import os
import numpy as np
import pandas as pd
import torch
import logging
import random
from rxnfp.models import SmilesLanguageModelingModel
logger = logging.getLogger(__name__)
Track the training
We will be using wandb to keep track of our training. You can use the an account on wandb or create an own instance following the instruction in the documentation.
If you then create an .env
file in the root folder and specify the WANDB_API_KEY=
(and the WANDB_BASE_URL=
), you can use dotenv to load those enviroment variables.
# !pip install python-dotenv
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
Setup MLM training
Choose the hyperparameters you want and start the training. The default parameters will train a BERT model with 12 layers and 4 attention heads per layer. The training task is Masked Language Modeling (MLM), where tokens from the input reactions are randomly masked and predicted by the model given the context.
After defining the config, the training is launched in 3 lines of code using our adapter written for the SimpleTransformers library (based on huggingface Transformers).
To make it work you will have to install simpletransformers:
pip install simpletransformers
config = {
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 256,
"initializer_range": 0.02,
"intermediate_size": 512,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 4,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
}
vocab_path = '../data/uspto_1k_TPL/individual_files/vocab.txt'
args = {'config': config,
'vocab_path': vocab_path,
'wandb_project': 'uspto_mlm_temp_1000',
'train_batch_size': 32,
'manual_seed': 42,
"fp16": False,
"num_train_epochs": 50,
'max_seq_length': 256,
'evaluate_during_training': True,
'overwrite_output_dir': True,
'output_dir': '../out/bert_mlm_1k_tpl',
'learning_rate': 1e-4
}
model = SmilesLanguageModelingModel(model_type='bert', model_name=None, args=args)
# !unzip ../data/uspto_1k_TPL/individual_files/mlm_training.zip -d ../data/uspto_1k_TPL/individual_files/
train_file = '../data/uspto_1k_TPL/individual_files/mlm_train_file.txt'
eval_file = '../data/uspto_1k_TPL/individual_files/mlm_eval_file_1k.txt'
model.train_model(train_file=train_file, eval_file=eval_file)