Skip to content

Commit

Permalink
Label confidence pair post processor logits as tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
tomas-gajarsky committed Dec 14, 2023
1 parent d683d1d commit 3f8832d
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions facetorch/analyzer/predictor/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,20 +297,19 @@ def run(self, preds: torch.Tensor) -> List[Prediction]:
if isinstance(preds, tuple):
preds = preds[0]

# Convert tensor to numpy array once instead of in the loop
preds_np = preds.cpu().numpy()

# Use list comprehension instead of loop for creating pred_list
pred_list = [
Prediction(
pred_list = []
for i in range(preds.shape[0]):
preds_sample = preds[i]
preds_sample_list = preds_sample.cpu().numpy().tolist()
other_labels = {
label: preds_sample_list[j] + self.offsets[j]
for j, label in enumerate(self.labels)
}
pred = Prediction(
label="other",
logits=preds_sample,
other={
label: preds_np[i, j] + offset
for j, (label, offset) in enumerate(zip(self.labels, self.offsets))
},
other=other_labels,
)
for i, preds_sample in enumerate(preds_np)
]
pred_list.append(pred)

return pred_list

0 comments on commit 3f8832d

Please sign in to comment.