rxn.metrics.scripts.reorder_retro_predictions_class_token.reorder_retro_predictions_class_token

rxn.metrics.scripts.reorder_retro_predictions_class_token.reorder_retro_predictions_class_token(ground_truth_file, predictions_file, confidences_file, fwd_predictions_file, classes_predictions_file, n_class_tokens)[source]

Reorder the retro-preditions generated from a class-token model.

For each sample x, N samples are created where N is the number of class token used. The retro predictions are originally ordered like e.g.:

‘[0] x’ -> top1 prediction(‘[0] x’)

-> top2 prediction(‘[0] x’) …

‘[1] x’ -> top1 prediction(‘[1] x’)

-> top2 prediction(‘[1] x’) …

… ‘[N] x’ -> top1 prediction(‘[N] x’)

-> top2 prediction(‘[N] x’) …

Starting from the log likelihood on each prediction we reorder them token-wise to remove the token dependency. So the new predictions for x will be:

x -> sorted([top1 prediction(‘[i] x’) for i in number_class_tokens])

-> sorted([top2 prediction(‘[i] x’) for i in number_class_tokens]) …

Parameters
  • ground_truth_file (Union[str, Path]) –

  • predictions_file (Union[str, Path]) –

  • confidences_file (Union[str, Path]) –

  • fwd_predictions_file (Union[str, Path]) –

  • classes_predictions_file (Union[str, Path]) –

  • n_class_tokens (int) –

Return type

None