Classification Report

Module Interface

class torchmetrics.ClassificationReport(**kwargs)[source]

Compute precision, recall, F-measure and support for each class.

\[ \begin{align}\begin{aligned}\text{Precision} = \frac{TP}{TP + FP}\\\text{Recall} = \frac{TP}{TP + FN}\\\text{F1} = 2 * \frac{\text{Precision} * \text{Recall}}{\text{Precision} + \text{Recall}}\\\text{Support} = \sum_i^N 1(y_i = k)\end{aligned}\end{align} \]

Where \(TP\) is true positives, \(FP\) is false positives, \(FN\) is false negatives, \(y\) is a tensor of target values, \(k\) is the class, and \(N\) is the number of samples.

This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the task argument to either 'binary', 'multiclass' or 'multilabel'. See the documentation of BinaryClassificationReport, MulticlassClassificationReport and MultilabelClassificationReport for the specific details of each argument influence and examples.

Example (Binary Classification):
>>> from torch import tensor
>>> from torchmetrics.classification import ClassificationReport
>>> target = tensor([0, 1, 0, 1])
>>> preds = tensor([0, 1, 1, 1])
>>> target_names = ['0', '1']
>>> report = ClassificationReport(
...     task="binary",
...     target_names=target_names,
...     digits=2
... )
>>> report.update(preds, target)
>>> print(report.compute()) 
                      precision  recall f1-score support

0                          1.00    0.50     0.67       2
1                          0.67    1.00     0.80       2

accuracy                                    0.75       4
macro avg                  0.83    0.75     0.73       4
weighted avg               0.83    0.75     0.73       4
static __new__(cls, task, threshold=0.5, num_classes=None, num_labels=None, target_names=None, sample_weight=None, digits=2, output_dict=False, zero_division='warn', ignore_index=None, top_k=1, **kwargs)[source]

Initialize task metric.

Return type:

Metric

BinaryClassificationReport

class torchmetrics.classification.BinaryClassificationReport(threshold=0.5, target_names=None, sample_weight=None, digits=2, output_dict=False, zero_division='warn', ignore_index=None, **kwargs)[source]

Compute precision, recall, F-measure and support for binary classification tasks.

The classification report provides detailed metrics for each class in a binary classification task: precision, recall, F1-score, and support.

\[ \begin{align}\begin{aligned}\text{Precision} = \frac{TP}{TP + FP}\\\text{Recall} = \frac{TP}{TP + FN}\\\text{F1} = 2 * \frac{\text{Precision} * \text{Recall}}{\text{Precision} + \text{Recall}}\\\text{Support} = \sum_i^N 1(y_i = k)\end{aligned}\end{align} \]

Where \(TP\) is true positives, \(FP\) is false positives, \(FN\) is false negatives, \(y\) is a tensor of target values, \(k\) is the class, and \(N\) is the number of samples.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A tensor of predictions of shape (N, ...) where N is the batch size. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): A tensor of targets of shape (N, ...) where N is the batch size.

As output to forward and compute the metric returns either:

  • A formatted string report if output_dict=False

  • A dictionary of metrics if output_dict=True

Parameters:
  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • target_names (Optional[Sequence[str]]) – Optional list of names for each class

  • sample_weight (Optional[Tensor]) – Optional weights for each sample

  • digits (int) – Number of decimal places to display in the report

  • output_dict (bool) – If True, return a dict instead of a string report

  • zero_division (Union[str, int]) – Value to use when dividing by zero

Example

>>> from torch import tensor
>>> from torchmetrics.classification.classification_report import binary_classification_report
>>> target = tensor([0, 1, 0, 1])
>>> preds = tensor([0, 1, 1, 1])
>>> target_names = ['0', '1']
>>> report = binary_classification_report(
...     preds=preds,
...     target=target,
...     target_names=target_names,
...     digits=2
... )
>>> print(report) 
                      precision  recall f1-score support

0                          1.00    0.50     0.67       2
1                          0.67    1.00     0.80       2

accuracy                                    0.75       4
macro avg                  0.83    0.75     0.73       4
weighted avg               0.83    0.75     0.73       4

MulticlassClassificationReport

class torchmetrics.classification.MulticlassClassificationReport(num_classes, target_names=None, sample_weight=None, digits=2, output_dict=False, zero_division='warn', ignore_index=None, top_k=1, **kwargs)[source]

Compute precision, recall, F-measure and support for multiclass classification tasks.

The classification report provides detailed metrics for each class in a multiclass classification task: precision, recall, F1-score, and support.

\[ \begin{align}\begin{aligned}\text{Precision} = \frac{TP}{TP + FP}\\\text{Recall} = \frac{TP}{TP + FN}\\\text{F1} = 2 * \frac{\text{Precision} * \text{Recall}}{\text{Precision} + \text{Recall}}\\\text{Support} = \sum_i^N 1(y_i = k)\end{aligned}\end{align} \]

Where \(TP\) is true positives, \(FP\) is false positives, \(FN\) is false negatives, \(y\) is a tensor of target values, \(k\) is the class, and \(N\) is the number of samples.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A tensor of predictions. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply softmax per sample. Additionally, we convert to int tensor with argmax.

  • target (Tensor): A tensor of integer targets.

As output to forward and compute the metric returns either:

  • A formatted string report if output_dict=False

  • A dictionary of metrics if output_dict=True

Parameters:
  • num_classes (int) – Number of classes in the dataset

  • target_names (Optional[Sequence[str]]) – Optional list of names for each class

  • sample_weight (Optional[Tensor]) – Optional weights for each sample

  • digits (int) – Number of decimal places to display in the report

  • output_dict (bool) – If True, return a dict instead of a string report

  • zero_division (Union[str, int]) – Value to use when dividing by zero

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

Example

>>> from torch import tensor
>>> from torchmetrics.classification.classification_report import multiclass_classification_report
>>> target = tensor([0, 1, 2, 2, 2])
>>> preds = tensor([0, 0, 2, 2, 1])
>>> target_names = ["class 0", "class 1", "class 2"]
>>> report = multiclass_classification_report(
...     preds=preds,
...     target=target,
...     num_classes=3,
...     target_names=target_names,
...     digits=2
... )
>>> print(report) 
                    precision  recall f1-score support

class 0                  0.50    1.00     0.67       1
class 1                  0.00    0.00     0.00       1
class 2                  1.00    0.67     0.80       3

accuracy                                  0.60       5
macro avg                0.50    0.56     0.49       5
weighted avg             0.70    0.60     0.61       5

MultilabelClassificationReport

class torchmetrics.classification.MultilabelClassificationReport(num_labels, target_names=None, threshold=0.5, sample_weight=None, digits=2, output_dict=False, zero_division='warn', ignore_index=None, **kwargs)[source]

Compute precision, recall, F-measure and support for multilabel classification tasks.

The classification report provides detailed metrics for each class in a multilabel classification task: precision, recall, F1-score, and support.

\[ \begin{align}\begin{aligned}\text{Precision} = \frac{TP}{TP + FP}\\\text{Recall} = \frac{TP}{TP + FN}\\\text{F1} = 2 * \frac{\text{Precision} * \text{Recall}}{\text{Precision} + \text{Recall}}\\\text{Support} = \sum_i^N 1(y_i = k)\end{aligned}\end{align} \]

Where \(TP\) is true positives, \(FP\) is false positives, \(FN\) is false negatives, \(y\) is a tensor of target values, \(k\) is the class, and \(N\) is the number of samples.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): A tensor of predictions of shape (N, C) where N is the batch size and C is the number of labels. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in threshold.

  • target (Tensor): A tensor of targets of shape (N, C) where N is the batch size and C is the number of labels.

As output to forward and compute the metric returns either:

  • A formatted string report if output_dict=False

  • A dictionary of metrics if output_dict=True

Parameters:
  • num_labels (int) – Number of labels in the dataset

  • target_names (Optional[Sequence[str]]) – Optional list of names for each label

  • threshold (float) – Threshold for transforming probability to binary (0,1) predictions

  • sample_weight (Optional[Tensor]) – Optional weights for each sample

  • digits (int) – Number of decimal places to display in the report

  • output_dict (bool) – If True, return a dict instead of a string report

  • zero_division (Union[str, int]) – Value to use when dividing by zero

Example

>>> from torch import tensor
>>> from torchmetrics.classification.classification_report import multilabel_classification_report
>>> target = tensor([[1, 0, 1], [0, 1, 0], [1, 1, 0]])
>>> preds = tensor([[1, 0, 1], [0, 1, 1], [1, 0, 0]])
>>> target_names = ["Label A", "Label B", "Label C"]
>>> report = multilabel_classification_report(
...     preds=preds,
...     target=target,
...     num_labels=len(target_names),
...     target_names=target_names,
...     digits=2,
... )
>>> print(report) 
                    precision  recall f1-score support

Label A                  1.00    1.00     1.00       2
Label B                  1.00    0.50     0.67       2
Label C                  0.50    1.00     0.67       1

micro avg                0.80    0.80     0.80       5
macro avg                0.83    0.83     0.78       5
weighted avg             0.90    0.80     0.80       5
samples avg              0.83    0.83     0.78       5

Functional Interface

torchmetrics.functional.classification.classification_report(preds, target, task, threshold=0.5, num_classes=None, num_labels=None, target_names=None, digits=2, output_dict=False, zero_division=0.0, ignore_index=None, validate_args=True, labels=None, top_k=1)[source]

Compute a classification report for various classification tasks.

The classification report shows the precision, recall, F1 score, and support for each class/label.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with ground truth labels

  • task (Literal['binary', 'multiclass', 'multilabel']) – The classification task - either ‘binary’, ‘multiclass’, or ‘multilabel’

  • threshold (float) – Threshold for converting probabilities to binary predictions (for binary and multilabel tasks)

  • num_classes (Optional[int]) – Number of classes (for multiclass tasks)

  • num_labels (Optional[int]) – Number of labels (for multilabel tasks)

  • target_names (Optional[List[str]]) – Optional list of names for the classes/labels

  • digits (int) – Number of decimal places to display in the report

  • output_dict (bool) – If True, return a dict instead of a string report

  • zero_division (Union[str, float]) – Value to use when dividing by zero

  • ignore_index (Optional[int]) – Optional index to ignore in the target (for multiclass tasks)

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness

  • labels (Optional[List[int]]) – Optional list of label indices to include in the report (for multiclass tasks)

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits and task is ‘multiclass’.

Return type:

Union[str, Dict[str, Union[float, Dict[str, Union[float, int]]]]]

Returns:

If output_dict=True, a dictionary with the classification report data. Otherwise, a formatted string with the classification report.

Examples

>>> from torch import tensor
>>> from torchmetrics.functional.classification.classification_report import classification_report
>>>
>>> # Binary classification example
>>> binary_target = tensor([0, 1, 0, 1])
>>> binary_preds = tensor([0, 1, 1, 1])
>>> binary_report = classification_report(
...     preds=binary_preds,
...     target=binary_target,
...     task="binary",
...     target_names=['Class 0', 'Class 1'],
...     digits=2
... )
>>> print(binary_report) 
                      precision  recall f1-score support

Class 0                     1.00    0.50     0.67       2
Class 1                     0.67    1.00     0.80       2

accuracy                                    0.75       4
macro avg                   0.83    0.75     0.73       4
weighted avg                0.83    0.75     0.73       4
>>>
>>> # Multiclass classification example
>>> multiclass_target = tensor([0, 1, 2, 2, 2])
>>> multiclass_preds = tensor([0, 0, 2, 2, 1])
>>> multiclass_report = classification_report(
...     preds=multiclass_preds,
...     target=multiclass_target,
...     task="multiclass",
...     num_classes=3,
...     target_names=["Class 0", "Class 1", "Class 2"],
...     digits=2
... )
>>> print(multiclass_report) 
                      precision  recall f1-score support

Class 0                    0.50    1.00     0.67       1
Class 1                    0.00    0.00     0.00       1
Class 2                    1.00    0.67     0.80       3

accuracy                                    0.60       5
macro avg                  0.50    0.56     0.49       5
weighted avg               0.70    0.60     0.61       5
>>>
>>> # Multilabel classification example
>>> multilabel_target = tensor([[1, 0, 1], [0, 1, 0], [1, 1, 0]])
>>> multilabel_preds = tensor([[1, 0, 1], [0, 1, 1], [1, 0, 0]])
>>> multilabel_report = classification_report(
...     preds=multilabel_preds,
...     target=multilabel_target,
...     task="multilabel",
...     num_labels=3,
...     target_names=["Label A", "Label B", "Label C"],
...     digits=2
... )
>>> print(multilabel_report) 
                      precision  recall f1-score support

Label A                    1.00    1.00     1.00       2
Label B                    1.00    0.50     0.67       2
Label C                    0.50    1.00     0.67       1

micro avg                  0.80    0.80     0.80       5
macro avg                  0.83    0.83     0.78       5
weighted avg               0.90    0.80     0.80       5
samples avg                0.83    0.83     0.78       5

binary_classification_report

torchmetrics.functional.classification.binary_classification_report(preds, target, threshold=0.5, target_names=None, digits=2, output_dict=False, zero_division=0.0, ignore_index=None, validate_args=True)[source]

Compute a classification report for binary classification tasks.

The classification report shows the precision, recall, F1 score, and support for each class.

Parameters:
  • preds (Tensor) – Tensor with predictions

  • target (Tensor) – Tensor with ground truth labels

  • threshold (float) – Threshold for converting probabilities to binary predictions

  • target_names (Optional[List[str]]) – Optional list of names for the classes

  • digits (int) – Number of decimal places to display in the report

  • output_dict (bool) – If True, return a dict instead of a string report

  • zero_division (Union[str, float]) – Value to use when dividing by zero

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness

Return type:

Union[str, Dict[str, Union[float, Dict[str, Union[float, int]]]]]

Returns:

If output_dict=True, a dictionary with the classification report data. Otherwise, a formatted string with the classification report.

Example

>>> from torch import tensor
>>> from torchmetrics.functional.classification.classification_report import binary_classification_report
>>> target = tensor([0, 1, 0, 1])
>>> preds = tensor([0, 1, 1, 1])
>>> target_names = ['0', '1']
>>> report = binary_classification_report(
...     preds=preds,
...     target=target,
...     target_names=target_names,
...     digits=2
... )
>>> print(report) 
                      precision  recall f1-score support

0                          1.00    0.50     0.67       2
1                          0.67    1.00     0.80       2

accuracy                                    0.75       4
macro avg                  0.83    0.75     0.73       4
weighted avg               0.83    0.75     0.73       4

multiclass_classification_report

torchmetrics.functional.classification.multiclass_classification_report(preds, target, num_classes, target_names=None, digits=2, output_dict=False, zero_division=0.0, ignore_index=None, validate_args=True, labels=None, top_k=1)[source]

Compute a classification report for multiclass classification tasks.

The classification report shows the precision, recall, F1 score, and support for each class.

Parameters:
  • preds (Tensor) – Tensor with predictions of shape (N, …) or (N, C, …) where C is the number of classes

  • target (Tensor) – Tensor with ground truth labels of shape (N, …)

  • num_classes (int) – Number of classes

  • target_names (Optional[List[str]]) – Optional list of names for the classes

  • digits (int) – Number of decimal places to display in the report

  • output_dict (bool) – If True, return a dict instead of a string report

  • zero_division (Union[str, float]) – Value to use when dividing by zero

  • ignore_index (Optional[int]) – Optional index to ignore in the target

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness

  • labels (Optional[List[int]]) – Optional list of label indices to include in the report

  • top_k (int) – Number of highest probability or logit score predictions considered to find the correct label. Only works when preds contain probabilities/logits.

Return type:

Union[str, Dict[str, Union[float, Dict[str, Union[float, int]]]]]

Returns:

If output_dict=True, a dictionary with the classification report data. Otherwise, a formatted string with the classification report.

Example

>>> from torch import tensor
>>> from torchmetrics.functional.classification.classification_report import multiclass_classification_report
>>> target = tensor([0, 1, 2, 2, 2])
>>> preds = tensor([0, 0, 2, 2, 1])
>>> target_names = ["class 0", "class 1", "class 2"]
>>> report = multiclass_classification_report(
...     preds=preds,
...     target=target,
...     num_classes=3,
...     target_names=target_names,
...     digits=2
... )
>>> print(report) 
                    precision  recall f1-score support

class 0                  0.50    1.00     0.67       1
class 1                  0.00    0.00     0.00       1
class 2                  1.00    0.67     0.80       3

accuracy                                  0.60       5
macro avg                0.50    0.56     0.49       5
weighted avg             0.70    0.60     0.61       5

multilabel_classification_report

torchmetrics.functional.classification.multilabel_classification_report(preds, target, num_labels, threshold=0.5, target_names=None, digits=2, output_dict=False, zero_division=0.0, ignore_index=None, validate_args=True)[source]

Compute a classification report for multilabel classification tasks.

The classification report shows the precision, recall, F1 score, and support for each label.

Parameters:
  • preds (Tensor) – Tensor with predictions of shape (N, L, …) where L is the number of labels

  • target (Tensor) – Tensor with ground truth labels of shape (N, L, …)

  • num_labels (int) – Number of labels

  • threshold (float) – Threshold for converting probabilities to binary predictions

  • target_names (Optional[List[str]]) – Optional list of names for the labels

  • digits (int) – Number of decimal places to display in the report

  • output_dict (bool) – If True, return a dict instead of a string report

  • zero_division (Union[str, float]) – Value to use when dividing by zero

  • ignore_index (Optional[int]) – Specifies a target value that is ignored and does not contribute to the metric calculation

  • validate_args (bool) – bool indicating if input arguments and tensors should be validated for correctness

Return type:

Union[str, Dict[str, Union[float, Dict[str, Union[float, int]]]]]

Returns:

If output_dict=True, a dictionary with the classification report data. Otherwise, a formatted string with the classification report.

Example

>>> from torch import tensor
>>> from torchmetrics.functional.classification.classification_report import multilabel_classification_report
>>> target = tensor([[1, 0, 1], [0, 1, 0], [1, 1, 0]])
>>> preds = tensor([[1, 0, 1], [0, 1, 1], [1, 0, 0]])
>>> target_names = ["Label A", "Label B", "Label C"]
>>> report = multilabel_classification_report(
...     preds=preds,
...     target=target,
...     num_labels=len(target_names),
...     target_names=target_names,
...     digits=2,
... )
>>> print(report) 
                    precision  recall f1-score support

Label A                  1.00    1.00     1.00       2
Label B                  1.00    0.50     0.67       2
Label C                  0.50    1.00     0.67       1

micro avg                0.80    0.80     0.80       5
macro avg                0.83    0.83     0.78       5
weighted avg             0.90    0.80     0.80       5
samples avg              0.83    0.83     0.78       5