Source code for rxn.reaction_preprocessing.scripts.annotation_check

import csv
from typing import Generator

import click

from rxn.reaction_preprocessing.annotations.annotation_info import AnnotationInfo
from rxn.reaction_preprocessing.annotations.missing_annotation_detector import (
    MissingAnnotationDetector,
)
from rxn.reaction_preprocessing.annotations.molecule_annotation import (
    load_annotations_multiple,
)
from rxn.reaction_preprocessing.config import DEFAULT_ANNOTATION_FILES


[docs]def iterate_rxn_smiles(csv_file: str, column_name: str) -> Generator[str, None, None]: with open(csv_file) as f: r = csv.reader(f) header = next(r) try: smiles_index = header.index(column_name) except ValueError as e: raise RuntimeError(f'No "{column_name}" column in {csv_file}') from e for row in r: yield row[smiles_index]
@click.command() @click.option("--csv_file", required=True) @click.option( "--column_name", required=True, help="Column containing the reaction SMILES" ) def main(csv_file: str, column_name: str) -> None: """Check for missing annotations: what is already annotated (accepted / rejected), what still needs to be annotated.""" iterator = iterate_rxn_smiles(csv_file, column_name) missing_annotation_detector = MissingAnnotationDetector(set()) molecules_requiring_annotation = list( missing_annotation_detector.missing_in_reaction_smiles( iterator, fragment_bond="~" ) ) annotations = load_annotations_multiple(DEFAULT_ANNOTATION_FILES) annotation_info = AnnotationInfo(annotations) not_annotated = [ m for m in molecules_requiring_annotation if not annotation_info.is_annotated(m) ] annotated = [ m for m in molecules_requiring_annotation if annotation_info.is_annotated(m) ] accepted = [m for m in annotated if annotation_info.is_accepted(m)] rejected = [m for m in annotated if annotation_info.is_rejected(m)] to_print = [ ("requiring annotation", molecules_requiring_annotation), ("not annotated", not_annotated), ("annotated", annotated), ("accepted", accepted), ("rejected", rejected), ] # Print summary for label, smiles_list in to_print: print(label, len(smiles_list), len(set(smiles_list))) # Print details for label, smiles_list in to_print: print() print(label) print("=" * len(label)) for smiles in sorted(set(smiles_list)): print(smiles.replace("~", ".")) if __name__ == "__main__": main()