import logging
from pathlib import Path
from typing import List, Tuple, Union
import click
from rxn.utilities.containers import chunker
from rxn.utilities.files import dump_list_to_file, load_list_from_file
from rxn.metrics.metrics_files import RetroFiles
from rxn.metrics.utils import get_sequence_multiplier
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]def reorder_retro_predictions_class_token(
ground_truth_file: Union[str, Path],
predictions_file: Union[str, Path],
confidences_file: Union[str, Path],
fwd_predictions_file: Union[str, Path],
classes_predictions_file: Union[str, Path],
n_class_tokens: int,
) -> None:
"""
Reorder the retro-preditions generated from a class-token model.
For each sample x, N samples are created where N is the number of class token used.
The retro predictions are originally ordered like e.g.:
'[0] x' -> top1 prediction('[0] x')
-> top2 prediction('[0] x')
...
'[1] x' -> top1 prediction('[1] x')
-> top2 prediction('[1] x')
...
...
'[N] x' -> top1 prediction('[N] x')
-> top2 prediction('[N] x')
...
Starting from the log likelihood on each prediction we reorder them token-wise to remove the token dependency.
So the new predictions for x will be:
x -> sorted([top1 prediction('[i] x') for i in number_class_tokens])
-> sorted([top2 prediction('[i] x') for i in number_class_tokens])
...
"""
logger.info(
f'Reordering file "{predictions_file}", based on {n_class_tokens} class tokens.'
)
# We load the files and chunk the confidences
ground_truth = load_list_from_file(ground_truth_file)
predictions = load_list_from_file(predictions_file)
confidences = load_list_from_file(confidences_file)
fwd_predictions = load_list_from_file(fwd_predictions_file)
classes_predictions = load_list_from_file(classes_predictions_file)
# Get the exact multiplier
multiplier = get_sequence_multiplier(
ground_truth=ground_truth, predictions=predictions
)
if multiplier % n_class_tokens != 0:
raise ValueError(
f"The number of predictions ('{multiplier}') is not an exact "
f"multiple of the number of class tokens '({n_class_tokens})'"
)
topx_per_class_token = int(multiplier / n_class_tokens)
predictions_and_confidences = zip(
predictions, confidences, fwd_predictions, classes_predictions
)
predictions_and_confidences_chunks = chunker(
predictions_and_confidences, chunk_size=multiplier
)
# we will reorder the predictions class-token wise using the confidence
predictions_and_confidences_reordered: List[Tuple[str, str, str, str]] = []
for pred_and_conf in predictions_and_confidences_chunks:
for topn in range(topx_per_class_token):
# For each class token take the topn prediction and reorder them based on the
# (negative) confidence (index x[1])
topn_per_class_token = [
chunk[topn]
for chunk in chunker(pred_and_conf, chunk_size=topx_per_class_token)
]
reordered = sorted(
topn_per_class_token, key=lambda x: float(x[1]), reverse=True
)
predictions_and_confidences_reordered.extend(reordered)
dump_list_to_file(
(pred for pred, _, _, _ in predictions_and_confidences_reordered),
RetroFiles.reordered(predictions_file),
)
dump_list_to_file(
(conf for _, conf, _, _ in predictions_and_confidences_reordered),
RetroFiles.reordered(confidences_file),
)
dump_list_to_file(
(fwd_pred for _, _, fwd_pred, _ in predictions_and_confidences_reordered),
RetroFiles.reordered(fwd_predictions_file),
)
dump_list_to_file(
(
classes_pred
for _, _, _, classes_pred in predictions_and_confidences_reordered
),
RetroFiles.reordered(classes_predictions_file),
)
@click.command()
@click.option(
"--ground_truth_file", "-g", required=True, help="File with ground truth."
)
@click.option(
"--predictions_file", "-p", required=True, help="File with the predictions."
)
@click.option(
"--confidences_file", "-l", required=True, help="File with the confidences."
)
@click.option(
"--fwd_predictions_file",
"-f",
required=True,
help="File with the forward predictions.",
)
@click.option(
"--classes_predictions_file",
"-c",
required=True,
help="File with the classes predictions.",
)
@click.option(
"--n_class_tokens", "-n", required=True, type=int, help="Number of class tokens."
)
def main(
ground_truth_file: str,
predictions_file: str,
confidences_file: str,
fwd_predictions_file: str,
classes_predictions_file: str,
n_class_tokens: int,
) -> None:
logging.basicConfig(format="%(asctime)s [%(levelname)s] %(message)s", level="INFO")
# Note: we put the actual code in a separate function, so that it can be
# called also as a Python function.
reorder_retro_predictions_class_token(
ground_truth_file=ground_truth_file,
predictions_file=predictions_file,
confidences_file=confidences_file,
fwd_predictions_file=fwd_predictions_file,
classes_predictions_file=classes_predictions_file,
n_class_tokens=n_class_tokens,
)
if __name__ == "__main__":
main()