import logging
import re
from itertools import count
from pathlib import Path
from typing import List, Optional
from rxn.utilities.files import PathLike
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]class ModelFiles:
"""
Class to make it easy to get the names/paths of the trained OpenNMT models.
"""
ONMT_CONFIG_FILE = "config_{idx}.yml"
MODEL_PREFIX = "model"
MODEL_STEP_PATTERN = re.compile(r"^model_step_(\d+)\.pt$")
[docs] def __init__(self, model_dir: PathLike):
# Directly converting to an absolute path
self.model_dir = Path(model_dir).resolve()
# Create the directory if it does not exist yet
self.model_dir.mkdir(parents=True, exist_ok=True)
@property
def model_prefix(self) -> Path:
"""Absolute path to the model prefix; during training, OpenNMT will
append "_step_10000.pt" to it (or other step numbers)."""
return self.model_dir / ModelFiles.MODEL_PREFIX
[docs] def next_config_file(self) -> Path:
"""Get the next available config file name."""
for idx in count(1):
config_file = self.model_dir / ModelFiles.ONMT_CONFIG_FILE.format(idx=idx)
if not config_file.exists():
return config_file
return Path() # Note: in order to satisfy mypy. This is never reached.
[docs] def get_checkpoints(self) -> List[Path]:
"""Get the checkpoints contained in the directory, sorted by step number."""
steps_and_models = [
(self._get_checkpoint_step(path), path) for path in self.model_dir.iterdir()
]
steps_and_models = [
(step, path) for step, path in steps_and_models if step is not None
]
# Sort, from low checkpoint to high checkpoint
steps_and_models.sort()
return [model for _, model in steps_and_models]
[docs] def get_last_checkpoint(self) -> Path:
"""Get the last checkpoint matching the naming including the step number.
Raises:
RuntimeError: no model is found in the expected directory.
"""
models = self.get_checkpoints()
if not models:
raise RuntimeError(f'No model found in "{self.model_dir}"')
return models[-1]
@staticmethod
def _get_checkpoint_step(path: Path) -> Optional[int]:
"""Get the step from the path of a given model. None if no match."""
match = ModelFiles.MODEL_STEP_PATTERN.match(path.name)
if match is None:
return None
return int(match.group(1))
[docs]class OnmtPreprocessedFiles:
"""
Class to make it easy to get the names/paths of the OpenNMT-preprocessed files.
"""
PREFIX = "preprocessed"
[docs] def __init__(self, preprocessed_dir: PathLike):
# Directly converting to an absolute path
self.preprocessed_dir = Path(preprocessed_dir).resolve()
# Create the directory if it does not exist yet
self.preprocessed_dir.mkdir(parents=True, exist_ok=True)
@property
def preprocess_prefix(self) -> Path:
"""Absolute path to the prefix for the preprocessed files; during preprocessing,
OpenNMT will append ".train.0.pt", ".valid.0.pt", ".vocab.pt", etc."""
return self.preprocessed_dir / OnmtPreprocessedFiles.PREFIX
@property
def vocab_file(self) -> Path:
return self.preprocess_prefix.with_suffix(".vocab.pt")
[docs]class RxnPreprocessingFiles:
"""
Class to make it easy to get the names/paths of the files generated during data preprocessing.
This assumes that the default paths were used when calling rxn-data-pipeline.
"""
FILENAME_ROOT = "data"
[docs] def __init__(self, processed_data_dir: PathLike):
# Directly converting to an absolute path
self.processed_data_dir = Path(processed_data_dir).resolve()
def _add_extension(self, extension: str) -> Path:
"""
Helper function get the path of the file produced with the given extension.
Args:
extension: extension to add
Returns:
Path to the file with the given extension.
"""
if not extension.startswith("."):
extension = "." + extension
return self.processed_data_dir / (
RxnPreprocessingFiles.FILENAME_ROOT + extension
)
@property
def standardized_csv(self) -> Path:
return self._add_extension("standardized.csv")
@property
def processed_csv(self) -> Path:
return self._add_extension("processed.csv")
def get_processed_csv_for_split(self, split: str) -> Path:
split = self._validate_split(split)
return self._add_extension(f"processed.{split}.csv")
@property
def processed_train_csv(self) -> Path:
return self.get_processed_csv_for_split("train")
@property
def processed_validation_csv(self) -> Path:
return self.get_processed_csv_for_split("validation")
@property
def processed_test_csv(self) -> Path:
return self.get_processed_csv_for_split("test")
def get_precursors_for_split(self, split: str) -> Path:
split = self._validate_split(split)
return self._add_extension(f"processed.{split}.precursors_tokens")
def get_products_for_split(self, split: str) -> Path:
split = self._validate_split(split)
return self._add_extension(f"processed.{split}.products_tokens")
@property
def train_precursors(self) -> Path:
return self.get_precursors_for_split("train")
@property
def train_products(self) -> Path:
return self.get_products_for_split("train")
@property
def validation_precursors(self) -> Path:
return self.get_precursors_for_split("validation")
@property
def validation_products(self) -> Path:
return self.get_products_for_split("validation")
@property
def test_precursors(self) -> Path:
return self.get_precursors_for_split("test")
@property
def test_products(self) -> Path:
return self.get_products_for_split("test")
def get_context_tags_for_split(self, split: str) -> Path:
split = self._validate_split(split)
return self._add_extension(f"processed.{split}.context.tagged")
def get_context_src_for_split(self, split: str) -> Path:
split = self._validate_split(split)
return self._add_extension(f"processed.{split}.context.src")
def get_context_tgt_for_split(self, split: str) -> Path:
split = self._validate_split(split)
return self._add_extension(f"processed.{split}.context.tgt")
[docs] @staticmethod
def augmented(data_path: Path) -> Path:
"""Get the path for the augmented version of a data file."""
return data_path.with_name(data_path.name + ".augmented")
def _validate_split(self, split: str) -> str:
if split == "train":
return "train"
if split == "valid" or split == "validation":
return "validation"
if split == "test":
return "test"
raise ValueError(f'Unsupported split: "{split}"')
[docs] def get_src_file(self, split: str, model_task: str) -> Path:
"""Get the source file for the given task.
Note: the file is tokenized for the forward and retro tasks, but not
for the context task.
"""
if model_task == "forward":
return self.get_precursors_for_split(split)
if model_task == "retro":
return self.get_products_for_split(split)
if model_task == "context":
return self.get_context_src_for_split(split)
raise ValueError(f'Unsupported model task: "{model_task}"')
[docs] def get_tgt_file(self, split: str, model_task: str) -> Path:
"""Get the target file for the given task.
Note: the file is tokenized for the forward and retro tasks, but not
for the context task.
"""
if model_task == "forward":
return self.get_products_for_split(split)
if model_task == "retro":
return self.get_precursors_for_split(split)
if model_task == "context":
return self.get_context_tgt_for_split(split)
raise ValueError(f'Unsupported model task: "{model_task}"')