An example of a reaction search in the Schneider 50k data set
lf = tm.LSHForest(256, 128)
mh_encoder = tm.Minhash()
with open('../data/rxnclass2name.json', 'r') as f:
rxnclass2name = json.load(f)
schneider_df = pd.read_csv('../data/schneider50k.tsv', sep='\t', index_col=0)
ft_10k_fps = np.load('../data/fps_ft.npz')['fps']
schneider_df['mhfp'] = [mh_encoder.from_weight_array(fp.tolist(), method="I2CWS") for fp in tqdm(ft_10k_fps)]
train_df = schneider_df[schneider_df.split=='train']
train_df.reset_index(inplace=True)
lf.batch_add(train_df.mhfp.values.tolist())
lf.index()
for i, row in schneider_df[schneider_df.split=='test'].sample(n=10, random_state=42).iterrows():
print('------------------------------------------------------------------------------------------------')
print('Query: Reaction class - {} {}'.format(row['rxn_class'], rxnclass2name[row['rxn_class']]))
display(AllChem.ReactionFromSmarts(row['rxn'], useSmiles=True))
print(row['rxn'])
print('------------------------------------------------------------------------------------------------')
print()
nns = lf.query_linear_scan(row['mhfp'], 3, kc=200)
for n, (_, j) in enumerate(nns):
rxn = train_df.iloc[j]['rxn']
display(AllChem.ReactionFromSmarts(rxn, useSmiles=True))
print(rxn)
print('NN-{} - {} {}'.format(n+1, train_df.iloc[j]['rxn_class'], rxnclass2name[train_df.iloc[j]['rxn_class']]))
print()
print()
print()