# LICENSED INTERNAL CODE. PROPERTY OF IBM.
# IBM Research Zurich Licensed Internal Code
# (C) Copyright IBM Corp. 2021
# ALL RIGHTS RESERVED
""" A utility class to augment the dataset files """
import math
import random
from pathlib import Path
from typing import List, Set
import pandas as pd
from rxn.chemutils.smiles_randomization import (
randomize_smiles_restricted,
randomize_smiles_rotated,
randomize_smiles_unrestricted,
)
from rxn.reaction_preprocessing.config import AugmentConfig
from rxn.reaction_preprocessing.smiles_tokenizer import SmilesTokenizer
from rxn.reaction_preprocessing.utils import RandomType, ReactionSection
[docs]def molecules_permutation_given_index(
molecules_list: List[str], permutation_index: int
) -> List["str"]:
"""
https://stackoverflow.com/questions/5602488/random-picks-from-permutation-generator
"""
molecules_list = molecules_list[:]
for i in range(len(molecules_list) - 1):
permutation_index, j = divmod(permutation_index, len(molecules_list) - i)
molecules_list[i], molecules_list[i + j] = (
molecules_list[i + j],
molecules_list[i],
)
return molecules_list
[docs]class Augmenter:
"""Augmenter.
Note: Unlike the other classes, which are memory-efficient, this one loads the
whole data in a pandas DataFrame for processing.
"""
[docs] def __init__(
self, df: pd.DataFrame, reaction_column_name: str, fragment_bond: str = "."
):
"""Creates a new instance of the Augmenter class.
Args:
df (pd.DataFrame): A pandas DataFrame containing the molecules SMILES.
reaction_column_name: The name of the DataFrame column containing the reaction SMILES.
fragment_bond (str): The fragment bond token contained in the SMILES.
"""
self.df = df
self.__reaction_column_name = reaction_column_name
self.tokenizer = SmilesTokenizer()
self.fragment_bond = fragment_bond
self.augmented_columns: Set[str] = set()
#
# Private Methods
#
def __randomize_smiles(
self, smiles: str, random_type: RandomType, permutations: int
) -> List[str]:
"""
Randomizes a molecules SMILES string that might contain fragment bonds
and returns a number of augmented versions of the SMILES equal to permutations.
Args:
smiles (str): The molecules SMILES to augment
random_type (RandomType): The type of randomization to be applied.
permutations (int): The number of permutations to deliver for the SMILES
Returns:
List[str]: The list of randomized SMILES
"""
# Raise for empty SMILES
if not smiles:
raise ValueError
list_of_smiles: List[str] = []
for i in range(permutations):
list_of_smiles.append(
".".join(
[
self.fragment_bond.join(
[
Augmenter.__randomize_smiles_without_fragment(
fragment, random_type
)
for fragment in group.split(self.fragment_bond)
]
)
for group in smiles.split(".")
]
)
)
return list_of_smiles
#
# Private Static Methods
#
@staticmethod
def __randomize_smiles_without_fragment(
smiles: str, random_type: RandomType
) -> str:
"""
Generates a random version of a SMILES without a fragment bond
Args:
smiles (str): The pandas DataFrame to be split into training, validation, and test sets.
random_type (RandomType): The type of randomization to be applied.
Raises:
InvalidSmiles: for invalid SMILES (raised via rxn.chemutils).
ValueError: if an invalid randomization type is provided.
Returns:
str: the randomized SMILES
"""
if random_type == RandomType.unrestricted:
return randomize_smiles_unrestricted(smiles)
elif random_type == RandomType.restricted:
return randomize_smiles_restricted(smiles)
elif random_type == RandomType.rotated:
return randomize_smiles_rotated(smiles, with_order_reversal=True)
raise ValueError(f"Invalid random type: {random_type}")
@staticmethod
def __randomize_molecules(smiles: str, permutations: int) -> List[str]:
"""
Randomizes the order of the molecules inside a SMILES string that might
contain fragment bonds and returns a number of augmented versions of the
SMILES equal to permutations.
For a number of molecules smaller than permutations, returns a number of
permutations equal to the number of molecules
Args:
smiles (str): The molecules SMILES to augment
permutations (int): The number of permutations to deliver for the SMILES
Returns:
List[str]: The list of randomized SMILES
"""
# Raise for empty SMILES
if not smiles:
raise ValueError
molecules_list = smiles.split(".")
total_permutations = range(min(math.factorial(len(molecules_list)), 4000000))
permutation_indices = random.sample(
total_permutations, min(permutations, len(molecules_list))
)
permuted_molecules_smiles = []
for idx in permutation_indices:
permuted_precursors = molecules_permutation_given_index(molecules_list, idx)
permuted_molecules_smiles.append(".".join(permuted_precursors))
return permuted_molecules_smiles
#
# Public Methods
#
[docs] def augment(
self,
random_type: RandomType = RandomType.unrestricted,
rxn_section_to_augment: ReactionSection = ReactionSection.precursors,
permutations: int = 1,
) -> pd.DataFrame:
"""
Creates samples for the augmentation. Returns a a pandas Series containing the
augmented samples.
Args:
random_type (RandomType): The string identifying the type of randomization to apply.
"molecules" for randomization of the molecules (canonical SMILES kept)
"unrestricted" for unrestricted randomization
"restricted" for restricted randomization
"rotated" for rotated randomization
For details on the differences:
https://github.com/undeadpixel/reinvent-randomized and
https://github.com/GLambard/SMILES-X
rxn_section_to_augment (ReactionSection): The section of the rxn SMILES to augment.
"precursors" for augmenting only the precursors
"products" for augmenting only the products
permutations (int): The number of permutations to generate for each SMILES
Returns:
pd.DataFrame: A pandas Series containing the augmented samples.
"""
if rxn_section_to_augment is ReactionSection.precursors:
self.df[f"precursors_{random_type.name}"] = self.df[
self.__reaction_column_name
].apply(lambda smiles: smiles.replace(" ", "").split(">>")[0])
if "products" not in self.df.keys():
self.df["products"] = self.df[self.__reaction_column_name].apply(
lambda smiles: smiles.replace(" ", "").split(">>")[1]
)
columns_to_augment = [f"precursors_{random_type.name}"]
columns_to_join = [f"precursors_{random_type.name}", "products"]
elif rxn_section_to_augment is ReactionSection.products:
self.df[f"products_{random_type.name}"] = self.df[
self.__reaction_column_name
].apply(lambda smiles: smiles.replace(" ", "").split(">>")[1])
if "precursors" not in self.df.keys():
self.df["precursors"] = self.df[self.__reaction_column_name].apply(
lambda smiles: smiles.replace(" ", "").split(">>")[0]
)
columns_to_augment = [f"products_{random_type.name}"]
columns_to_join = ["precursors", f"products_{random_type.name}"]
else:
raise ValueError(
f"Invalid reaction section to augment: {rxn_section_to_augment.name}"
)
for column in columns_to_augment:
if random_type != RandomType.molecules:
self.df[column] = self.df[column].apply(
lambda smiles: self.__randomize_smiles(
smiles, random_type, permutations
)
)
else:
self.df[column] = self.df[column].apply(
lambda smiles: self.__randomize_molecules(smiles, permutations)
)
# Exploding the dataframe columns where I have the list of augmented
# versions of a SMILES (the list length is the number of permutations)
self.df = (
self.df.set_index(
[col for col in self.df.keys() if col not in columns_to_augment]
)
.apply(pd.Series.explode)
.reset_index()
)
augmented_column_name = f"rxn_{random_type.name}"
self.augmented_columns.add(augmented_column_name)
self.df[augmented_column_name] = self.df.apply(
lambda x: ">>".join(x[columns_to_join]), axis=1
)
return self.df
#
# Public Static Methods
#
[docs] @staticmethod
def read_csv(
filepath: str, reaction_column_name: str, fragment_bond: str = "."
) -> "Augmenter":
"""A helper function to read a list or csv of SMILES.
Args:
filepath (str): The path to the text file containing the molecules SMILES.
reaction_column_name: The name of the reaction column (or the name that wil be given
to the reaction column if the input file has no headers).
fragment_bond (str): The fragment token in the reaction SMILES
Returns:
Augmenter: A new augmenter instance.
"""
df = pd.read_csv(filepath, lineterminator="\n")
if len(df.columns) == 1:
df.rename(columns={df.columns[0]: reaction_column_name}, inplace=True)
return Augmenter(df, reaction_column_name, fragment_bond)
[docs]def augment(cfg: AugmentConfig) -> None:
output_file_path = Path(cfg.output_file_path)
if not Path(cfg.input_file_path).exists():
raise ValueError(
f"Input file for standardization does not exist: {cfg.input_file_path}"
)
# Create a instance of the Augmenter.
ag = Augmenter.read_csv(
cfg.input_file_path, cfg.reaction_column_name, cfg.fragment_bond.value
)
columns_to_keep = list(ag.df.columns)
# Perform augmentation
ag.augment(
random_type=cfg.random_type,
rxn_section_to_augment=cfg.rxn_section_to_augment,
permutations=cfg.permutations,
)
columns_to_keep.extend(ag.augmented_columns)
if not cfg.keep_intermediate_columns:
ag.df = ag.df[columns_to_keep]
# Exporting augmented samples
ag.df.to_csv(output_file_path, index=False)