From d34bfd096353d057ae0e39aaa9d4c041f47f7278 Mon Sep 17 00:00:00 2001 From: Sheldon Roberts Date: Fri, 16 Aug 2024 17:23:40 -0700 Subject: [PATCH 1/7] Add DeepNCMClassifier model Add tests for DeepNCMClassifier Remove old test Add multi label support Add type hints and doc strings --- flair/models/__init__.py | 2 + flair/models/deepncm_classification_model.py | 455 ++++++++++++++++++ flair/trainers/plugins/__init__.py | 2 + .../functional/deepncm_trainer_plugin.py | 41 ++ tests/models/test_deepncm_classifier.py | 167 +++++++ 5 files changed, 667 insertions(+) create mode 100644 flair/models/deepncm_classification_model.py create mode 100644 flair/trainers/plugins/functional/deepncm_trainer_plugin.py create mode 100644 tests/models/test_deepncm_classifier.py diff --git a/flair/models/__init__.py b/flair/models/__init__.py index e75daf074b..bf3651078a 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,3 +1,4 @@ +from .deepncm_classification_model import DeepNCMClassifier from .entity_linker_model import SpanClassifier from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel @@ -37,4 +38,5 @@ "TextClassifier", "TextRegressor", "MultitaskModel", + "DeepNCMClassifier", ] diff --git a/flair/models/deepncm_classification_model.py b/flair/models/deepncm_classification_model.py new file mode 100644 index 0000000000..b942e28919 --- /dev/null +++ b/flair/models/deepncm_classification_model.py @@ -0,0 +1,455 @@ +import logging +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import torch +from tqdm import tqdm + +import flair +from flair.data import Dictionary, Sentence +from flair.datasets import DataLoader, FlairDatapointDataset +from flair.embeddings import DocumentEmbeddings +from flair.embeddings.base import load_embeddings +from flair.nn import Classifier + +log = logging.getLogger("flair") + + +class DeepNCMClassifier(Classifier[Sentence]): + """Deep Nearest Class Mean (DeepNCM) Classifier for text classification tasks. + + This model combines deep learning with the Nearest Class Mean (NCM) approach. + It uses document embeddings to represent text, optionally applies an encoder, + and classifies based on the nearest class prototype in the embedded space. + + The model supports various methods for updating class prototypes during training, + making it adaptable to different learning scenarios. + + This implementation is based on the research paper: + Guerriero, S., Caputo, B., & Mensink, T. (2018). DeepNCM: Deep Nearest Class Mean Classifiers. + In International Conference on Learning Representations (ICLR) 2018 Workshop. + URL: https://openreview.net/forum?id=rkPLZ4JPM + """ + + def __init__( + self, + embeddings: DocumentEmbeddings, + label_dictionary: Dictionary, + label_type: str, + encoding_dim: Optional[int] = None, + alpha: float = 0.9, + mean_update_method: Literal["online", "condensation", "decay"] = "online", + use_encoder: bool = True, + multi_label: bool = False, + multi_label_threshold: float = 0.5, + ): + """Initialize a DeepNCMClassifier. + + Args: + embeddings: Document embeddings to use for encoding text. + label_dictionary: Dictionary containing the label vocabulary. + label_type: The type of label to predict. + encoding_dim: The dimensionality of the encoded embeddings (default is the same as the input embeddings). + alpha: The decay factor for updating class prototypes (default is 0.9). + mean_update_method: The method for updating class prototypes ('online', 'condensation', or 'decay'). + use_encoder: Whether to apply an encoder to the input embeddings (default is True). + multi_label: Whether to predict multiple labels per sentence (default is False). + multi_label_threshold: The threshold for multi-label prediction (default is 0.5). + """ + super().__init__() + + self.embeddings = embeddings + self.label_dictionary = label_dictionary + self._label_type = label_type + self.alpha = alpha + self.mean_update_method = mean_update_method + self.use_encoder = use_encoder + self.multi_label = multi_label + self.multi_label_threshold = multi_label_threshold + self.num_classes = len(label_dictionary) + self.embedding_dim = embeddings.embedding_length + + if use_encoder: + self.encoding_dim = encoding_dim or self.embedding_dim + else: + self.encoding_dim = self.embedding_dim + + self._validate_parameters() + + if self.use_encoder: + self.encoder = torch.nn.Sequential( + torch.nn.Linear(self.embedding_dim, self.encoding_dim * 2), + torch.nn.ReLU(), + torch.nn.Linear(self.encoding_dim * 2, self.encoding_dim), + ) + else: + self.encoder = torch.nn.Sequential(torch.nn.Identity()) + + self.loss_function = ( + torch.nn.BCEWithLogitsLoss(reduction="sum") + if self.multi_label + else torch.nn.CrossEntropyLoss(reduction="sum") + ) + + self.class_prototypes = torch.nn.Parameter( + torch.nn.functional.normalize(torch.randn(self.num_classes, self.encoding_dim)), requires_grad=False + ) + self.class_counts = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=False) + self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) + self.prototype_update_counts = torch.zeros(self.num_classes).to(flair.device) + self.to(flair.device) + + def _validate_parameters(self) -> None: + """Validate the input parameters.""" + assert 0 <= self.alpha <= 1, "alpha must be in the range [0, 1]" + assert self.mean_update_method in [ + "online", + "condensation", + "decay", + ], f"Invalid mean_update_method: {self.mean_update_method}. Must be 'online', 'condensation', or 'decay'" + assert self.encoding_dim > 0, "encoding_dim must be greater than 0" + + def forward(self, sentences: Union[List[Sentence], Sentence]) -> torch.Tensor: + """Encode the input sentences using embeddings and optional encoder. + + Args: + sentences: Input sentence or list of sentences. + + Returns: + torch.Tensor: Encoded representations of the input sentences. + """ + if not isinstance(sentences, list): + sentences = [sentences] + + self.embeddings.embed(sentences) + sentence_embeddings = torch.stack([sentence.get_embedding() for sentence in sentences]) + encoded_embeddings = self.encoder(sentence_embeddings) + + return encoded_embeddings + + def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor: + """Calculate distances between encoded embeddings and class prototypes. + + Args: + encoded_embeddings: Encoded representations of the input sentences. + + Returns: + torch.Tensor: Distances between encoded embeddings and class prototypes. + """ + return torch.cdist(encoded_embeddings, self.class_prototypes) + + def forward_loss(self, data_points: List[Sentence]) -> Tuple[torch.Tensor, int]: + """Compute the loss for a batch of sentences. + + Args: + data_points: A list of sentences. + + Returns: + Tuple[torch.Tensor, int]: The total loss and the number of sentences. + """ + encoded_embeddings = self.forward(data_points) + labels = self._prepare_label_tensor(data_points) + distances = self._calculate_distances(encoded_embeddings) + loss = self.loss_function(-distances, labels) + self._calculate_prototype_updates(encoded_embeddings, labels) + + return loss, len(data_points) + + def _prepare_label_tensor(self, sentences: List[Sentence]) -> torch.Tensor: + """Prepare the label tensor for the given sentences. + + Args: + sentences: A list of sentences. + + Returns: + torch.Tensor: The label tensor for the given sentences. + """ + if self.multi_label: + return torch.tensor( + [ + [ + ( + 1 + if label + in [sentence_label.value for sentence_label in sentence.get_labels(self._label_type)] + else 0 + ) + for label in self.label_dictionary.get_items() + ] + for sentence in sentences + ], + dtype=torch.float, + device=flair.device, + ) + else: + return torch.tensor( + [ + self.label_dictionary.get_idx_for_item(sentence.get_label(self._label_type).value) + for sentence in sentences + ], + dtype=torch.long, + device=flair.device, + ) + + def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: torch.Tensor) -> None: + """Calculate updates for class prototypes based on the current batch. + + Args: + encoded_embeddings: Encoded representations of the input sentences. + labels: True labels for the input sentences. + """ + one_hot = ( + labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_classes).float() + ) + + updates = torch.matmul(one_hot.t(), encoded_embeddings) + counts = one_hot.sum(dim=0) + mask = counts > 0 + self.prototype_updates[mask] += updates[mask] + self.prototype_update_counts[mask] += counts[mask] + + def update_prototypes(self) -> None: + """Apply accumulated updates to class prototypes.""" + with torch.no_grad(): + update_mask = self.prototype_update_counts > 0 + if update_mask.any(): + if self.mean_update_method in ["online", "condensation"]: + new_counts = self.class_counts[update_mask] + self.prototype_update_counts[update_mask] + self.class_prototypes[update_mask] = ( + self.class_counts[update_mask].unsqueeze(1) * self.class_prototypes[update_mask] + + self.prototype_updates[update_mask] + ) / new_counts.unsqueeze(1) + self.class_counts[update_mask] = new_counts + elif self.mean_update_method == "decay": + new_prototypes = self.prototype_updates[update_mask] / self.prototype_update_counts[ + update_mask + ].unsqueeze(1) + self.class_prototypes[update_mask] = ( + self.alpha * self.class_prototypes[update_mask] + (1 - self.alpha) * new_prototypes + ) + self.class_counts[update_mask] += self.prototype_update_counts[update_mask] + + # Reset prototype updates + self.prototype_updates = torch.zeros_like(self.class_prototypes, device=flair.device) + self.prototype_update_counts = torch.zeros(self.num_classes, device=flair.device) + + def predict( + self, + sentences: Union[List[Sentence], Sentence], + mini_batch_size: int = 32, + return_probabilities_for_all_classes: bool = False, + verbose: bool = False, + label_name: Optional[str] = None, + return_loss: bool = False, + embedding_storage_mode: str = "none", + ) -> Union[List[Sentence], Tuple[float, int]]: + """Predict classes for a list of sentences. + + Args: + sentences: A list of sentences or a single sentence. + mini_batch_size: Size of mini batches during prediction. + return_probabilities_for_all_classes: Whether to return probabilities for all classes. + verbose: If True, show progress bar during prediction. + label_name: The name of the label to use for prediction. + return_loss: If True, compute and return loss. + embedding_storage_mode: The mode for storing embeddings ('none', 'cpu', or 'gpu'). + + Returns: + Union[List[Sentence], Tuple[float, int]]: + if return_loss is True, returns a tuple of total loss and total number of sentences; + otherwise, returns the list of sentences with predicted labels. + """ + with torch.no_grad(): + if not isinstance(sentences, list): + sentences = [sentences] + if not sentences: + return sentences + + label_name = label_name or self.label_type + Sentence.set_context_for_sentences(sentences) + + filtered_sentences = [sent for sent in sentences if len(sent) > 0] + reordered_sentences = sorted(filtered_sentences, key=len, reverse=True) + + if len(reordered_sentences) == 0: + return sentences + + dataloader = DataLoader( + dataset=FlairDatapointDataset(reordered_sentences), + batch_size=mini_batch_size, + ) + + if verbose: + progress_bar = tqdm(dataloader) + progress_bar.set_description("Predicting") + dataloader = progress_bar + + total_loss = 0.0 + total_sentences = 0 + + for batch in dataloader: + if not batch: + continue + + encoded_embeddings = self.forward(batch) + distances = self._calculate_distances(encoded_embeddings) + + if self.multi_label: + probabilities = torch.sigmoid(-distances) + else: + probabilities = torch.nn.functional.softmax(-distances, dim=1) + + if return_loss: + labels = self._prepare_label_tensor(batch) + loss = self.loss_function(-distances, labels) + total_loss += loss.item() + total_sentences += len(batch) + + for sentence_index, sentence in enumerate(batch): + sentence.remove_labels(label_name) + + if self.multi_label: + for label_index, probability in enumerate(probabilities[sentence_index]): + if probability > self.multi_label_threshold or return_probabilities_for_all_classes: + label_value = self.label_dictionary.get_item_for_index(label_index) + sentence.add_label(label_name, label_value, probability.item()) + else: + predicted_idx = torch.argmax(probabilities[sentence_index]) + label_value = self.label_dictionary.get_item_for_index(predicted_idx.item()) + sentence.add_label(label_name, label_value, probabilities[sentence_index, predicted_idx].item()) + + if return_probabilities_for_all_classes: + for label_index, probability in enumerate(probabilities[sentence_index]): + label_value = self.label_dictionary.get_item_for_index(label_index) + sentence.add_label(f"{label_name}_all", label_value, probability.item()) + + for sentence in batch: + sentence.clear_embeddings(embedding_storage_mode) + + if return_loss: + return total_loss, total_sentences + return sentences + + def _get_state_dict(self) -> Dict[str, Any]: + """Get the state dictionary of the model. + + Returns: + Dict[str, Any]: The state dictionary containing model parameters and configuration. + """ + model_state = { + "embeddings": self.embeddings.save_embeddings(), + "label_dictionary": self.label_dictionary, + "label_type": self.label_type, + "encoding_dim": self.encoding_dim, + "alpha": self.alpha, + "mean_update_method": self.mean_update_method, + "use_encoder": self.use_encoder, + "multi_label": self.multi_label, + "multi_label_threshold": self.multi_label_threshold, + "class_prototypes": self.class_prototypes.cpu(), + "class_counts": self.class_counts.cpu(), + "encoder": self.encoder.state_dict(), + } + return model_state + + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs) -> "DeepNCMClassifier": + """Initialize the model from a state dictionary. + + Args: + state: The state dictionary containing model parameters and configuration. + **kwargs: Additional keyword arguments for model initialization. + + Returns: + DeepNCMClassifier: An instance of the model initialized with the given state. + """ + embeddings = state["embeddings"] + if isinstance(embeddings, dict): + embeddings = load_embeddings(embeddings) + + model = cls( + embeddings=embeddings, + label_dictionary=state["label_dictionary"], + label_type=state["label_type"], + encoding_dim=state["encoding_dim"], + alpha=state["alpha"], + mean_update_method=state["mean_update_method"], + use_encoder=state["use_encoder"], + multi_label=state.get("multi_label", False), + multi_label_threshold=state.get("multi_label_threshold", 0.5), + **kwargs, + ) + + if "encoder" in state: + model.encoder.load_state_dict(state["encoder"]) + if "class_prototypes" in state: + model.class_prototypes.data = state["class_prototypes"].to(flair.device) + if "class_counts" in state: + model.class_counts.data = state["class_counts"].to(flair.device) + + return model + + def get_prototype(self, class_name: str) -> torch.Tensor: + """Get the prototype vector for a given class name. + + Args: + class_name: The name of the class whose prototype vector is requested. + + Returns: + torch.Tensor: The prototype vector for the given class. + + Raises: + ValueError: If the class name is not found in the label dictionary. + """ + try: + class_idx = self.label_dictionary.get_idx_for_item(class_name) + except IndexError as exc: + raise ValueError(f"Class name '{class_name}' not found in the label dictionary") from exc + + return self.class_prototypes[class_idx].clone() + + def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> List[Tuple[str, float]]: + """Get the top_k closest prototype vectors to the given input vector using the configured distance metric. + + Args: + input_vector (torch.Tensor): The input vector to compare against prototypes. + top_k (int): The number of closest prototypes to return (default is 5). + + Returns: + List[Tuple[str, float]]: Each tuple contains (class_name, distance). + """ + if input_vector.dim() != 1: + raise ValueError("Input vector must be a 1D tensor") + if input_vector.size(0) != self.class_prototypes.size(1): + raise ValueError( + f"Input vector dimension ({input_vector.size(0)}) does not match prototype dimension ({self.class_prototypes.size(1)})" + ) + + input_vector = input_vector.unsqueeze(0) + distances = self._calculate_distances(input_vector) + top_k_values, top_k_indices = torch.topk(distances.squeeze(), k=top_k, largest=False) + + nearest_prototypes = [] + for idx, value in zip(top_k_indices, top_k_values): + class_name = self.label_dictionary.get_item_for_index(idx.item()) + nearest_prototypes.append((class_name, value.item())) + + return nearest_prototypes + + @property + def label_type(self) -> str: + """Get the label type for this classifier.""" + return self._label_type + + def __str__(self) -> str: + """Get a string representation of the model. + + Returns: + str: A string describing the model architecture. + """ + return ( + f"DeepNCMClassifier(\n" + f" (embeddings): {self.embeddings}\n" + f" (encoder): {self.encoder}\n" + f" (prototypes): {self.class_prototypes.shape}\n" + f")" + ) diff --git a/flair/trainers/plugins/__init__.py b/flair/trainers/plugins/__init__.py index 373fdf969b..c3b1c1bab3 100644 --- a/flair/trainers/plugins/__init__.py +++ b/flair/trainers/plugins/__init__.py @@ -1,6 +1,7 @@ from .base import BasePlugin, Pluggable, TrainerPlugin, TrainingInterrupt from .functional.anneal_on_plateau import AnnealingPlugin from .functional.checkpoints import CheckpointPlugin +from .functional.deepncm_trainer_plugin import DeepNCMPlugin from .functional.linear_scheduler import LinearSchedulerPlugin from .functional.reduce_transformer_vocab import ReduceTransformerVocabPlugin from .functional.weight_extractor import WeightExtractorPlugin @@ -15,6 +16,7 @@ "AnnealingPlugin", "CheckpointPlugin", "ClearmlLoggerPlugin", + "DeepNCMPlugin", "LinearSchedulerPlugin", "WeightExtractorPlugin", "LogFilePlugin", diff --git a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py new file mode 100644 index 0000000000..2c4c0ccb49 --- /dev/null +++ b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py @@ -0,0 +1,41 @@ +import torch + +from flair.models import DeepNCMClassifier, MultitaskModel +from flair.trainers.plugins.base import TrainerPlugin + + +class DeepNCMPlugin(TrainerPlugin): + """Plugin for training DeepNCMClassifier. + + Handles both multitask and single-task scenarios. + """ + + def _process_models(self, operation: str): + """Process updates for all DeepNCMClassifier models in the trainer. + + Args: + operation (str): The operation to perform ('condensation' or 'update') + """ + model = self.trainer.model + + models = model.tasks.values() if isinstance(model, MultitaskModel) else [model] + + for sub_model in models: + if isinstance(sub_model, DeepNCMClassifier): + if operation == "condensation" and sub_model.mean_update_method == "condensation": + sub_model.class_counts.data = torch.ones_like(sub_model.class_counts) + elif operation == "update": + sub_model.update_prototypes() + + @TrainerPlugin.hook + def after_training_epoch(self, **kwargs): + """Update prototypes after each training epoch.""" + self._process_models("condensation") + + @TrainerPlugin.hook + def after_training_batch(self, **kwargs): + """Update prototypes after each training batch.""" + self._process_models("update") + + def __str__(self) -> str: + return "DeepNCMPlugin" diff --git a/tests/models/test_deepncm_classifier.py b/tests/models/test_deepncm_classifier.py new file mode 100644 index 0000000000..3b76b6c0b9 --- /dev/null +++ b/tests/models/test_deepncm_classifier.py @@ -0,0 +1,167 @@ +import pytest +import torch + +from flair.data import Sentence +from flair.datasets import ClassificationCorpus +from flair.embeddings import TransformerDocumentEmbeddings +from flair.models import DeepNCMClassifier +from flair.trainers import ModelTrainer +from flair.trainers.plugins import DeepNCMPlugin +from tests.model_test_utils import BaseModelTest + + +class TestDeepNCMClassifier(BaseModelTest): + model_cls = DeepNCMClassifier + train_label_type = "class" + multiclass_prediction_labels = ["POSITIVE", "NEGATIVE"] + training_args = { + "max_epochs": 2, + "mini_batch_size": 4, + "learning_rate": 1e-5, + } + + @pytest.fixture() + def embeddings(self): + return TransformerDocumentEmbeddings("distilbert-base-uncased", fine_tune=True) + + @pytest.fixture() + def corpus(self, tasks_base_path): + return ClassificationCorpus(tasks_base_path / "imdb", label_type=self.train_label_type) + + @pytest.fixture() + def multiclass_train_test_sentence(self): + return Sentence("This movie was great!") + + def build_model(self, embeddings, label_dict, **kwargs): + model_args = { + "embeddings": embeddings, + "label_dictionary": label_dict, + "label_type": self.train_label_type, + "use_encoder": False, + "encoding_dim": 64, + "alpha": 0.95, + } + model_args.update(kwargs) + return self.model_cls(**model_args) + + @pytest.mark.integration() + def test_train_load_use_classifier( + self, results_base_path, corpus, embeddings, example_sentence, train_test_sentence + ): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + + model = self.build_model(embeddings, label_dict, mean_update_method="condensation") + + trainer = ModelTrainer(model, corpus) + trainer.fine_tune( + results_base_path, optimizer=torch.optim.AdamW, plugins=[DeepNCMPlugin()], **self.training_args + ) + + model.predict(train_test_sentence) + + for label in train_test_sentence.get_labels(self.train_label_type): + assert label.value is not None + assert 0.0 <= label.score <= 1.0 + assert isinstance(label.score, float) + + del trainer, model, corpus + + loaded_model = self.model_cls.load(results_base_path / "final-model.pt") + + loaded_model.predict(example_sentence) + loaded_model.predict([example_sentence, self.empty_sentence]) + loaded_model.predict([self.empty_sentence]) + + def test_get_prototype(self, corpus, embeddings): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict) + + prototype = model.get_prototype(next(iter(label_dict.get_items()))) + assert isinstance(prototype, torch.Tensor) + assert prototype.shape == (model.encoding_dim,) + + with pytest.raises(ValueError): + model.get_prototype("NON_EXISTENT_CLASS") + + def test_get_closest_prototypes(self, corpus, embeddings): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict) + input_vector = torch.randn(model.encoding_dim) + closest_prototypes = model.get_closest_prototypes(input_vector, top_k=2) + + assert len(closest_prototypes) == 2 + assert all(isinstance(item, tuple) and len(item) == 2 for item in closest_prototypes) + + with pytest.raises(ValueError): + model.get_closest_prototypes(torch.randn(model.encoding_dim + 1)) + + def test_forward_loss(self, corpus, embeddings): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict) + + sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] + for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): + sentence.add_label(self.train_label_type, label) + + loss, count = model.forward_loss(sentences) + assert isinstance(loss, torch.Tensor) + assert loss.item() > 0 + assert count == len(sentences) + + @pytest.mark.parametrize("mean_update_method", ["online", "condensation", "decay"]) + def test_mean_update_methods(self, corpus, embeddings, mean_update_method): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict, mean_update_method=mean_update_method) + + initial_prototypes = model.class_prototypes.clone() + + sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] + for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): + sentence.add_label(self.train_label_type, label) + + model.forward_loss(sentences) + model.update_prototypes() + + assert not torch.all(torch.eq(initial_prototypes, model.class_prototypes)) + + @pytest.mark.parametrize("mean_update_method", ["online", "condensation", "decay"]) + def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict, mean_update_method=mean_update_method) + + trainer = ModelTrainer(model, corpus) + plugin = DeepNCMPlugin() + plugin.attach_to(trainer) + + initial_class_counts = model.class_counts.clone() + initial_prototypes = model.class_prototypes.clone() + + # Simulate training epoch + plugin.after_training_epoch() + + if mean_update_method == "condensation": + assert torch.all(model.class_counts == 1), "Class counts should be 1 for condensation method after epoch" + elif mean_update_method == "online": + assert torch.all( + torch.eq(model.class_counts, initial_class_counts) + ), "Class counts should not change for online method after epoch" + + # Simulate training batch + sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] + for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): + sentence.add_label(self.train_label_type, label) + model.forward_loss(sentences) + plugin.after_training_batch() + + assert not torch.all( + torch.eq(initial_prototypes, model.class_prototypes) + ), "Prototypes should be updated after a batch" + + if mean_update_method == "condensation": + assert torch.all( + model.class_counts >= 1 + ), "Class counts should be >= 1 for condensation method after a batch" + elif mean_update_method == "online": + assert torch.all( + model.class_counts > initial_class_counts + ), "Class counts should increase for online method after a batch" From 213396c762492d9ffc77e82a234eb0eb1ecebf42 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 8 Nov 2024 17:45:23 -0800 Subject: [PATCH 2/7] feat: change DeepNCM classifier to a decoder so it can be used with different model types. make small changes to DefaultClassifier forward_loss to pass label tensor when needed. update tests --- flair/models/__init__.py | 4 +- flair/models/deepncm_classification_model.py | 328 +++--------------- flair/nn/model.py | 11 +- .../functional/deepncm_trainer_plugin.py | 13 +- tests/models/test_deepncm_classifier.py | 61 ++-- 5 files changed, 103 insertions(+), 314 deletions(-) diff --git a/flair/models/__init__.py b/flair/models/__init__.py index bf3651078a..d9fca4a706 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,4 +1,4 @@ -from .deepncm_classification_model import DeepNCMClassifier +from .deepncm_classification_model import DeepNCMDecoder from .entity_linker_model import SpanClassifier from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel @@ -38,5 +38,5 @@ "TextClassifier", "TextRegressor", "MultitaskModel", - "DeepNCMClassifier", + "DeepNCMDecoder", ] diff --git a/flair/models/deepncm_classification_model.py b/flair/models/deepncm_classification_model.py index b942e28919..ec3385a78a 100644 --- a/flair/models/deepncm_classification_model.py +++ b/flair/models/deepncm_classification_model.py @@ -1,20 +1,15 @@ import logging -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Literal, Optional import torch -from tqdm import tqdm import flair -from flair.data import Dictionary, Sentence -from flair.datasets import DataLoader, FlairDatapointDataset -from flair.embeddings import DocumentEmbeddings -from flair.embeddings.base import load_embeddings -from flair.nn import Classifier +from flair.data import Dictionary log = logging.getLogger("flair") -class DeepNCMClassifier(Classifier[Sentence]): +class DeepNCMDecoder(torch.nn.Module): """Deep Nearest Class Mean (DeepNCM) Classifier for text classification tasks. This model combines deep learning with the Nearest Class Mean (NCM) approach. @@ -32,47 +27,50 @@ class DeepNCMClassifier(Classifier[Sentence]): def __init__( self, - embeddings: DocumentEmbeddings, label_dictionary: Dictionary, - label_type: str, + embeddings_size: int, encoding_dim: Optional[int] = None, alpha: float = 0.9, mean_update_method: Literal["online", "condensation", "decay"] = "online", use_encoder: bool = True, - multi_label: bool = False, - multi_label_threshold: float = 0.5, - ): - """Initialize a DeepNCMClassifier. + multi_label: bool = False, # should get from the Model it belongs to + ) -> None: + """Initialize a DeepNCMDecoder. Args: - embeddings: Document embeddings to use for encoding text. - label_dictionary: Dictionary containing the label vocabulary. - label_type: The type of label to predict. encoding_dim: The dimensionality of the encoded embeddings (default is the same as the input embeddings). - alpha: The decay factor for updating class prototypes (default is 0.9). + alpha: The decay factor for updating class prototypes (default is 0.9). This only applies when mean_update_method is 'decay'. mean_update_method: The method for updating class prototypes ('online', 'condensation', or 'decay'). use_encoder: Whether to apply an encoder to the input embeddings (default is True). multi_label: Whether to predict multiple labels per sentence (default is False). - multi_label_threshold: The threshold for multi-label prediction (default is 0.5). """ + super().__init__() - self.embeddings = embeddings self.label_dictionary = label_dictionary - self._label_type = label_type + self._num_prototypes = len(label_dictionary) + self.alpha = alpha self.mean_update_method = mean_update_method self.use_encoder = use_encoder self.multi_label = multi_label - self.multi_label_threshold = multi_label_threshold - self.num_classes = len(label_dictionary) - self.embedding_dim = embeddings.embedding_length + + self.embedding_dim = embeddings_size if use_encoder: self.encoding_dim = encoding_dim or self.embedding_dim else: self.encoding_dim = self.embedding_dim + self.class_prototypes = torch.nn.Parameter( + torch.nn.functional.normalize(torch.randn(self._num_prototypes, self.encoding_dim)), requires_grad=False + ) + + self.class_counts = torch.nn.Parameter(torch.zeros(self._num_prototypes), requires_grad=False) + self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) + self.prototype_update_counts = torch.zeros(self._num_prototypes).to(flair.device) + self.to(flair.device) + self._validate_parameters() if self.use_encoder: @@ -84,22 +82,11 @@ def __init__( else: self.encoder = torch.nn.Sequential(torch.nn.Identity()) - self.loss_function = ( - torch.nn.BCEWithLogitsLoss(reduction="sum") - if self.multi_label - else torch.nn.CrossEntropyLoss(reduction="sum") - ) - - self.class_prototypes = torch.nn.Parameter( - torch.nn.functional.normalize(torch.randn(self.num_classes, self.encoding_dim)), requires_grad=False - ) - self.class_counts = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=False) - self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) - self.prototype_update_counts = torch.zeros(self.num_classes).to(flair.device) + # all parameters will be pushed internally to the specified device self.to(flair.device) def _validate_parameters(self) -> None: - """Validate the input parameters.""" + """Validate that the input parameters have valid and compatible values.""" assert 0 <= self.alpha <= 1, "alpha must be in the range [0, 1]" assert self.mean_update_method in [ "online", @@ -108,26 +95,13 @@ def _validate_parameters(self) -> None: ], f"Invalid mean_update_method: {self.mean_update_method}. Must be 'online', 'condensation', or 'decay'" assert self.encoding_dim > 0, "encoding_dim must be greater than 0" - def forward(self, sentences: Union[List[Sentence], Sentence]) -> torch.Tensor: - """Encode the input sentences using embeddings and optional encoder. - - Args: - sentences: Input sentence or list of sentences. - - Returns: - torch.Tensor: Encoded representations of the input sentences. - """ - if not isinstance(sentences, list): - sentences = [sentences] - - self.embeddings.embed(sentences) - sentence_embeddings = torch.stack([sentence.get_embedding() for sentence in sentences]) - encoded_embeddings = self.encoder(sentence_embeddings) - - return encoded_embeddings + @property + def num_prototypes(self) -> int: + """The number of class prototypes.""" + return self.class_prototypes.size(0) def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor: - """Calculate distances between encoded embeddings and class prototypes. + """Calculate the squared Euclidean distance between encoded embeddings and class prototypes. Args: encoded_embeddings: Encoded representations of the input sentences. @@ -135,60 +109,7 @@ def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor Returns: torch.Tensor: Distances between encoded embeddings and class prototypes. """ - return torch.cdist(encoded_embeddings, self.class_prototypes) - - def forward_loss(self, data_points: List[Sentence]) -> Tuple[torch.Tensor, int]: - """Compute the loss for a batch of sentences. - - Args: - data_points: A list of sentences. - - Returns: - Tuple[torch.Tensor, int]: The total loss and the number of sentences. - """ - encoded_embeddings = self.forward(data_points) - labels = self._prepare_label_tensor(data_points) - distances = self._calculate_distances(encoded_embeddings) - loss = self.loss_function(-distances, labels) - self._calculate_prototype_updates(encoded_embeddings, labels) - - return loss, len(data_points) - - def _prepare_label_tensor(self, sentences: List[Sentence]) -> torch.Tensor: - """Prepare the label tensor for the given sentences. - - Args: - sentences: A list of sentences. - - Returns: - torch.Tensor: The label tensor for the given sentences. - """ - if self.multi_label: - return torch.tensor( - [ - [ - ( - 1 - if label - in [sentence_label.value for sentence_label in sentence.get_labels(self._label_type)] - else 0 - ) - for label in self.label_dictionary.get_items() - ] - for sentence in sentences - ], - dtype=torch.float, - device=flair.device, - ) - else: - return torch.tensor( - [ - self.label_dictionary.get_idx_for_item(sentence.get_label(self._label_type).value) - for sentence in sentences - ], - dtype=torch.long, - device=flair.device, - ) + return torch.cdist(encoded_embeddings, self.class_prototypes).pow(2) def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: torch.Tensor) -> None: """Calculate updates for class prototypes based on the current batch. @@ -198,7 +119,7 @@ def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: labels: True labels for the input sentences. """ one_hot = ( - labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_classes).float() + labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_prototypes).float() ) updates = torch.matmul(one_hot.t(), encoded_embeddings) @@ -230,163 +151,25 @@ def update_prototypes(self) -> None: # Reset prototype updates self.prototype_updates = torch.zeros_like(self.class_prototypes, device=flair.device) - self.prototype_update_counts = torch.zeros(self.num_classes, device=flair.device) - - def predict( - self, - sentences: Union[List[Sentence], Sentence], - mini_batch_size: int = 32, - return_probabilities_for_all_classes: bool = False, - verbose: bool = False, - label_name: Optional[str] = None, - return_loss: bool = False, - embedding_storage_mode: str = "none", - ) -> Union[List[Sentence], Tuple[float, int]]: - """Predict classes for a list of sentences. - - Args: - sentences: A list of sentences or a single sentence. - mini_batch_size: Size of mini batches during prediction. - return_probabilities_for_all_classes: Whether to return probabilities for all classes. - verbose: If True, show progress bar during prediction. - label_name: The name of the label to use for prediction. - return_loss: If True, compute and return loss. - embedding_storage_mode: The mode for storing embeddings ('none', 'cpu', or 'gpu'). - - Returns: - Union[List[Sentence], Tuple[float, int]]: - if return_loss is True, returns a tuple of total loss and total number of sentences; - otherwise, returns the list of sentences with predicted labels. - """ - with torch.no_grad(): - if not isinstance(sentences, list): - sentences = [sentences] - if not sentences: - return sentences - - label_name = label_name or self.label_type - Sentence.set_context_for_sentences(sentences) - - filtered_sentences = [sent for sent in sentences if len(sent) > 0] - reordered_sentences = sorted(filtered_sentences, key=len, reverse=True) - - if len(reordered_sentences) == 0: - return sentences - - dataloader = DataLoader( - dataset=FlairDatapointDataset(reordered_sentences), - batch_size=mini_batch_size, - ) - - if verbose: - progress_bar = tqdm(dataloader) - progress_bar.set_description("Predicting") - dataloader = progress_bar - - total_loss = 0.0 - total_sentences = 0 - - for batch in dataloader: - if not batch: - continue - - encoded_embeddings = self.forward(batch) - distances = self._calculate_distances(encoded_embeddings) - - if self.multi_label: - probabilities = torch.sigmoid(-distances) - else: - probabilities = torch.nn.functional.softmax(-distances, dim=1) + self.prototype_update_counts = torch.zeros(self.num_prototypes, device=flair.device) - if return_loss: - labels = self._prepare_label_tensor(batch) - loss = self.loss_function(-distances, labels) - total_loss += loss.item() - total_sentences += len(batch) + def forward(self, embedded: torch.Tensor, label_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of the decoder, which calculates the scores as prototype distances. - for sentence_index, sentence in enumerate(batch): - sentence.remove_labels(label_name) - - if self.multi_label: - for label_index, probability in enumerate(probabilities[sentence_index]): - if probability > self.multi_label_threshold or return_probabilities_for_all_classes: - label_value = self.label_dictionary.get_item_for_index(label_index) - sentence.add_label(label_name, label_value, probability.item()) - else: - predicted_idx = torch.argmax(probabilities[sentence_index]) - label_value = self.label_dictionary.get_item_for_index(predicted_idx.item()) - sentence.add_label(label_name, label_value, probabilities[sentence_index, predicted_idx].item()) - - if return_probabilities_for_all_classes: - for label_index, probability in enumerate(probabilities[sentence_index]): - label_value = self.label_dictionary.get_item_for_index(label_index) - sentence.add_label(f"{label_name}_all", label_value, probability.item()) - - for sentence in batch: - sentence.clear_embeddings(embedding_storage_mode) - - if return_loss: - return total_loss, total_sentences - return sentences - - def _get_state_dict(self) -> Dict[str, Any]: - """Get the state dictionary of the model. - - Returns: - Dict[str, Any]: The state dictionary containing model parameters and configuration. + :param embedded: Embedded representations of the input sentences. + :param label_tensor: True labels for the input sentences as a tensor. + :return: Scores as a tensor of distances to class prototypes. """ - model_state = { - "embeddings": self.embeddings.save_embeddings(), - "label_dictionary": self.label_dictionary, - "label_type": self.label_type, - "encoding_dim": self.encoding_dim, - "alpha": self.alpha, - "mean_update_method": self.mean_update_method, - "use_encoder": self.use_encoder, - "multi_label": self.multi_label, - "multi_label_threshold": self.multi_label_threshold, - "class_prototypes": self.class_prototypes.cpu(), - "class_counts": self.class_counts.cpu(), - "encoder": self.encoder.state_dict(), - } - return model_state - - @classmethod - def _init_model_with_state_dict(cls, state, **kwargs) -> "DeepNCMClassifier": - """Initialize the model from a state dictionary. + encoded_embeddings = self.encoder(embedded) - Args: - state: The state dictionary containing model parameters and configuration. - **kwargs: Additional keyword arguments for model initialization. + distances = self._calculate_distances(encoded_embeddings) - Returns: - DeepNCMClassifier: An instance of the model initialized with the given state. - """ - embeddings = state["embeddings"] - if isinstance(embeddings, dict): - embeddings = load_embeddings(embeddings) - - model = cls( - embeddings=embeddings, - label_dictionary=state["label_dictionary"], - label_type=state["label_type"], - encoding_dim=state["encoding_dim"], - alpha=state["alpha"], - mean_update_method=state["mean_update_method"], - use_encoder=state["use_encoder"], - multi_label=state.get("multi_label", False), - multi_label_threshold=state.get("multi_label_threshold", 0.5), - **kwargs, - ) + if label_tensor is not None: + self._calculate_prototype_updates(encoded_embeddings, label_tensor) - if "encoder" in state: - model.encoder.load_state_dict(state["encoder"]) - if "class_prototypes" in state: - model.class_prototypes.data = state["class_prototypes"].to(flair.device) - if "class_counts" in state: - model.class_counts.data = state["class_counts"].to(flair.device) + scores = -distances - return model + return scores def get_prototype(self, class_name: str) -> torch.Tensor: """Get the prototype vector for a given class name. @@ -407,15 +190,15 @@ def get_prototype(self, class_name: str) -> torch.Tensor: return self.class_prototypes[class_idx].clone() - def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> List[Tuple[str, float]]: - """Get the top_k closest prototype vectors to the given input vector using the configured distance metric. + def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> list[tuple[str, float]]: + """Get the k closest prototype vectors to the given input vector using the configured distance metric. Args: input_vector (torch.Tensor): The input vector to compare against prototypes. top_k (int): The number of closest prototypes to return (default is 5). Returns: - List[Tuple[str, float]]: Each tuple contains (class_name, distance). + list[tuple[str, float]]: Each tuple contains (class_name, distance). """ if input_vector.dim() != 1: raise ValueError("Input vector must be a 1D tensor") @@ -434,22 +217,3 @@ def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> nearest_prototypes.append((class_name, value.item())) return nearest_prototypes - - @property - def label_type(self) -> str: - """Get the label type for this classifier.""" - return self._label_type - - def __str__(self) -> str: - """Get a string representation of the model. - - Returns: - str: A string describing the model architecture. - """ - return ( - f"DeepNCMClassifier(\n" - f" (embeddings): {self.embeddings}\n" - f" (encoder): {self.encoder}\n" - f" (prototypes): {self.class_prototypes.shape}\n" - f")" - ) diff --git a/flair/nn/model.py b/flair/nn/model.py index 03834afc76..69c51f7a5e 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import Counter from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, List, Optional, Tuple, Union import torch.nn from torch import Tensor @@ -778,8 +778,11 @@ def forward_loss(self, sentences: list[DT]) -> tuple[torch.Tensor, int]: # pass data points through network to get encoded data point tensor data_point_tensor = self._encode_data_points(sentences, data_points) - # decode - scores = self.decoder(data_point_tensor) + # decode, passing label tensor if needed, such as for prototype updates + if "label_tensor" in inspect.signature(self.decoder.forward).parameters: + scores = self.decoder(data_point_tensor, label_tensor) + else: + scores = self.decoder(data_point_tensor) # an optional masking step (no masking in most cases) scores = self._mask_scores(scores, data_points) @@ -814,7 +817,7 @@ def predict( label_name: Optional[str] = None, return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", - ): + ) -> Optional[Union[List[DT], Tuple[float, int]]]: """Predicts the class labels for the given sentences. The labels are directly added to the sentences. Args: diff --git a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py index 2c4c0ccb49..e5394debd2 100644 --- a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py +++ b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py @@ -1,6 +1,7 @@ import torch -from flair.models import DeepNCMClassifier, MultitaskModel +from flair.models import MultitaskModel +from flair.models.deepncm_classification_model import DeepNCMDecoder from flair.trainers.plugins.base import TrainerPlugin @@ -11,7 +12,7 @@ class DeepNCMPlugin(TrainerPlugin): """ def _process_models(self, operation: str): - """Process updates for all DeepNCMClassifier models in the trainer. + """Process updates for all DeepNCMDecoder decoders in the trainer. Args: operation (str): The operation to perform ('condensation' or 'update') @@ -21,11 +22,11 @@ def _process_models(self, operation: str): models = model.tasks.values() if isinstance(model, MultitaskModel) else [model] for sub_model in models: - if isinstance(sub_model, DeepNCMClassifier): - if operation == "condensation" and sub_model.mean_update_method == "condensation": - sub_model.class_counts.data = torch.ones_like(sub_model.class_counts) + if hasattr(sub_model, "decoder") and isinstance(sub_model.decoder, DeepNCMDecoder): + if operation == "condensation" and sub_model.decoder.mean_update_method == "condensation": + sub_model.decoder.class_counts.data = torch.ones_like(sub_model.decoder.class_counts) elif operation == "update": - sub_model.update_prototypes() + sub_model.decoder.update_prototypes() @TrainerPlugin.hook def after_training_epoch(self, **kwargs): diff --git a/tests/models/test_deepncm_classifier.py b/tests/models/test_deepncm_classifier.py index 3b76b6c0b9..b587a33142 100644 --- a/tests/models/test_deepncm_classifier.py +++ b/tests/models/test_deepncm_classifier.py @@ -4,14 +4,14 @@ from flair.data import Sentence from flair.datasets import ClassificationCorpus from flair.embeddings import TransformerDocumentEmbeddings -from flair.models import DeepNCMClassifier +from flair.models import DeepNCMDecoder, TextClassifier from flair.trainers import ModelTrainer from flair.trainers.plugins import DeepNCMPlugin from tests.model_test_utils import BaseModelTest -class TestDeepNCMClassifier(BaseModelTest): - model_cls = DeepNCMClassifier +class TestDeepNCMDecoder(BaseModelTest): + model_cls = TextClassifier train_label_type = "class" multiclass_prediction_labels = ["POSITIVE", "NEGATIVE"] training_args = { @@ -33,6 +33,7 @@ def multiclass_train_test_sentence(self): return Sentence("This movie was great!") def build_model(self, embeddings, label_dict, **kwargs): + model_args = { "embeddings": embeddings, "label_dictionary": label_dict, @@ -40,9 +41,27 @@ def build_model(self, embeddings, label_dict, **kwargs): "use_encoder": False, "encoding_dim": 64, "alpha": 0.95, + "mean_update_method": "online", } model_args.update(kwargs) - return self.model_cls(**model_args) + + deepncm_decoder = DeepNCMDecoder( + label_dictionary=model_args["label_dictionary"], + embeddings_size=model_args["embeddings"].embedding_length, + alpha=model_args["alpha"], + encoding_dim=model_args["encoding_dim"], + mean_update_method=model_args["mean_update_method"], + ) + + model = self.model_cls( + embeddings=model_args["embeddings"], + label_dictionary=model_args["label_dictionary"], + label_type=model_args["label_type"], + multi_label=model_args.get("multi_label", False), + decoder=deepncm_decoder, + ) + + return model @pytest.mark.integration() def test_train_load_use_classifier( @@ -76,24 +95,24 @@ def test_get_prototype(self, corpus, embeddings): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) model = self.build_model(embeddings, label_dict) - prototype = model.get_prototype(next(iter(label_dict.get_items()))) + prototype = model.decoder.get_prototype(next(iter(label_dict.get_items()))) assert isinstance(prototype, torch.Tensor) - assert prototype.shape == (model.encoding_dim,) + assert prototype.shape == (model.decoder.encoding_dim,) with pytest.raises(ValueError): - model.get_prototype("NON_EXISTENT_CLASS") + model.decoder.get_prototype("NON_EXISTENT_CLASS") def test_get_closest_prototypes(self, corpus, embeddings): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) model = self.build_model(embeddings, label_dict) - input_vector = torch.randn(model.encoding_dim) - closest_prototypes = model.get_closest_prototypes(input_vector, top_k=2) + input_vector = torch.randn(model.decoder.encoding_dim) + closest_prototypes = model.decoder.get_closest_prototypes(input_vector, top_k=2) assert len(closest_prototypes) == 2 assert all(isinstance(item, tuple) and len(item) == 2 for item in closest_prototypes) with pytest.raises(ValueError): - model.get_closest_prototypes(torch.randn(model.encoding_dim + 1)) + model.decoder.get_closest_prototypes(torch.randn(model.decoder.encoding_dim + 1)) def test_forward_loss(self, corpus, embeddings): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) @@ -113,16 +132,16 @@ def test_mean_update_methods(self, corpus, embeddings, mean_update_method): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) model = self.build_model(embeddings, label_dict, mean_update_method=mean_update_method) - initial_prototypes = model.class_prototypes.clone() + initial_prototypes = model.decoder.class_prototypes.clone() sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): sentence.add_label(self.train_label_type, label) model.forward_loss(sentences) - model.update_prototypes() + model.decoder.update_prototypes() - assert not torch.all(torch.eq(initial_prototypes, model.class_prototypes)) + assert not torch.all(torch.eq(initial_prototypes, model.decoder.class_prototypes)) @pytest.mark.parametrize("mean_update_method", ["online", "condensation", "decay"]) def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): @@ -133,17 +152,19 @@ def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): plugin = DeepNCMPlugin() plugin.attach_to(trainer) - initial_class_counts = model.class_counts.clone() - initial_prototypes = model.class_prototypes.clone() + initial_class_counts = model.decoder.class_counts.clone() + initial_prototypes = model.decoder.class_prototypes.clone() # Simulate training epoch plugin.after_training_epoch() if mean_update_method == "condensation": - assert torch.all(model.class_counts == 1), "Class counts should be 1 for condensation method after epoch" + assert torch.all( + model.decoder.class_counts == 1 + ), "Class counts should be 1 for condensation method after epoch" elif mean_update_method == "online": assert torch.all( - torch.eq(model.class_counts, initial_class_counts) + torch.eq(model.decoder.class_counts, initial_class_counts) ), "Class counts should not change for online method after epoch" # Simulate training batch @@ -154,14 +175,14 @@ def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): plugin.after_training_batch() assert not torch.all( - torch.eq(initial_prototypes, model.class_prototypes) + torch.eq(initial_prototypes, model.decoder.class_prototypes) ), "Prototypes should be updated after a batch" if mean_update_method == "condensation": assert torch.all( - model.class_counts >= 1 + model.decoder.class_counts >= 1 ), "Class counts should be >= 1 for condensation method after a batch" elif mean_update_method == "online": assert torch.all( - model.class_counts > initial_class_counts + model.decoder.class_counts > initial_class_counts ), "Class counts should increase for online method after a batch" From 649e68dfbfb509b2f38d14aafb215ade283b0364 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Wed, 18 Dec 2024 15:19:54 -0500 Subject: [PATCH 3/7] refactor: move DeepNCMDecoder to decoder.py --- flair/models/__init__.py | 2 - flair/models/deepncm_classification_model.py | 208 ----------------- flair/nn/__init__.py | 3 +- flair/nn/decoder.py | 217 +++++++++++++++++- .../functional/deepncm_trainer_plugin.py | 2 +- tests/models/test_deepncm_classifier.py | 3 +- 6 files changed, 221 insertions(+), 214 deletions(-) diff --git a/flair/models/__init__.py b/flair/models/__init__.py index d9fca4a706..e75daf074b 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,4 +1,3 @@ -from .deepncm_classification_model import DeepNCMDecoder from .entity_linker_model import SpanClassifier from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel @@ -38,5 +37,4 @@ "TextClassifier", "TextRegressor", "MultitaskModel", - "DeepNCMDecoder", ] diff --git a/flair/models/deepncm_classification_model.py b/flair/models/deepncm_classification_model.py index ec3385a78a..be1b5788a0 100644 --- a/flair/models/deepncm_classification_model.py +++ b/flair/models/deepncm_classification_model.py @@ -9,211 +9,3 @@ log = logging.getLogger("flair") -class DeepNCMDecoder(torch.nn.Module): - """Deep Nearest Class Mean (DeepNCM) Classifier for text classification tasks. - - This model combines deep learning with the Nearest Class Mean (NCM) approach. - It uses document embeddings to represent text, optionally applies an encoder, - and classifies based on the nearest class prototype in the embedded space. - - The model supports various methods for updating class prototypes during training, - making it adaptable to different learning scenarios. - - This implementation is based on the research paper: - Guerriero, S., Caputo, B., & Mensink, T. (2018). DeepNCM: Deep Nearest Class Mean Classifiers. - In International Conference on Learning Representations (ICLR) 2018 Workshop. - URL: https://openreview.net/forum?id=rkPLZ4JPM - """ - - def __init__( - self, - label_dictionary: Dictionary, - embeddings_size: int, - encoding_dim: Optional[int] = None, - alpha: float = 0.9, - mean_update_method: Literal["online", "condensation", "decay"] = "online", - use_encoder: bool = True, - multi_label: bool = False, # should get from the Model it belongs to - ) -> None: - """Initialize a DeepNCMDecoder. - - Args: - encoding_dim: The dimensionality of the encoded embeddings (default is the same as the input embeddings). - alpha: The decay factor for updating class prototypes (default is 0.9). This only applies when mean_update_method is 'decay'. - mean_update_method: The method for updating class prototypes ('online', 'condensation', or 'decay'). - use_encoder: Whether to apply an encoder to the input embeddings (default is True). - multi_label: Whether to predict multiple labels per sentence (default is False). - """ - - super().__init__() - - self.label_dictionary = label_dictionary - self._num_prototypes = len(label_dictionary) - - self.alpha = alpha - self.mean_update_method = mean_update_method - self.use_encoder = use_encoder - self.multi_label = multi_label - - self.embedding_dim = embeddings_size - - if use_encoder: - self.encoding_dim = encoding_dim or self.embedding_dim - else: - self.encoding_dim = self.embedding_dim - - self.class_prototypes = torch.nn.Parameter( - torch.nn.functional.normalize(torch.randn(self._num_prototypes, self.encoding_dim)), requires_grad=False - ) - - self.class_counts = torch.nn.Parameter(torch.zeros(self._num_prototypes), requires_grad=False) - self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) - self.prototype_update_counts = torch.zeros(self._num_prototypes).to(flair.device) - self.to(flair.device) - - self._validate_parameters() - - if self.use_encoder: - self.encoder = torch.nn.Sequential( - torch.nn.Linear(self.embedding_dim, self.encoding_dim * 2), - torch.nn.ReLU(), - torch.nn.Linear(self.encoding_dim * 2, self.encoding_dim), - ) - else: - self.encoder = torch.nn.Sequential(torch.nn.Identity()) - - # all parameters will be pushed internally to the specified device - self.to(flair.device) - - def _validate_parameters(self) -> None: - """Validate that the input parameters have valid and compatible values.""" - assert 0 <= self.alpha <= 1, "alpha must be in the range [0, 1]" - assert self.mean_update_method in [ - "online", - "condensation", - "decay", - ], f"Invalid mean_update_method: {self.mean_update_method}. Must be 'online', 'condensation', or 'decay'" - assert self.encoding_dim > 0, "encoding_dim must be greater than 0" - - @property - def num_prototypes(self) -> int: - """The number of class prototypes.""" - return self.class_prototypes.size(0) - - def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor: - """Calculate the squared Euclidean distance between encoded embeddings and class prototypes. - - Args: - encoded_embeddings: Encoded representations of the input sentences. - - Returns: - torch.Tensor: Distances between encoded embeddings and class prototypes. - """ - return torch.cdist(encoded_embeddings, self.class_prototypes).pow(2) - - def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: torch.Tensor) -> None: - """Calculate updates for class prototypes based on the current batch. - - Args: - encoded_embeddings: Encoded representations of the input sentences. - labels: True labels for the input sentences. - """ - one_hot = ( - labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_prototypes).float() - ) - - updates = torch.matmul(one_hot.t(), encoded_embeddings) - counts = one_hot.sum(dim=0) - mask = counts > 0 - self.prototype_updates[mask] += updates[mask] - self.prototype_update_counts[mask] += counts[mask] - - def update_prototypes(self) -> None: - """Apply accumulated updates to class prototypes.""" - with torch.no_grad(): - update_mask = self.prototype_update_counts > 0 - if update_mask.any(): - if self.mean_update_method in ["online", "condensation"]: - new_counts = self.class_counts[update_mask] + self.prototype_update_counts[update_mask] - self.class_prototypes[update_mask] = ( - self.class_counts[update_mask].unsqueeze(1) * self.class_prototypes[update_mask] - + self.prototype_updates[update_mask] - ) / new_counts.unsqueeze(1) - self.class_counts[update_mask] = new_counts - elif self.mean_update_method == "decay": - new_prototypes = self.prototype_updates[update_mask] / self.prototype_update_counts[ - update_mask - ].unsqueeze(1) - self.class_prototypes[update_mask] = ( - self.alpha * self.class_prototypes[update_mask] + (1 - self.alpha) * new_prototypes - ) - self.class_counts[update_mask] += self.prototype_update_counts[update_mask] - - # Reset prototype updates - self.prototype_updates = torch.zeros_like(self.class_prototypes, device=flair.device) - self.prototype_update_counts = torch.zeros(self.num_prototypes, device=flair.device) - - def forward(self, embedded: torch.Tensor, label_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: - """Forward pass of the decoder, which calculates the scores as prototype distances. - - :param embedded: Embedded representations of the input sentences. - :param label_tensor: True labels for the input sentences as a tensor. - :return: Scores as a tensor of distances to class prototypes. - """ - encoded_embeddings = self.encoder(embedded) - - distances = self._calculate_distances(encoded_embeddings) - - if label_tensor is not None: - self._calculate_prototype_updates(encoded_embeddings, label_tensor) - - scores = -distances - - return scores - - def get_prototype(self, class_name: str) -> torch.Tensor: - """Get the prototype vector for a given class name. - - Args: - class_name: The name of the class whose prototype vector is requested. - - Returns: - torch.Tensor: The prototype vector for the given class. - - Raises: - ValueError: If the class name is not found in the label dictionary. - """ - try: - class_idx = self.label_dictionary.get_idx_for_item(class_name) - except IndexError as exc: - raise ValueError(f"Class name '{class_name}' not found in the label dictionary") from exc - - return self.class_prototypes[class_idx].clone() - - def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> list[tuple[str, float]]: - """Get the k closest prototype vectors to the given input vector using the configured distance metric. - - Args: - input_vector (torch.Tensor): The input vector to compare against prototypes. - top_k (int): The number of closest prototypes to return (default is 5). - - Returns: - list[tuple[str, float]]: Each tuple contains (class_name, distance). - """ - if input_vector.dim() != 1: - raise ValueError("Input vector must be a 1D tensor") - if input_vector.size(0) != self.class_prototypes.size(1): - raise ValueError( - f"Input vector dimension ({input_vector.size(0)}) does not match prototype dimension ({self.class_prototypes.size(1)})" - ) - - input_vector = input_vector.unsqueeze(0) - distances = self._calculate_distances(input_vector) - top_k_values, top_k_indices = torch.topk(distances.squeeze(), k=top_k, largest=False) - - nearest_prototypes = [] - for idx, value in zip(top_k_indices, top_k_values): - class_name = self.label_dictionary.get_item_for_index(idx.item()) - nearest_prototypes.append((class_name, value.item())) - - return nearest_prototypes diff --git a/flair/nn/__init__.py b/flair/nn/__init__.py index 1ceae91859..9ced1753c1 100644 --- a/flair/nn/__init__.py +++ b/flair/nn/__init__.py @@ -1,4 +1,4 @@ -from .decoder import LabelVerbalizerDecoder, PrototypicalDecoder +from .decoder import DeepNCMDecoder, LabelVerbalizerDecoder, PrototypicalDecoder from .dropout import LockedDropout, WordDropout from .model import Classifier, DefaultClassifier, Model @@ -9,5 +9,6 @@ "DefaultClassifier", "Model", "PrototypicalDecoder", + "DeepNCMDecoder", "LabelVerbalizerDecoder", ] diff --git a/flair/nn/decoder.py b/flair/nn/decoder.py index 48cdbf39b0..b5fc49ecf0 100644 --- a/flair/nn/decoder.py +++ b/flair/nn/decoder.py @@ -1,5 +1,5 @@ import logging -from typing import Optional +from typing import Literal, Optional import torch @@ -123,6 +123,221 @@ def forward(self, embedded): return scores +class DeepNCMDecoder(torch.nn.Module): + """Deep Nearest Class Mean (DeepNCM) Classifier for text classification tasks. + + This model combines deep learning with the Nearest Class Mean (NCM) approach. + It uses document embeddings to represent text, optionally applies an encoder, + and classifies based on the nearest class prototype in the embedded space. + + The model supports various methods for updating class prototypes during training, + making it adaptable to different learning scenarios. + + This implementation is based on the research paper: + Guerriero, S., Caputo, B., & Mensink, T. (2018). DeepNCM: Deep Nearest Class Mean Classifiers. + In International Conference on Learning Representations (ICLR) 2018 Workshop. + URL: https://openreview.net/forum?id=rkPLZ4JPM + """ + + def __init__( + self, + label_dictionary: Dictionary, + embeddings_size: int, + use_encoder: bool = True, + encoding_dim: Optional[int] = None, + alpha: float = 0.9, + mean_update_method: Literal["online", "condensation", "decay"] = "online", + multi_label: bool = False, # should get from the Model it belongs to + ) -> None: + """Initialize a DeepNCMDecoder. + + Args: + label_dictionary: Label dictionary from the corpus + embeddings_size: The dimensionality of the input embeddings, usually the same as the model embeddings + use_encoder: Whether to apply an encoder to the input embeddings (default is True). + encoding_dim: The dimensionality of the encoded embeddings if an encoder is used (default is the same as the input embeddings). + alpha: The decay factor for updating class prototypes (default is 0.9). This only applies when mean_update_method is 'decay'. + mean_update_method: The method for updating class prototypes ('online', 'condensation', or 'decay'). + online - + condensation - + decay - after every batch, + multi_label: Whether to predict multiple labels per sentence (default is False, and performs multi-class clsasification). + """ + + super().__init__() + + self.label_dictionary = label_dictionary + self._num_prototypes = len(label_dictionary) + + self.alpha = alpha + self.mean_update_method = mean_update_method + self.use_encoder = use_encoder + self.multi_label = multi_label + + self.embedding_dim = embeddings_size + + if use_encoder: + self.encoding_dim = encoding_dim or self.embedding_dim + else: + self.encoding_dim = self.embedding_dim + + self.class_prototypes = torch.nn.Parameter( + torch.nn.functional.normalize(torch.randn(self._num_prototypes, self.encoding_dim)), requires_grad=False + ) + + self.class_counts = torch.nn.Parameter(torch.zeros(self._num_prototypes), requires_grad=False) + self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) + self.prototype_update_counts = torch.zeros(self._num_prototypes).to(flair.device) + self.to(flair.device) + + self._validate_parameters() + + if self.use_encoder: + self.encoder = torch.nn.Sequential( + torch.nn.Linear(self.embedding_dim, self.encoding_dim * 2), + torch.nn.ReLU(), + torch.nn.Linear(self.encoding_dim * 2, self.encoding_dim), + ) + else: + self.encoder = torch.nn.Sequential(torch.nn.Identity()) + + # all parameters will be pushed internally to the specified device + self.to(flair.device) + + def _validate_parameters(self) -> None: + """Validate that the input parameters have valid and compatible values.""" + assert 0 <= self.alpha <= 1, "alpha must be in the range [0, 1]" + assert self.mean_update_method in [ + "online", + "condensation", + "decay", + ], f"Invalid mean_update_method: {self.mean_update_method}. Must be 'online', 'condensation', or 'decay'" + assert self.encoding_dim > 0, "encoding_dim must be greater than 0" + + @property + def num_prototypes(self) -> int: + """The number of class prototypes.""" + return self.class_prototypes.size(0) + + def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor: + """Calculate the squared Euclidean distance between encoded embeddings and class prototypes. + + Args: + encoded_embeddings: Encoded representations of the input sentences. + + Returns: + torch.Tensor: Distances between encoded embeddings and class prototypes. + """ + return torch.cdist(encoded_embeddings, self.class_prototypes).pow(2) + + def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: torch.Tensor) -> None: + """Calculate updates for class prototypes based on the current batch. + + Args: + encoded_embeddings: Encoded representations of the input sentences. + labels: True labels for the input sentences. + """ + one_hot = ( + labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_prototypes).float() + ) + + updates = torch.matmul(one_hot.t(), encoded_embeddings) + counts = one_hot.sum(dim=0) + mask = counts > 0 + self.prototype_updates[mask] += updates[mask] + self.prototype_update_counts[mask] += counts[mask] + + def update_prototypes(self) -> None: + """Apply accumulated updates to class prototypes.""" + with torch.no_grad(): + update_mask = self.prototype_update_counts > 0 + if update_mask.any(): + if self.mean_update_method in ["online", "condensation"]: + new_counts = self.class_counts[update_mask] + self.prototype_update_counts[update_mask] + self.class_prototypes[update_mask] = ( + self.class_counts[update_mask].unsqueeze(1) * self.class_prototypes[update_mask] + + self.prototype_updates[update_mask] + ) / new_counts.unsqueeze(1) + self.class_counts[update_mask] = new_counts + elif self.mean_update_method == "decay": + new_prototypes = self.prototype_updates[update_mask] / self.prototype_update_counts[ + update_mask + ].unsqueeze(1) + self.class_prototypes[update_mask] = ( + self.alpha * self.class_prototypes[update_mask] + (1 - self.alpha) * new_prototypes + ) + self.class_counts[update_mask] += self.prototype_update_counts[update_mask] + + # Reset prototype updates + self.prototype_updates = torch.zeros_like(self.class_prototypes, device=flair.device) + self.prototype_update_counts = torch.zeros(self.num_prototypes, device=flair.device) + + def forward(self, embedded: torch.Tensor, label_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of the decoder, which calculates the scores as prototype distances. + + :param embedded: Embedded representations of the input sentences. + :param label_tensor: True labels for the input sentences as a tensor. + :return: Scores as a tensor of distances to class prototypes. + """ + encoded_embeddings = self.encoder(embedded) + + distances = self._calculate_distances(encoded_embeddings) + + if label_tensor is not None: + self._calculate_prototype_updates(encoded_embeddings, label_tensor) + + scores = -distances + + return scores + + def get_prototype(self, class_name: str) -> torch.Tensor: + """Get the prototype vector for a given class name. + + Args: + class_name: The name of the class whose prototype vector is requested. + + Returns: + torch.Tensor: The prototype vector for the given class. + + Raises: + ValueError: If the class name is not found in the label dictionary. + """ + try: + class_idx = self.label_dictionary.get_idx_for_item(class_name) + except IndexError as exc: + raise ValueError(f"Class name '{class_name}' not found in the label dictionary") from exc + + return self.class_prototypes[class_idx].clone() + + def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> list[tuple[str, float]]: + """Get the k closest prototype vectors to the given input vector using the configured distance metric. + + Args: + input_vector (torch.Tensor): The input vector to compare against prototypes. + top_k (int): The number of closest prototypes to return (default is 5). + + Returns: + list[tuple[str, float]]: Each tuple contains (class_name, distance). + """ + if input_vector.dim() != 1: + raise ValueError("Input vector must be a 1D tensor") + if input_vector.size(0) != self.class_prototypes.size(1): + raise ValueError( + f"Input vector dimension ({input_vector.size(0)}) does not match prototype dimension ({self.class_prototypes.size(1)})" + ) + + input_vector = input_vector.unsqueeze(0) + distances = self._calculate_distances(input_vector) + top_k_values, top_k_indices = torch.topk(distances.squeeze(), k=top_k, largest=False) + + nearest_prototypes = [] + for idx, value in zip(top_k_indices, top_k_values): + class_name = self.label_dictionary.get_item_for_index(idx.item()) + nearest_prototypes.append((class_name, value.item())) + + return nearest_prototypes + + class LabelVerbalizerDecoder(torch.nn.Module): """A class for decoding labels using the idea of siamese networks / bi-encoders. This can be used for all classification tasks in flair. diff --git a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py index e5394debd2..981d413d61 100644 --- a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py +++ b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py @@ -1,7 +1,7 @@ import torch from flair.models import MultitaskModel -from flair.models.deepncm_classification_model import DeepNCMDecoder +from flair.nn import DeepNCMDecoder from flair.trainers.plugins.base import TrainerPlugin diff --git a/tests/models/test_deepncm_classifier.py b/tests/models/test_deepncm_classifier.py index b587a33142..5324f08fc3 100644 --- a/tests/models/test_deepncm_classifier.py +++ b/tests/models/test_deepncm_classifier.py @@ -4,7 +4,8 @@ from flair.data import Sentence from flair.datasets import ClassificationCorpus from flair.embeddings import TransformerDocumentEmbeddings -from flair.models import DeepNCMDecoder, TextClassifier +from flair.models import TextClassifier +from flair.nn import DeepNCMDecoder from flair.trainers import ModelTrainer from flair.trainers.plugins import DeepNCMPlugin from tests.model_test_utils import BaseModelTest From 46379c9142a52d36c910f6f00bc5066a40f24a31 Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Fri, 3 Jan 2025 14:25:46 +0100 Subject: [PATCH 4/7] Removed deprecated deepncm classifier file --- flair/models/deepncm_classification_model.py | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 flair/models/deepncm_classification_model.py diff --git a/flair/models/deepncm_classification_model.py b/flair/models/deepncm_classification_model.py deleted file mode 100644 index be1b5788a0..0000000000 --- a/flair/models/deepncm_classification_model.py +++ /dev/null @@ -1,11 +0,0 @@ -import logging -from typing import Literal, Optional - -import torch - -import flair -from flair.data import Dictionary - -log = logging.getLogger("flair") - - From cab51053956347c269f5e4a9802c848778a31cb9 Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Fri, 3 Jan 2025 14:26:13 +0100 Subject: [PATCH 5/7] Slightly refactored deepncm trainer plugin --- .../functional/deepncm_trainer_plugin.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py index 981d413d61..006396b760 100644 --- a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py +++ b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py @@ -1,3 +1,5 @@ +from collections.abc import Iterable + import torch from flair.models import MultitaskModel @@ -11,32 +13,29 @@ class DeepNCMPlugin(TrainerPlugin): Handles both multitask and single-task scenarios. """ - def _process_models(self, operation: str): - """Process updates for all DeepNCMDecoder decoders in the trainer. - - Args: - operation (str): The operation to perform ('condensation' or 'update') - """ + @property + def decoders(self) -> Iterable[DeepNCMDecoder]: + """Iterator over all DeepNCMDecoder decoders in the trainer.""" model = self.trainer.model models = model.tasks.values() if isinstance(model, MultitaskModel) else [model] for sub_model in models: if hasattr(sub_model, "decoder") and isinstance(sub_model.decoder, DeepNCMDecoder): - if operation == "condensation" and sub_model.decoder.mean_update_method == "condensation": - sub_model.decoder.class_counts.data = torch.ones_like(sub_model.decoder.class_counts) - elif operation == "update": - sub_model.decoder.update_prototypes() + yield sub_model.decoder @TrainerPlugin.hook def after_training_epoch(self, **kwargs): - """Update prototypes after each training epoch.""" - self._process_models("condensation") + """Reset class counts after each training epoch.""" + for decoder in self.decoders: + if decoder.mean_update_method == "condensation": + decoder.class_counts.data = torch.ones_like(decoder.class_counts) @TrainerPlugin.hook def after_training_batch(self, **kwargs): """Update prototypes after each training batch.""" - self._process_models("update") + for decoder in self.decoders: + decoder.update_prototypes() def __str__(self) -> str: return "DeepNCMPlugin" From 5ae1508e44dbe19d0dc913240c35db205654cc0b Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Fri, 3 Jan 2025 14:26:22 +0100 Subject: [PATCH 6/7] Fixed formatting --- flair/nn/decoder.py | 1 - flair/nn/model.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/flair/nn/decoder.py b/flair/nn/decoder.py index b5fc49ecf0..5499f03fe5 100644 --- a/flair/nn/decoder.py +++ b/flair/nn/decoder.py @@ -163,7 +163,6 @@ def __init__( decay - after every batch, multi_label: Whether to predict multiple labels per sentence (default is False, and performs multi-class clsasification). """ - super().__init__() self.label_dictionary = label_dictionary diff --git a/flair/nn/model.py b/flair/nn/model.py index 69c51f7a5e..cc38c56c5a 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import Counter from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch.nn from torch import Tensor @@ -780,7 +780,7 @@ def forward_loss(self, sentences: list[DT]) -> tuple[torch.Tensor, int]: # decode, passing label tensor if needed, such as for prototype updates if "label_tensor" in inspect.signature(self.decoder.forward).parameters: - scores = self.decoder(data_point_tensor, label_tensor) + scores = self.decoder(data_point_tensor, label_tensor=label_tensor) else: scores = self.decoder(data_point_tensor) @@ -817,7 +817,7 @@ def predict( label_name: Optional[str] = None, return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", - ) -> Optional[Union[List[DT], Tuple[float, int]]]: + ) -> Optional[Union[list[DT], tuple[float, int]]]: """Predicts the class labels for the given sentences. The labels are directly added to the sentences. Args: From 8398be2abc194a735140771dfc1373742557ca90 Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Fri, 3 Jan 2025 15:25:58 +0100 Subject: [PATCH 7/7] Removed predict return types The specified return types were overly resetrictive (e.g. did not include sequence labelling models) --- flair/nn/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/nn/model.py b/flair/nn/model.py index 623907727b..7c10c82ee8 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -817,7 +817,7 @@ def predict( label_name: Optional[str] = None, return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", - ) -> Optional[Union[list[DT], tuple[float, int]]]: + ): """Predicts the class labels for the given sentences. The labels are directly added to the sentences. Args: