Source code for rxn.metrics.classification_translation

import logging
from pathlib import Path
from typing import Optional, Union

from rxn.chemutils.tokenization import file_is_tokenized, tokenize_file
from rxn.onmt_utils import translate
from rxn.utilities.files import dump_list_to_file, is_path_exists_or_creatable

from .metrics_files import RetroFiles
from .tokenize_file import (
    classification_file_is_tokenized,
    detokenize_classification_file,
    tokenize_classification_file,
)
from .utils import combine_precursors_and_products_from_files

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs]def maybe_classify_predictions( classification_model: Optional[Path], retro_files: RetroFiles, batch_size: int, gpu: bool, ) -> None: """Classify the reactions for determining the diversity metric. Only executed if a classification model is available.""" if classification_model is None: return create_rxn_from_files( retro_files.predicted_canonical, retro_files.predicted_products_canonical, retro_files.predicted_rxn_canonical, ) classification_translation( src_file=retro_files.predicted_rxn_canonical, tgt_file=None, pred_file=retro_files.predicted_classes, model=classification_model, n_best=1, beam_size=5, batch_size=batch_size, gpu=gpu, )
[docs]def create_rxn_from_files( input_file_precursors: Union[str, Path], input_file_products: Union[str, Path], output_file: Union[str, Path], ) -> None: logger.info( f'Combining files "{input_file_precursors}" and "{input_file_products}" -> "{output_file}".' ) dump_list_to_file( combine_precursors_and_products_from_files( precursors_file=input_file_precursors, products_file=input_file_products, ), output_file, )
[docs]def classification_translation( src_file: Union[str, Path], tgt_file: Optional[Union[str, Path]], pred_file: Union[str, Path], model: Union[str, Path], n_best: int, beam_size: int, batch_size: int, gpu: bool, max_length: int = 3, as_external_command: bool = False, ) -> None: """ Do a classification translation. This function takes care of tokenizing/detokenizing the input. Note: no check is made that the source is canonical. Args: src_file: source file (tokenized or detokenized). tgt_file: ground truth class file (tokenized), not mandatory. pred_file: file where to save the predictions. model: model to do the translation n_best: number of predictions to make for each input. beam_size: beam size. batch_size: batch size. gpu: whether to use the GPU. max_length: maximum sequence length. """ if not is_path_exists_or_creatable(pred_file): raise RuntimeError(f'The file "{pred_file}" cannot be created.') # src if file_is_tokenized(src_file): tokenized_src = src_file else: tokenized_src = str(src_file) + ".tokenized" tokenize_file(src_file, tokenized_src, fallback_value="") # tgt if tgt_file is None: tokenized_tgt = None elif classification_file_is_tokenized(tgt_file): tokenized_tgt = tgt_file else: tokenized_tgt = str(tgt_file) + ".tokenized" tokenize_classification_file(tgt_file, tokenized_tgt) tokenized_pred = str(pred_file) + ".tokenized" translate( model=model, src=tokenized_src, tgt=tokenized_tgt, output=tokenized_pred, n_best=n_best, beam_size=beam_size, max_length=max_length, batch_size=batch_size, gpu=gpu, as_external_command=as_external_command, ) detokenize_classification_file(tokenized_pred, pred_file)