import copy
import os
from argparse import Namespace
from itertools import repeat
from typing import Any, Iterable, Iterator, List, Optional
import attr
import onmt.opts as opts
from onmt.translate.translator import build_translator
from onmt.utils.misc import split_corpus
from onmt.utils.parse import ArgumentParser
from rxn.utilities.files import named_temporary_path
[docs]@attr.s(auto_attribs=True)
class TranslationResult:
"""
Struct containing the result of a translation with OpenNMT.
"""
text: str
score: float
[docs]class RawTranslator:
"""
Translator class that is very coupled to the internal OpenNMT implementation.
"""
[docs] def __init__(self, opt: Namespace):
self.opt = opt
self.score_for_empty_input = -9999.9999
self.dummy_string_for_empty_input = "C . C . C . C"
# to avoid the creation of an unnecessary file
out_file = open(os.devnull, "w")
self.internal_translator = build_translator(
self.opt, report_score=False, out_file=out_file
)
[docs] def translate_sentences_with_onmt(
self, sentences: Iterable[str], **opt_updated_kwargs: Any
) -> Iterator[List[TranslationResult]]:
"""
Do the translation (in tokenized format) with OpenNMT.
Args:
sentences: sentences to translate
opt_updated_kwargs: values to update in the "opt" of the translator. The
translator is not instantiated again from those values, therefore this
only affects values that are used for translation, such as n_best.
"""
new_opt = copy.deepcopy(self.opt)
for key, value in opt_updated_kwargs.items():
setattr(new_opt, key, value)
with named_temporary_path() as tmp_src, named_temporary_path() as tmp_output:
new_opt.src = tmp_src
new_opt.output = tmp_output
# List to track which inputs were empty, for post-processing
empty_input: List[bool] = []
# write source sentences to temporary input file
with open(new_opt.src, "wt") as f:
for sentence in sentences:
# In order to avoid problems with batches full of empty string on GPUs,
# we write a dummy line instead of the empty string. These lines
# are post-processed again below to replace the predictions by
# empty strings.
is_empty = False
if sentence == "":
sentence = self.dummy_string_for_empty_input
is_empty = True
f.write(f"{sentence}\n")
empty_input.append(is_empty)
for translation_results, is_empty in zip(
self.translate_with_onmt(new_opt), empty_input
):
# For predictions corresponding to empty predictions, return
# an empty string with adequate score
if is_empty:
yield [
TranslationResult("", self.score_for_empty_input)
for _ in translation_results
]
else:
yield translation_results
[docs] def translate_with_onmt(self, opt) -> Iterator[List[TranslationResult]]:
"""
Do the translation (in tokenized format) with OpenNMT.
Args:
opt: args given to the main script
Returns:
Generator of TranslationResults; they will be yielded in chunks of
size opt.shard_size.
"""
# for some versions, it seems that n_best is not updated, we therefore do it manually here
self.internal_translator.n_best = opt.n_best
src_shards = split_corpus(opt.src, opt.shard_size)
tgt_shards = (
split_corpus(opt.tgt, opt.shard_size)
if opt.tgt is not None
else repeat(None)
)
shard_pairs = zip(src_shards, tgt_shards)
for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
l1, l2 = self.internal_translator.translate(
src=src_shard,
tgt=tgt_shard,
src_dir=opt.src_dir,
batch_size=opt.batch_size,
batch_type=opt.batch_type,
attn_debug=opt.attn_debug,
)
for score_list, translation_list in zip(l1, l2):
yield [
TranslationResult(text=t, score=s.item())
for s, t in zip(score_list, translation_list)
]
[docs]def get_onmt_opt(
translation_model: Iterable[str],
src_file: Optional[str] = None,
output_file: Optional[str] = None,
**kwargs: Any,
) -> Namespace:
"""
Create the opt arguments by taking the defaults and overwriting a few values.
Args:
translation_model: Model(s) to for translation
src_file: Source file
output_file: Output file
kwargs: additional values to change in the resulting opt
"""
# Some values are needed and must be parsed from args, other values can
# simply be overwritten from the default ones
src = src_file if src_file is not None else "(unused)"
output = output_file if output_file is not None else "(unused)"
args_str = f'--model {" ".join(translation_model)} --src {src} --output {output}'
args = args_str.split()
parser = onmt_parser()
opt = parser.parse_args(args)
for key, value in kwargs.items():
setattr(opt, key, value)
ArgumentParser.validate_translate_opts(opt)
return opt
[docs]def onmt_parser() -> ArgumentParser:
"""
Create the OpenNMT parser, adapted from OpenNMT-Py repo.
"""
parser = ArgumentParser(description="translate.py")
opts.config_opts(parser)
opts.translate_opts(parser)
return parser