Tutorial how to train a reaction language model
import os
import numpy as np
import pandas as pd
import torch
import logging
import random
from rxnfp.models import SmilesLanguageModelingModel

logger = logging.getLogger(__name__)
wandb: WARNING W&B installed but not logged in.  Run `wandb login` or set the WANDB_API_KEY env variable.
This extension has only been tested with simpletransformers==0.34.4

Track the training

We will be using wandb to keep track of our training. You can use the an account on wandb or create an own instance following the instruction in the documentation.

If you then create an .env file in the root folder and specify the WANDB_API_KEY= (and the WANDB_BASE_URL=), you can use dotenv to load those enviroment variables.

from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
True

Setup MLM training

Choose the hyperparameters you want and start the training. The default parameters will train a BERT model with 12 layers and 4 attention heads per layer. The training task is Masked Language Modeling (MLM), where tokens from the input reactions are randomly masked and predicted by the model given the context.

After defining the config, the training is launched in 3 lines of code using our adapter written for the SimpleTransformers library (based on huggingface Transformers).

To make it work you will have to install simpletransformers:

pip install simpletransformers==0.34.4
config = {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 4,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
"type_vocab_size": 2,
}
vocab_path = '../data/uspto_1k_TPL/individual_files/vocab.txt'

args = {'config': config, 
        'vocab_path': vocab_path, 
        'wandb_project': 'uspto_mlm_temp_1000',
        'train_batch_size': 32,
        'manual_seed': 42,
        "fp16": False,
        "num_train_epochs": 50,
        'max_seq_length': 256,
        'evaluate_during_training': True,
        'overwrite_output_dir': True,
        'output_dir': '../out/bert_mlm_1k_tpl',
        'learning_rate': 1e-4
       }
model = SmilesLanguageModelingModel(model_type='bert', model_name=None, args=args)
Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up.
Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up.
train_file = '../data/uspto_1k_TPL/individual_files/mlm_train_file.txt'
eval_file = '../data/uspto_1k_TPL/individual_files/mlm_eval_file_1k.txt'
model.train_model(train_file=train_file, eval_file=eval_file)