import importlib.util
import logging
from typing import Dict, List, Optional, Sequence
from rxn.chemutils.conversion import canonicalize_smiles
from rxn.chemutils.miscellaneous import smiles_has_atom_mapping
from rxn.chemutils.reaction_smiles import parse_any_reaction_smiles
from rxn.chemutils.utils import remove_atom_mapping
from rxn.utilities.containers import chunker
from rxn.utilities.files import dump_list_to_file
from .metrics_files import RetroFiles
from .utils import combine_precursors_and_products_from_files, get_sequence_multiplier
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]def true_reactant_environment_check(do_reactant_check: bool) -> None:
"""Make sure that the Python packages required for the determination of true
reactants are available (if applicable).
Raises:
RuntimeError: if the environment lacks a dependency.
"""
if not do_reactant_check:
# no reactant check to do, i.e. no packages are needed.
return
spec = importlib.util.find_spec("rxnmapper")
if spec is None:
raise RuntimeError(
'The package "rxnmapper" is not available. Install it, or deactivate '
"the calculation of the true reactant accuracy."
)
[docs]def maybe_determine_true_reactants(
do_reactant_check: bool, retro_files: RetroFiles, batch_size: int
) -> None:
if not do_reactant_check:
return
# Importing only here, so that the scripts work without the rxnmapper
# package if the true reactant accuracy is not needed.
from rxnmapper import BatchedMapper
# Mute the rxnmapper log entries
rxnmapper_logger = logging.getLogger("rxnmapper")
old_logger_level = rxnmapper_logger.level
rxnmapper_logger.setLevel(logging.ERROR)
logger.info(
"The user opted in for the true reactant accuracy; "
"the ground truth and predicted reactions will be atom-mapped."
)
mapper = BatchedMapper(batch_size=batch_size)
logger.info("Atom-mapping the ground truth reactions...")
gt_reactions = combine_precursors_and_products_from_files(
precursors_file=retro_files.gt_tgt, products_file=retro_files.gt_src
)
dump_list_to_file(mapper.map_reactions(gt_reactions), retro_files.gt_mapped)
logger.info("Atom-mapping the ground truth reactions... Done.")
logger.info("Atom-mapping the predicted reactions...")
predicted_reactions = combine_precursors_and_products_from_files(
precursors_file=retro_files.predicted_canonical,
products_file=retro_files.gt_src,
)
dump_list_to_file(
mapper.map_reactions(predicted_reactions), retro_files.predicted_mapped
)
logger.info("Atom-mapping the predicted reactions... Done.")
# Reset the logger level
rxnmapper_logger.setLevel(old_logger_level)
[docs]def get_standardized_true_reactants(mapped_rxn_smiles: str) -> Optional[List[str]]:
"""
Get the reactants that contribute atoms to the product, and standardize them.
Returns:
The sorted list of "true reactants", None if something is not right.
"""
try:
reactants = parse_any_reaction_smiles(mapped_rxn_smiles).reactants
true_reactants = [r for r in reactants if smiles_has_atom_mapping(r)]
true_reactants = [remove_atom_mapping(r) for r in true_reactants]
canonical_true_reactants = [
canonicalize_smiles(r, check_valence=False) for r in true_reactants
]
if len(canonical_true_reactants) == 0:
return None
return sorted(canonical_true_reactants)
except Exception as e:
logger.debug(
f'Error when determining the true reactants in "{mapped_rxn_smiles}": {e}'
)
return None
[docs]def true_reactant_accuracy(
ground_truth_mapped: Sequence[str], predictions_mapped: Sequence[str]
) -> Dict[int, float]:
"""
Compute the top-n "true reactant" accuracy values (i.e. discarding reagents).
Args:
ground_truth_mapped: list of atom-mapped reactions from the ground truth.
predictions_mapped: list of atom-mapped reactions from the predictions.
Raises:
ValueError: if the list sizes are incompatible, forwarded from get_sequence_multiplier().
Returns:
Dictionary of top-n accuracy values.
"""
multiplier = get_sequence_multiplier(
ground_truth=ground_truth_mapped, predictions=predictions_mapped
)
# we will count, for each "n", how many predictions are correct
correct_for_topn: List[int] = [0 for _ in range(multiplier)]
# We will process sample by sample - for that, we need to chunk the predictions
prediction_chunks = chunker(predictions_mapped, chunk_size=multiplier)
for gt, predictions in zip(ground_truth_mapped, prediction_chunks):
gt_true_reactants = get_standardized_true_reactants(gt)
# if the ground truth has no mapping info: count as a negative
if gt_true_reactants is None:
continue
pred_true_reactants = [get_standardized_true_reactants(p) for p in predictions]
for i in range(multiplier):
correct = gt_true_reactants in pred_true_reactants[: i + 1]
correct_for_topn[i] += int(correct)
return {
i + 1: correct_for_topn[i] / len(ground_truth_mapped) for i in range(multiplier)
}