Analysis of the predictions on the USPTO 1K TPL data set for the approaches
import pickle
import faiss
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Optional
from collections import Counter
from pycm import ConfusionMatrix, Compare
BERT classifier and k-NN classifier predictions
To compute the evaluation metrics we use PyCM, which can be installed using pip install pycm==2.7
. The nearest neighbours for the different fingerprint approaches are computed using faiss (version 1.5.3).
We have already computed confusion matrix, you can download them (together with the data set and the fingerprints) from MappingChemicalReactions on Box.
def get_nearest_neighbours_prediction(train_X: np.array, train_y: np.array,
eval_X: np.array, n_neighbours: int=5) -> list:
"""
Use faiss to make a K-nearest neighbour prediction
"""
# Indexing
index = faiss.IndexFlatL2(len(train_X[0]))
index.add(train_X)
# Querying
_, results = index.search(eval_X, n_neighbours)
# Scoring
y_pred = get_pred(train_y, results)
return y_pred
def get_pred(y: list, results: list) -> list:
"""
Get most common label from nearest neighbour list
"""
y_pred = []
for i, r in enumerate(results):
y_pred.append(Counter(y[r]).most_common(1)[0][0])
return y_pred
def get_cache_confusion_matrix(
name: str, actual_vector: list, predict_vector: list
) -> ConfusionMatrix:
"""
Make confusion matrix and save it.
"""
cm_cached = load_confusion_matrix(f"{name}.pickle")
if cm_cached is not None:
return cm_cached
cm = ConfusionMatrix(actual_vector=actual_vector, predict_vector=predict_vector)
cm.save_html(name)
with open(f"{name}.pickle", "wb") as f:
pickle.dump(cm, f)
return cm
def load_confusion_matrix(path: str) -> Optional[ConfusionMatrix]:
"""
Load confusion matrix if existing.
"""
if Path(path).is_file():
return pickle.load(open(path, "rb"))
return None
with open('../data/uspto_1k_TPL/individual_files/test_labels.txt', 'r') as f:
labels_true = [int(line.strip()) for line in f.readlines()]
with open('../data/uspto_1k_TPL/individual_files/predictions_bert_classifier.txt', 'r') as f:
labels_bert_class = [int(line.strip()) for line in f.readlines()]
cm_bert_classifier = get_cache_confusion_matrix('../data/uspto_1k_TPL/results/cm_bert_classifier', labels_true, labels_bert_class)
print(f"Accuracy : {cm_bert_classifier.overall_stat['Overall ACC']:.4f}")
print(f"MCC : {cm_bert_classifier.overall_stat['Overall MCC']:.4f}")
print(f"CEN : {cm_bert_classifier.overall_stat['Overall CEN']:.4f}")
# Get confusion matrix for the Schneider FP
if not Path('../data/uspto_1k_TPL/results/cm_schneider_fp_5NN.pickle').is_file():
train_X = np.vstack(
pickle.load(open("../data/uspto_1k_TPL/fingerprints/schneider_AP3_folded_256_plus_agent_features_total_df_1000_role.pkl", "rb"))[
"train_valid"
]
).astype("float32")
train_y = np.array([line.strip() for line in open("../data/uspto_1k_TPL/individual_files/train_valid_labels.txt")])
eval_X = np.vstack(
pickle.load(open("../data/uspto_1k_TPL/fingerprints/schneider_AP3_folded_256_plus_agent_features_total_df_1000_role.pkl", "rb"))[
"test"
]
).astype("float32")
labels_predicted = [int(i) for i in get_nearest_neighbours_prediction(train_X, train_y, eval_X)]
cm_schneider_fp = cm = get_cache_confusion_matrix('../data/uspto_1k_TPL/results/cm_schneider_fp_5NN', labels_true, labels_predicted)
print(f"Accuracy : {cm_schneider_fp.overall_stat['Overall ACC']:.3f}")
print(f"MCC : {cm_schneider_fp.overall_stat['Overall MCC']:.3f}")
print(f"CEN : {cm_schneider_fp.overall_stat['Overall CEN']:.3f}")
# Get confusion matrix for the pretrained BERT fingerprint
if not Path('../data/uspto_1k_TPL/results/cm_bert_mlm_fp_5NN.pickle').is_file():
train_X = np.vstack(
pickle.load(open("../data/uspto_1k_TPL/fingerprints/USPTO_1k_TPL_bert_mlm.pkl", "rb"))[
"train_valid"
]
).astype("float32")
train_y = np.array([line.strip() for line in open("../data/uspto_1k_TPL/individual_files/train_valid_labels.txt")])
eval_X = np.vstack(
pickle.load(open("../data/uspto_1k_TPL/fingerprints/USPTO_1k_TPL_bert_mlm.pkl", "rb"))[
"test"
]
).astype("float32")
labels_predicted = [int(i) for i in get_nearest_neighbours_prediction(train_X, train_y, eval_X)]
cm_bert_mlm_fp = cm = get_cache_confusion_matrix('../data/uspto_1k_TPL/results/cm_bert_mlm_fp_5NN', labels_true, labels_predicted)
print(f"Accuracy : {cm_bert_mlm_fp.overall_stat['Overall ACC']:.3f}")
print(f"MCC : {cm_bert_mlm_fp.overall_stat['Overall MCC']:.3f}")
print(f"CEN : {cm_bert_mlm_fp.overall_stat['Overall CEN']:.3f}")
# Get confusion matrix for the fine-tuned BERT fingerprint
if not Path('../data/uspto_1k_TPL/results/cm_bert_class_fp_5NN.pickle').is_file():
train_X = np.vstack(
pickle.load(open("../data/uspto_1k_TPL/fingerprints/USPTO_1k_TPL_bert_class.pkl", "rb"))[
"train_valid"
]
).astype("float32")
train_y = np.array([line.strip() for line in open("../data/uspto_1k_TPL/individual_files/train_valid_labels.txt")])
eval_X = np.vstack(
pickle.load(open("../data/uspto_1k_TPL/fingerprints/USPTO_1k_TPL_bert_class.pkl", "rb"))[
"test"
]
).astype("float32")
labels_predicted = [int(i) for i in get_nearest_neighbours_prediction(train_X, train_y, eval_X)]
cm_bert_class_fp = cm = get_cache_confusion_matrix('../data/uspto_1k_TPL/results/cm_bert_class_fp_5NN', labels_true, labels_predicted)
print(f"Accuracy : {cm_bert_class_fp.overall_stat['Overall ACC']:.3f}")
print(f"MCC : {cm_bert_class_fp.overall_stat['Overall MCC']:.3f}")
print(f"CEN : {cm_bert_class_fp.overall_stat['Overall CEN']:.3f}")