Source code for rxn.chemutils.smiles_augmenter

import logging
import random
from typing import Callable, List

from .miscellaneous import apply_to_any_smiles, apply_to_smiles_groups

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


[docs]class SmilesAugmenter: """ Class to augment any kind of SMILES string with the help of randomization and shuffling. """
[docs] def __init__( self, augmentation_fn: Callable[[str], str], augmentation_probability: float = 1.0, shuffle: bool = True, ignore_exceptions: bool = True, ): """ Args: augmentation_fn: Function for augmenting the individual SMILES strings, such as the functions provided in smiles_randomization.py. augmentation_probability: Probability with which to augment individual SMILES strings. shuffle: Whether to shuffle the order of the compounds. ignore_exceptions: Whether to ignore the error (and return the original string) when an augmentation fails. If False, exceptions will be propagated. """ self.augmentation_fn = augmentation_fn self.augmentation_probability = augmentation_probability self.shuffle = shuffle self.ignore_exceptions = ignore_exceptions
[docs] def augment(self, smiles: str, number_augmentations: int) -> List[str]: """ Augment one SMILES string (of any kind). Args: smiles: SMILES string to augment. number_augmentations: how many times to do the augmentation. """ # augmentation of the individual compound SMILES augmented = [ apply_to_any_smiles( smiles, self._augment_with_probability, force_multicomponent=True ) for _ in range(number_augmentations) ] # shuffle the order of the compounds if self.shuffle: augmented = [ apply_to_smiles_groups(s, SmilesAugmenter._shuffle) for s in augmented ] return augmented
def _augment_with_probability(self, smiles: str) -> str: """Augmentat a SMILES, with the probability given by the member variable.""" # Note: no need to call random.uniform if the augmentation probability is 1.0. if ( self.augmentation_probability == 1.0 or random.uniform(0, 1) <= self.augmentation_probability ): try: return self.augmentation_fn(smiles) except Exception as e: if self.ignore_exceptions: logger.warning(f"Augmentation failed for {smiles}: {e}") return smiles else: raise # no augmentation return smiles @staticmethod def _shuffle(smiles_list: List[str]) -> List[str]: smiles_list = smiles_list.copy() random.shuffle(smiles_list) return smiles_list