diff --git a/torch_brain/utils/stitcher.py b/torch_brain/utils/stitcher.py index 2708bbb..bc711bb 100644 --- a/torch_brain/utils/stitcher.py +++ b/torch_brain/utils/stitcher.py @@ -278,6 +278,8 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) token_sample_idx = torch.where(mask)[0] + curr_sample_ptr = self.sample_ptr + for i in torch.unique(token_sample_idx): pred = output_values[i][readout_id] target = target_values[readout_id][token_sample_idx == i] @@ -287,19 +289,21 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) ) subtask_idx = output_subtask_index[readout_id][token_sample_idx == i] - self.cache[self.sequence_index[self.sample_ptr]]["pred"][ + self.cache[self.sequence_index[curr_sample_ptr]]["pred"][ readout_id ].append(pred.detach().cpu()) - self.cache[self.sequence_index[self.sample_ptr]]["target"][ + self.cache[self.sequence_index[curr_sample_ptr]]["target"][ readout_id ].append(target.detach().cpu()) - self.cache[self.sequence_index[self.sample_ptr]]["timestamps"][ + self.cache[self.sequence_index[curr_sample_ptr]]["timestamps"][ readout_id ].append(timestamps.detach().cpu()) - self.cache[self.sequence_index[self.sample_ptr]]["subtask_index"][ + self.cache[self.sequence_index[curr_sample_ptr]]["subtask_index"][ readout_id ].append(subtask_idx.detach().cpu()) + curr_sample_ptr += 1 + # update counter then check if the cache should be flushed for i in range(len(output_values)): j = self.sequence_index[self.sample_ptr]