Skip to content

Commit

Permalink
fix bug in multitask stitcher
Browse files Browse the repository at this point in the history
  • Loading branch information
mazabou committed Jan 2, 2025
1 parent 7a59bb5 commit 014d5db
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions torch_brain/utils/stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx)

token_sample_idx = torch.where(mask)[0]

curr_sample_ptr = self.sample_ptr

for i in torch.unique(token_sample_idx):
pred = output_values[i][readout_id]
target = target_values[readout_id][token_sample_idx == i]
Expand All @@ -287,19 +289,21 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx)
)
subtask_idx = output_subtask_index[readout_id][token_sample_idx == i]

self.cache[self.sequence_index[self.sample_ptr]]["pred"][
self.cache[self.sequence_index[curr_sample_ptr]]["pred"][
readout_id
].append(pred.detach().cpu())
self.cache[self.sequence_index[self.sample_ptr]]["target"][
self.cache[self.sequence_index[curr_sample_ptr]]["target"][
readout_id
].append(target.detach().cpu())
self.cache[self.sequence_index[self.sample_ptr]]["timestamps"][
self.cache[self.sequence_index[curr_sample_ptr]]["timestamps"][
readout_id
].append(timestamps.detach().cpu())
self.cache[self.sequence_index[self.sample_ptr]]["subtask_index"][
self.cache[self.sequence_index[curr_sample_ptr]]["subtask_index"][
readout_id
].append(subtask_idx.detach().cpu())

curr_sample_ptr += 1

# update counter then check if the cache should be flushed
for i in range(len(output_values)):
j = self.sequence_index[self.sample_ptr]
Expand Down

0 comments on commit 014d5db

Please sign in to comment.