Source code for rxn.metrics.forward_metrics

from typing import Any, Dict, Iterable

from rxn.utilities.files import PathLike, iterate_lines_from_file

from .metrics import top_n_accuracy
from .metrics_calculator import MetricsCalculator
from .metrics_files import ForwardFiles, MetricsFiles


[docs]class ForwardMetrics(MetricsCalculator): """ Class to compute common metrics for forward models, starting from files containing the ground truth and predictions. Note: all files are expected to be standardized (canonicalized, sorted, etc.). """
[docs] def __init__(self, gt_products: Iterable[str], predicted_products: Iterable[str]): self.gt_products = list(gt_products) self.predicted_products = list(predicted_products)
[docs] def get_metrics(self) -> Dict[str, Any]: topn = top_n_accuracy( ground_truth=self.gt_products, predictions=self.predicted_products ) return {"accuracy": topn}
[docs] @classmethod def from_metrics_files(cls, metrics_files: MetricsFiles) -> "ForwardMetrics": if not isinstance(metrics_files, ForwardFiles): raise ValueError("Invalid type provided") return cls.from_raw_files( gt_products_file=metrics_files.gt_tgt, predicted_products_file=metrics_files.predicted_canonical, )
@classmethod def from_raw_files( cls, gt_products_file: PathLike, predicted_products_file: PathLike, ) -> "ForwardMetrics": return cls( gt_products=iterate_lines_from_file(gt_products_file), predicted_products=iterate_lines_from_file(predicted_products_file), )