Analysis of the predictions on the USPTO 1K TPL data set for the approaches
import pickle
import faiss
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Optional
from collections import Counter
from pycm import ConfusionMatrix, Compare

BERT classifier and k-NN classifier predictions

To compute the evaluation metrics we use PyCM, which can be installed using pip install pycm==2.7. The nearest neighbours for the different fingerprint approaches are computed using faiss (version 1.5.3).

We have already computed confusion matrix, you can download them (together with the data set and the fingerprints) from MappingChemicalReactions on Box.

def get_nearest_neighbours_prediction(train_X: np.array, train_y: np.array, 
                                      eval_X: np.array, n_neighbours: int=5) -> list:
    """
    Use faiss to make a K-nearest neighbour prediction
    """
    # Indexing
    index = faiss.IndexFlatL2(len(train_X[0]))
    index.add(train_X)

    # Querying
    _, results = index.search(eval_X, n_neighbours)

    # Scoring
    y_pred = get_pred(train_y, results)
    
    return y_pred
    

def get_pred(y: list, results: list) -> list:
    """
    Get most common label from nearest neighbour list
    """
    y_pred = []
    for i, r in enumerate(results):
        y_pred.append(Counter(y[r]).most_common(1)[0][0])
    return y_pred


def get_cache_confusion_matrix(
    name: str, actual_vector: list, predict_vector: list
) -> ConfusionMatrix:
    """
    Make confusion matrix and save it. 
    """
    cm_cached = load_confusion_matrix(f"{name}.pickle")

    if cm_cached is not None:
        return cm_cached
    
    cm = ConfusionMatrix(actual_vector=actual_vector, predict_vector=predict_vector)
    cm.save_html(name)
    with open(f"{name}.pickle", "wb") as f:
        pickle.dump(cm, f)
    return cm

def load_confusion_matrix(path: str) -> Optional[ConfusionMatrix]:
    """
    Load confusion matrix if existing.
    """
    if Path(path).is_file():
        return pickle.load(open(path, "rb"))
    return None

Generate confusion matrices

The precomputed fingerprints and confusion matrices can be downloaded from (include link)

with open('../data/uspto_1k_TPL/individual_files/test_labels.txt', 'r') as f:
    labels_true = [int(line.strip()) for line in f.readlines()]

BERT classifier predictions

with open('../data/uspto_1k_TPL/individual_files/predictions_bert_classifier.txt', 'r') as f:
    labels_bert_class = [int(line.strip()) for line in f.readlines()]
cm_bert_classifier = get_cache_confusion_matrix('../data/uspto_1k_TPL/results/cm_bert_classifier', labels_true, labels_bert_class)
print(f"Accuracy : {cm_bert_classifier.overall_stat['Overall ACC']:.4f}")
print(f"MCC : {cm_bert_classifier.overall_stat['Overall MCC']:.4f}")
print(f"CEN : {cm_bert_classifier.overall_stat['Overall CEN']:.4f}")
Accuracy : 0.9893
MCC : 0.9893
CEN : 0.0056

AP3_folded_256_plus_agent_features fingerprint, 5-NN

# Get confusion matrix for the Schneider FP
if not Path('../data/uspto_1k_TPL/results/cm_schneider_fp_5NN.pickle').is_file():

    train_X = np.vstack(
            pickle.load(open("../data/uspto_1k_TPL/fingerprints/schneider_AP3_folded_256_plus_agent_features_total_df_1000_role.pkl", "rb"))[
                "train_valid"
            ]
        ).astype("float32")


    train_y = np.array([line.strip() for line in open("../data/uspto_1k_TPL/individual_files/train_valid_labels.txt")])

    eval_X = np.vstack(
        pickle.load(open("../data/uspto_1k_TPL/fingerprints/schneider_AP3_folded_256_plus_agent_features_total_df_1000_role.pkl", "rb"))[
            "test"
        ]
    ).astype("float32")

    labels_predicted = [int(i) for i in get_nearest_neighbours_prediction(train_X, train_y, eval_X)]
cm_schneider_fp = cm = get_cache_confusion_matrix('../data/uspto_1k_TPL/results/cm_schneider_fp_5NN', labels_true, labels_predicted)
print(f"Accuracy : {cm_schneider_fp.overall_stat['Overall ACC']:.3f}")
print(f"MCC : {cm_schneider_fp.overall_stat['Overall MCC']:.3f}")
print(f"CEN : {cm_schneider_fp.overall_stat['Overall CEN']:.3f}")
Accuracy : 0.295
MCC : 0.292
CEN : 0.424

BERT fingerprint pretrained with MLM

# Get confusion matrix for the pretrained BERT fingerprint
if not Path('../data/uspto_1k_TPL/results/cm_bert_mlm_fp_5NN.pickle').is_file():

    train_X = np.vstack(
            pickle.load(open("../data/uspto_1k_TPL/fingerprints/USPTO_1k_TPL_bert_mlm.pkl", "rb"))[
                "train_valid"
            ]
        ).astype("float32")


    train_y = np.array([line.strip() for line in open("../data/uspto_1k_TPL/individual_files/train_valid_labels.txt")])

    eval_X = np.vstack(
        pickle.load(open("../data/uspto_1k_TPL/fingerprints/USPTO_1k_TPL_bert_mlm.pkl", "rb"))[
            "test"
        ]
    ).astype("float32")

    labels_predicted = [int(i) for i in get_nearest_neighbours_prediction(train_X, train_y, eval_X)]
    
    
    
cm_bert_mlm_fp = cm = get_cache_confusion_matrix('../data/uspto_1k_TPL/results/cm_bert_mlm_fp_5NN', labels_true, labels_predicted)
print(f"Accuracy : {cm_bert_mlm_fp.overall_stat['Overall ACC']:.3f}")
print(f"MCC : {cm_bert_mlm_fp.overall_stat['Overall MCC']:.3f}")
print(f"CEN : {cm_bert_mlm_fp.overall_stat['Overall CEN']:.3f}")
Accuracy : 0.340
MCC : 0.337
CEN : 0.392

BERT fingerprint after fine-tuning on template labels

# Get confusion matrix for the fine-tuned BERT fingerprint
if not Path('../data/uspto_1k_TPL/results/cm_bert_class_fp_5NN.pickle').is_file():

    train_X = np.vstack(
            pickle.load(open("../data/uspto_1k_TPL/fingerprints/USPTO_1k_TPL_bert_class.pkl", "rb"))[
                "train_valid"
            ]
        ).astype("float32")


    train_y = np.array([line.strip() for line in open("../data/uspto_1k_TPL/individual_files/train_valid_labels.txt")])

    eval_X = np.vstack(
        pickle.load(open("../data/uspto_1k_TPL/fingerprints/USPTO_1k_TPL_bert_class.pkl", "rb"))[
            "test"
        ]
    ).astype("float32")

    labels_predicted = [int(i) for i in get_nearest_neighbours_prediction(train_X, train_y, eval_X)]
cm_bert_class_fp = cm = get_cache_confusion_matrix('../data/uspto_1k_TPL/results/cm_bert_class_fp_5NN', labels_true, labels_predicted)
print(f"Accuracy : {cm_bert_class_fp.overall_stat['Overall ACC']:.3f}")
print(f"MCC : {cm_bert_class_fp.overall_stat['Overall MCC']:.3f}")
print(f"CEN : {cm_bert_class_fp.overall_stat['Overall CEN']:.3f}")
Accuracy : 0.989
MCC : 0.989
CEN : 0.006