import logging
from enum import Flag
from typing import Any, List, Optional, Tuple
from rxn.utilities.files import PathLike
from .model_introspection import get_model_rnn_size
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]class RxnCommand(Flag):
"""
Flag indicating which command(s) the parameters relate to.
TC, TF, TCF are the combinations of the three base flags.
This enum allows for easily checking which commands some parameters relate
to (see Parameter and TrainingPlanner classes).
"""
T = 1 # Train
C = 2 # Continue training
F = 4 # Fine-tune
TC = 3
TF = 5
CF = 6
TCF = 7
[docs]class Arg:
"""
Represents an argument to be given for the onmt_train command.
Attributes:
key: argument name (i.e. what is forwarded to onmt_train, without the dash).
default: default value that we use for that argument in the RXN universe.
None indicates that this argument must be provided explicitly, an
empty string is used for boolean args not requiring a value.
needed_for: what commands this argument is needed for (train, finetune, etc.)
"""
[docs] def __init__(self, key: str, default: Any, needed_for: RxnCommand):
self.key = key
self.default = default
self.needed_for = needed_for
ONMT_TRAIN_ARGS: List[Arg] = [
Arg("accum_count", "4", RxnCommand.TCF),
Arg("adam_beta1", "0.9", RxnCommand.TF),
Arg("adam_beta2", "0.998", RxnCommand.TF),
Arg("batch_size", None, RxnCommand.TCF),
Arg("batch_type", "tokens", RxnCommand.TCF),
Arg("data", None, RxnCommand.TCF),
Arg("decay_method", "noam", RxnCommand.TF),
Arg("decoder_type", "transformer", RxnCommand.T),
Arg("dropout", None, RxnCommand.TCF),
Arg("encoder_type", "transformer", RxnCommand.T),
Arg("global_attention", "general", RxnCommand.T),
Arg("global_attention_function", "softmax", RxnCommand.T),
Arg("heads", None, RxnCommand.T),
Arg("keep_checkpoint", "-1", RxnCommand.TCF),
Arg("label_smoothing", "0.0", RxnCommand.TCF),
Arg("layers", None, RxnCommand.T),
Arg("learning_rate", None, RxnCommand.TF),
Arg("max_generator_batches", "32", RxnCommand.TCF),
Arg("max_grad_norm", "0", RxnCommand.TF),
Arg("normalization", "tokens", RxnCommand.TCF),
Arg("optim", "adam", RxnCommand.TF),
Arg("param_init", "0", RxnCommand.T),
Arg("param_init_glorot", "", RxnCommand.T), # note: empty means "nothing"
Arg("position_encoding", "", RxnCommand.T), # note: empty means "nothing"
Arg("report_every", "1000", RxnCommand.TCF),
Arg("reset_optim", None, RxnCommand.CF),
Arg("rnn_size", None, RxnCommand.TF),
Arg("save_checkpoint_steps", "5000", RxnCommand.TCF),
Arg("save_model", None, RxnCommand.TCF),
Arg("seed", None, RxnCommand.TCF),
Arg("self_attn_type", "scaled-dot", RxnCommand.T),
Arg("share_embeddings", "", RxnCommand.T), # note: empty means "nothing"
Arg("train_from", None, RxnCommand.CF),
Arg("train_steps", None, RxnCommand.TCF),
Arg("transformer_ff", None, RxnCommand.T),
Arg("valid_batch_size", "8", RxnCommand.TCF),
Arg("warmup_steps", None, RxnCommand.TF),
Arg("word_vec_size", None, RxnCommand.T),
]
[docs]class OnmtTrainCommand:
"""
Class to build the onmt_command for training models, continuing the
training, or finetuning.
"""
[docs] def __init__(
self,
command_type: RxnCommand,
no_gpu: bool,
data_weights: Tuple[int, ...],
**kwargs: Any,
):
self._command_type = command_type
self._no_gpu = no_gpu
self._data_weights = data_weights
self._kwargs = kwargs
def _build_cmd(self) -> List[str]:
"""
Build the base command.
"""
command = ["onmt_train"]
for arg in ONMT_TRAIN_ARGS:
arg_given = arg.key in self._kwargs
if self._command_type not in arg.needed_for:
# Check that the arg was not given; then go to the next argument.
if arg_given:
raise ValueError(
f'"{arg.key}" value given, but not necessary for {command}'
)
continue
# Case 1: something given (whether there was a default or not)
if arg_given:
value = str(self._kwargs[arg.key])
# Case 2: default is None (i.e. a value is needed) but nothing was given
elif arg.default is None:
raise ValueError(f"No value given for {arg.key}")
# Case 3: does not need value and nothing given
else:
value = str(arg.default)
# Add the args to the command. Note: if the value is the empty string,
# do not add anything (typically for boolean args)
command.append(f"-{arg.key}")
if value != "":
command.append(value)
command += self._args_for_gpu()
command += self._args_for_data_weights()
return command
def _args_for_gpu(self) -> List[str]:
if self._no_gpu:
return []
return ["-gpu_ranks", "0"]
def _args_for_data_weights(self) -> List[str]:
if not self._data_weights:
return []
n_additional_datasets = len(self._data_weights) - 1
data_ids = preprocessed_id_names(n_additional_datasets)
return [
"-data_ids",
*data_ids,
"-data_weights",
*(str(weight) for weight in self._data_weights),
]
[docs] def cmd(self) -> List[str]:
"""
Return the "raw" command for executing onmt_train.
"""
return self._build_cmd()
[docs] def save_to_config_cmd(self, config_file: PathLike) -> List[str]:
"""
Return the command for saving the config to a file.
"""
return self._build_cmd() + ["-save_config", str(config_file)]
[docs] @staticmethod
def execute_from_config_cmd(config_file: PathLike) -> List[str]:
"""
Return the command for executing onmt_train with values read from the config.
"""
return ["onmt_train", "-config", str(config_file)]
@classmethod
def train(
cls,
batch_size: int,
data: PathLike,
dropout: float,
heads: int,
layers: int,
learning_rate: float,
rnn_size: int,
save_model: PathLike,
seed: int,
train_steps: int,
transformer_ff: int,
warmup_steps: int,
word_vec_size: int,
no_gpu: bool,
data_weights: Tuple[int, ...],
keep_checkpoint: int = -1,
) -> "OnmtTrainCommand":
return cls(
command_type=RxnCommand.T,
no_gpu=no_gpu,
data_weights=data_weights,
batch_size=batch_size,
data=data,
dropout=dropout,
heads=heads,
keep_checkpoint=keep_checkpoint,
layers=layers,
learning_rate=learning_rate,
rnn_size=rnn_size,
save_model=save_model,
seed=seed,
train_steps=train_steps,
transformer_ff=transformer_ff,
warmup_steps=warmup_steps,
word_vec_size=word_vec_size,
)
@classmethod
def continue_training(
cls,
batch_size: int,
data: PathLike,
dropout: float,
save_model: PathLike,
seed: int,
train_from: PathLike,
train_steps: int,
no_gpu: bool,
data_weights: Tuple[int, ...],
keep_checkpoint: int = -1,
) -> "OnmtTrainCommand":
return cls(
command_type=RxnCommand.C,
no_gpu=no_gpu,
data_weights=data_weights,
batch_size=batch_size,
data=data,
dropout=dropout,
keep_checkpoint=keep_checkpoint,
reset_optim="none",
save_model=save_model,
seed=seed,
train_from=train_from,
train_steps=train_steps,
)
@classmethod
def finetune(
cls,
batch_size: int,
data: PathLike,
dropout: float,
learning_rate: float,
save_model: PathLike,
seed: int,
train_from: PathLike,
train_steps: int,
warmup_steps: int,
no_gpu: bool,
data_weights: Tuple[int, ...],
report_every: int,
save_checkpoint_steps: int,
keep_checkpoint: int = -1,
rnn_size: Optional[int] = None,
) -> "OnmtTrainCommand":
if rnn_size is None:
# In principle, the rnn_size should not be needed for finetuning. However,
# when resetting the decay algorithm for the learning rate, this value
# is necessary - and does not get it from the model checkpoint (OpenNMT bug).
rnn_size = get_model_rnn_size(train_from)
logger.info(f"Loaded the value of rnn_size from the model: {rnn_size}.")
return cls(
command_type=RxnCommand.F,
no_gpu=no_gpu,
data_weights=data_weights,
batch_size=batch_size,
data=data,
dropout=dropout,
keep_checkpoint=keep_checkpoint,
learning_rate=learning_rate,
reset_optim="all",
rnn_size=rnn_size,
save_model=save_model,
seed=seed,
train_from=train_from,
train_steps=train_steps,
warmup_steps=warmup_steps,
report_every=report_every,
save_checkpoint_steps=save_checkpoint_steps,
)
[docs]def preprocessed_id_names(n_additional_sets: int) -> List[str]:
"""Get the names of the ids for the datasets used in multi-task training
with OpenNMT.
Args:
n_additional_sets: how many sets there are in addition to the main set.
"""
return ["main_set"] + [f"additional_set_{i+1}" for i in range(n_additional_sets)]