frarch.modules.metrics.base module#

class frarch.modules.metrics.base.AggregationModes(value)[source]#

Bases: frarch.utils.enums.base.StringEnum

An enumeration.

MAX = 'max'#
MEAN = 'mean'#
MIN = 'min'#
class frarch.modules.metrics.base.Metric[source]#

Bases: object

abstract class for Metric objects.

Example

Simple usage of the Metric class::
class MyMetric(Metric):
def _update(self, predictions, truth):

# compute some metric return metric_value

model = MyModel() mymetric = MyMetric() for batch, labels in dataset:

predictions = model(batch) mymetric.update(predictions, labels)

print(mymetric.get_metric(mode=”mean”))

aggregation_methods = {AggregationModes.MAX: <function Metric.<lambda>>, AggregationModes.MEAN: <function Metric.<lambda>>, AggregationModes.MIN: <function Metric.<lambda>>}#
get_metric(mode: frarch.modules.metrics.base.AggregationModes = AggregationModes.MEAN) float[source]#

Aggregate all values stored in the metric class.

Parameters

mode (str, optional) – aggregation type. mean, max or min. Defaults to “mean”.

Raises

ValueError – aggregation mode not supported

Returns

aggregated metric.

Return type

float

reset() None[source]#

Clear metrics from class.

update(predictions: torch.Tensor, truth: torch.Tensor) None[source]#

Compute metric value and append to the metrics array.

Parameters
  • predictions (torch.Tensor) – output tensors from model.

  • truth (torch.Tensor) – ground truth tensor.