rxn.metrics.scripts.reorder_retro_predictions_class_token.reorder_retro_predictions_class_token
- rxn.metrics.scripts.reorder_retro_predictions_class_token.reorder_retro_predictions_class_token(ground_truth_file, predictions_file, confidences_file, fwd_predictions_file, classes_predictions_file, n_class_tokens)[source]
Reorder the retro-preditions generated from a class-token model.
For each sample x, N samples are created where N is the number of class token used. The retro predictions are originally ordered like e.g.:
- ‘[0] x’ -> top1 prediction(‘[0] x’)
-> top2 prediction(‘[0] x’) …
- ‘[1] x’ -> top1 prediction(‘[1] x’)
-> top2 prediction(‘[1] x’) …
… ‘[N] x’ -> top1 prediction(‘[N] x’)
-> top2 prediction(‘[N] x’) …
Starting from the log likelihood on each prediction we reorder them token-wise to remove the token dependency. So the new predictions for x will be:
- x -> sorted([top1 prediction(‘[i] x’) for i in number_class_tokens])
-> sorted([top2 prediction(‘[i] x’) for i in number_class_tokens]) …
- Parameters
ground_truth_file (
Union
[str
,Path
]) –predictions_file (
Union
[str
,Path
]) –confidences_file (
Union
[str
,Path
]) –fwd_predictions_file (
Union
[str
,Path
]) –classes_predictions_file (
Union
[str
,Path
]) –n_class_tokens (
int
) –
- Return type
None