diff --git a/facetorch/analyzer/predictor/post.py b/facetorch/analyzer/predictor/post.py index 21da2cc..bf9f62d 100644 --- a/facetorch/analyzer/predictor/post.py +++ b/facetorch/analyzer/predictor/post.py @@ -297,20 +297,19 @@ def run(self, preds: torch.Tensor) -> List[Prediction]: if isinstance(preds, tuple): preds = preds[0] - # Convert tensor to numpy array once instead of in the loop - preds_np = preds.cpu().numpy() - - # Use list comprehension instead of loop for creating pred_list - pred_list = [ - Prediction( + pred_list = [] + for i in range(preds.shape[0]): + preds_sample = preds[i] + preds_sample_list = preds_sample.cpu().numpy().tolist() + other_labels = { + label: preds_sample_list[j] + self.offsets[j] + for j, label in enumerate(self.labels) + } + pred = Prediction( label="other", logits=preds_sample, - other={ - label: preds_np[i, j] + offset - for j, (label, offset) in enumerate(zip(self.labels, self.offsets)) - }, + other=other_labels, ) - for i, preds_sample in enumerate(preds_np) - ] + pred_list.append(pred) return pred_list