Source code for rxn.metrics.run_metrics

"""
Functions to launch metrics calculations on forward, retro, or context models.
"""
import json
import logging
from pathlib import Path
from typing import Dict, Type

from rxn.chemutils.miscellaneous import canonicalize_file
from rxn.chemutils.tokenization import copy_as_detokenized
from rxn.onmt_models import rxn_translation
from rxn.utilities.files import PathLike, ensure_directory_exists_and_is_empty
from rxn.utilities.logging import setup_console_and_file_logger

from .context_metrics import ContextMetrics
from .forward_metrics import ForwardMetrics
from .metrics_calculator import MetricsCalculator
from .metrics_files import ContextFiles, ForwardFiles, MetricsFiles, RetroFiles
from .retro_metrics import RetroMetrics

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

_FILES_MAPPING: Dict[str, Type[MetricsFiles]] = {
    "forward": ForwardFiles,
    "context": ContextFiles,
    "retro": RetroFiles,
}
_CALCULATOR_MAPPING: Dict[str, Type[MetricsCalculator]] = {
    "forward": ForwardMetrics,
    "context": ContextMetrics,
    "retro": RetroMetrics,
}


[docs]def get_metrics_files(task: str, files_path: PathLike) -> MetricsFiles: return _FILES_MAPPING[task](files_path)
[docs]def get_metrics_calculator(task: str, files: MetricsFiles) -> MetricsCalculator: return _CALCULATOR_MAPPING[task].from_metrics_files(files)
[docs]def evaluate_metrics(task: str, files_path: PathLike) -> None: logger.info(f"Evaluating the {task} metrics...") files = get_metrics_files(task, files_path) calculator = get_metrics_calculator(task, files) metrics_dict = calculator.get_metrics() if files.metrics_file.exists(): logger.warning(f'Overwriting "{files.metrics_file}"!') with open(files.metrics_file, "wt") as f: json.dump(metrics_dict, f, indent=2) logger.info(f'Evaluating the {task} metrics... Saved to "{files.metrics_file}".')
[docs]def run_model_for_metrics( task: str, model_path: Path, src_file: Path, tgt_file: Path, output_dir: Path, n_best: int, beam_size: int, batch_size: int, gpu: bool, initialize_logger: bool = False, ) -> None: ensure_directory_exists_and_is_empty(output_dir) files = get_metrics_files(task, output_dir) if initialize_logger: setup_console_and_file_logger(files.log_file) copy_as_detokenized(src_file, files.gt_src) copy_as_detokenized(tgt_file, files.gt_tgt) # context prediction rxn_translation( src_file=files.gt_src, tgt_file=files.gt_tgt, pred_file=files.predicted, model=model_path, n_best=n_best, beam_size=beam_size, batch_size=batch_size, gpu=gpu, ) canonicalize_file( files.predicted, files.predicted_canonical, fallback_value="", sort_molecules=True, )