diff --git a/flair/models/text_regression_model.py b/flair/models/text_regression_model.py index 9b4cc12c3a..351dc9bed8 100644 --- a/flair/models/text_regression_model.py +++ b/flair/models/text_regression_model.py @@ -64,7 +64,8 @@ def forward_loss(self, sentences: List[Sentence]) -> Tuple[torch.Tensor, int]: def _labels_to_tensor(self, sentences: List[Sentence]): indices = [ - torch.tensor([float(label.value) for label in sentence.labels], dtype=torch.float) for sentence in sentences + torch.tensor([float(label.value) for label in sentence.get_labels(self.label_name)], dtype=torch.float) + for sentence in sentences ] vec = torch.cat(indices, 0).to(flair.device)