Train a reaction BERT on the USPTO 1k TPL data set. The task is to predict the 1 out of 1000 template classes given the chemical reaction SMILES. The data set is strongly imbalanced and contains noisy reactions.
import os
import numpy as np
import pandas as pd
import torch
import logging
import random
import pkg_resources
import sklearn
from rxnfp.models import SmilesClassificationModel
logger = logging.getLogger(__name__)
Track the training
We will be using wandb to keep track of our training. You can use the an account on wandb or create an own instance following the instruction in the documentation.
If you then create an .env
file in the root folder and specify the WANDB_API_KEY=
(and the WANDB_BASE_URL=
), you can use dotenv to load those enviroment variables.
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
df = pd.read_csv('../data/uspto_1k_TPL/data_set/uspto_1k_TPL_train_valid.tsv.gzip', compression='gzip', sep='\t', index_col=0)
df[['canonical_rxn', 'labels']].head()
train_df = df.iloc[:400000]
eval_df = df[['canonical_rxn', 'labels']].iloc[400000:400604]
eval_df.columns = ['text', 'labels']
all_train_reactions = train_df.canonical_rxn_with_fragment_info[train_df.canonical_rxn_with_fragment_info!=train_df.canonical_rxn].values.tolist() + train_df.canonical_rxn.values.tolist()
corresponding_labels = train_df[train_df.canonical_rxn_with_fragment_info!=train_df.canonical_rxn].labels.values.tolist() + train_df.labels.values.tolist()
final_train_df = pd.DataFrame({'text': all_train_reactions, 'labels': corresponding_labels })
final_train_df = final_train_df.sample(frac=1., random_state=42)
Load model pretrained on a Masked Language Modeling task and train
This will currently only work if you have installed the library from the github repo with pip install -e .
,
as the models/transformers/bert_mlm_1k_tpl
and models/transformers/bert_class_1k_tpl
model are not included in the pip package.
model_args = {
'wandb_project': 'nmi_uspto_1000_class', 'num_train_epochs': 5, 'overwrite_output_dir': True,
'learning_rate': 2e-5, 'gradient_accumulation_steps': 1,
'regression': False, "num_labels": len(final_train_df.labels.unique()), "fp16": False,
"evaluate_during_training": True, 'manual_seed': 42,
"max_seq_length": 512, "train_batch_size": 8,"warmup_ratio": 0.00,
'output_dir': '../out/bert_class_1k_tpl',
'thread_count': 8,
}
model_path = pkg_resources.resource_filename("rxnfp", "models/transformers/bert_mlm_1k_tpl")
model = SmilesClassificationModel("bert", model_path, num_labels=len(final_train_df.labels.unique()), args=model_args, use_cuda=torch.cuda.is_available())
model.train_model(final_train_df, eval_df=eval_df, acc=sklearn.metrics.accuracy_score, mcc=sklearn.metrics.matthews_corrcoef)
train_model_path = pkg_resources.resource_filename("rxnfp", "models/transformers/bert_class_1k_tpl")
model = SmilesClassificationModel("bert", train_model_path, use_cuda=torch.cuda.is_available())
y_preds = model.predict(test_df.text.values)