From 1ed6d46eb05d0994fb894fd5439d4c04a4c4a775 Mon Sep 17 00:00:00 2001 From: ASR Date: Wed, 24 Jan 2024 10:44:35 +0530 Subject: [PATCH] added some things --- nemo/collections/asr/metrics/wer_bpe.py | 16 +++++-- nemo/collections/asr/models/ctc_bpe_models.py | 11 +++-- .../asr/models/ctc_bpe_multisoftmax_models.py | 2 + .../common/parts/preprocessing/collections.py | 1 + nemo/utils/exp_manager.py | 45 +++++++++++++++++-- 5 files changed, 65 insertions(+), 10 deletions(-) diff --git a/nemo/collections/asr/metrics/wer_bpe.py b/nemo/collections/asr/metrics/wer_bpe.py index 3e3ee1923..93ad0de75 100644 --- a/nemo/collections/asr/metrics/wer_bpe.py +++ b/nemo/collections/asr/metrics/wer_bpe.py @@ -138,7 +138,7 @@ class CTCBPEDecoding(AbstractCTCDecoding): tokenizer: NeMo tokenizer object, which inherits from TokenizerSpec. """ - def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None): + def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None,lang=None): if blank_id is None: blank_id = tokenizer.tokenizer.vocab_size @@ -149,13 +149,21 @@ def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None): # Finalize Beam Search Decoding framework if isinstance(self.decoding, ctc_beam_decoding.AbstractBeamCTCInfer): if hasattr(self.tokenizer.tokenizer, 'get_vocab'): - vocab_dict = self.tokenizer.tokenizer.get_vocab() - if isinstance(self.tokenizer.tokenizer, DummyTokenizer): # AggregateTokenizer.DummyTokenizer + if lang is None: + vocab_dict = self.tokenizer.tokenizer.get_vocab() + else: + vocab_dict = self.tokenizer.tokenizers_dict['hi'].tokenizer.get_vocab() + print(vocab_dict) + # breakpoint() + if isinstance(self.tokenizer.tokenizer, DummyTokenizer): # or decoding_cfg.tokenizer_type == "multilingual": # AggregateTokenizer.DummyTokenizer vocab = vocab_dict else: vocab = list(vocab_dict.keys()) self.decoding.set_vocabulary(vocab) - self.decoding.set_tokenizer(tokenizer) + if lang is not None: + self.decoding.set_tokenizer(self.tokenizer.tokenizers_dict['hi']) + else: + self.decoding.set_tokenizer(self.tokenizer) else: logging.warning("Could not resolve the vocabulary of the tokenizer !") diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index 14614a1b7..a9577f2bd 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -91,7 +91,12 @@ def __init__(self, cfg: DictConfig, trainer=None): with open_dict(self.cfg): self.cfg.decoding = decoding_cfg if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in cfg.decoder: - self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys())) + if decoding_cfg.strategy == 'pyctcdecode': + # create separate decoders for each language + # self.decoding = [CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()),lang=l) for l in self.tokenizer.tokenizers_dict.keys()] + self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()),lang='any') + else: + self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys())) else: self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer) @@ -140,7 +145,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): drop_last=config.get('drop_last', False), shuffle=config['shuffle'], num_workers=config.get('num_workers', 0), - pin_memory=config.get('pin_memory', False), + pin_memory=config.get('pin_memory', False) ) else: return torch.utils.data.DataLoader( @@ -149,7 +154,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): collate_fn=collate_fn, drop_last=config.get('drop_last', False), num_workers=config.get('num_workers', 0), - pin_memory=config.get('pin_memory', False), + pin_memory=config.get('pin_memory', False) ) def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': diff --git a/nemo/collections/asr/models/ctc_bpe_multisoftmax_models.py b/nemo/collections/asr/models/ctc_bpe_multisoftmax_models.py index 6cd608c42..bc53ad7fd 100644 --- a/nemo/collections/asr/models/ctc_bpe_multisoftmax_models.py +++ b/nemo/collections/asr/models/ctc_bpe_multisoftmax_models.py @@ -50,6 +50,8 @@ def __init__(self, cfg: DictConfig, trainer=None): # Initialize a dummy vocabulary vocabulary = self.tokenizer.tokenizer.get_vocab() + print(vocabulary) + breakpoint() # Set the new vocabulary with open_dict(cfg): diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 5c3c35990..3abb14c16 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -161,6 +161,7 @@ def __init__( if lang is not None: text_tokens = parser(text, lang) else: + print(audio_file) raise ValueError("lang required in manifest when using aggregate tokenizers") else: text_tokens = parser(text) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index af9610da3..6e0a08653 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -50,6 +50,40 @@ from nemo.utils.model_utils import uninject_model_parallel_rank +class MinStepsCallback(EarlyStopping): + def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 3, + verbose: bool = False, mode: str = 'auto', strict: bool = True, + min_steps: int = 5000, check_finite: bool = True, stopping_threshold: Optional[float] = None, + divergence_threshold: Optional[float] = None,check_on_train_epoch_end: Optional[bool] = None, + log_rank_zero_only: bool = False + ): + self.min_steps = min_steps + super().__init__(monitor=monitor, min_delta=min_delta, patience=patience, + verbose=verbose, mode=mode, strict=strict, check_finite=check_finite, + stopping_threshold=stopping_threshold,divergence_threshold=divergence_threshold, + check_on_train_epoch_end=check_on_train_epoch_end,log_rank_zero_only=log_rank_zero_only) + + def _run_early_stopping_check(self, trainer: pytorch_lightning.Trainer) -> None: + if trainer.global_step > self.min_steps: + return super()._run_early_stopping_check(trainer) + else: + return False, f"Yet to reach the minimum steps {trainer.global_step}" + +@dataclass +class MinStepsCallbackParams: + monitor: str = "val_loss" # The metric that early stopping should consider. + mode: str = "min" # inform early stopping whether to look for increase or decrease in monitor. + min_delta: float = 0.001 # smallest change to consider as improvement. + patience: int = 10 # how many (continuous) validation cycles to wait with no improvement and stopping training. + verbose: bool = True + strict: bool = True + check_finite: bool = True + stopping_threshold: Optional[float] = None + divergence_threshold: Optional[float] = None + check_on_train_epoch_end: Optional[bool] = None + log_rank_zero_only: bool = False + min_steps: int = 5000 + class NotFoundError(NeMoBaseException): """ Raised when a file or folder is not found""" @@ -170,6 +204,8 @@ class ExpManagerConfig: ema: Optional[EMAParams] = EMAParams() # Wall clock time limit max_time_per_run: Optional[str] = None + early_stopping_with_min_steps: Optional[bool] = False + early_stopping_with_min_steps_params: Optional[MinStepsCallbackParams] = MinStepsCallbackParams() class TimingCallback(Callback): @@ -436,10 +472,13 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo every_n_steps=cfg.ema.every_n_steps, ) trainer.callbacks.append(ema_callback) + if cfg.early_stopping_with_min_steps: + min_steps_cb = MinStepsCallback(**cfg.early_stopping_with_min_steps_params) + trainer.callbacks.append(min_steps_cb) - if cfg.create_early_stopping_callback: - early_stop_callback = EarlyStopping(**cfg.early_stopping_callback_params) - trainer.callbacks.append(early_stop_callback) + # if cfg.create_early_stopping_callback: + # early_stop_callback = EarlyStopping(**cfg.early_stopping_callback_params) + # trainer.callbacks.append(early_stop_callback) if cfg.create_checkpoint_callback: configure_checkpointing(