Skip to content

Commit

Permalink
added some things
Browse files Browse the repository at this point in the history
  • Loading branch information
ASR committed Jan 24, 2024
1 parent 3bddb03 commit 1ed6d46
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 10 deletions.
16 changes: 12 additions & 4 deletions nemo/collections/asr/metrics/wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ class CTCBPEDecoding(AbstractCTCDecoding):
tokenizer: NeMo tokenizer object, which inherits from TokenizerSpec.
"""

def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None):
def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None,lang=None):

if blank_id is None:
blank_id = tokenizer.tokenizer.vocab_size
Expand All @@ -149,13 +149,21 @@ def __init__(self, decoding_cfg, tokenizer: TokenizerSpec, blank_id = None):
# Finalize Beam Search Decoding framework
if isinstance(self.decoding, ctc_beam_decoding.AbstractBeamCTCInfer):
if hasattr(self.tokenizer.tokenizer, 'get_vocab'):
vocab_dict = self.tokenizer.tokenizer.get_vocab()
if isinstance(self.tokenizer.tokenizer, DummyTokenizer): # AggregateTokenizer.DummyTokenizer
if lang is None:
vocab_dict = self.tokenizer.tokenizer.get_vocab()
else:
vocab_dict = self.tokenizer.tokenizers_dict['hi'].tokenizer.get_vocab()
print(vocab_dict)
# breakpoint()
if isinstance(self.tokenizer.tokenizer, DummyTokenizer): # or decoding_cfg.tokenizer_type == "multilingual": # AggregateTokenizer.DummyTokenizer
vocab = vocab_dict
else:
vocab = list(vocab_dict.keys())
self.decoding.set_vocabulary(vocab)
self.decoding.set_tokenizer(tokenizer)
if lang is not None:
self.decoding.set_tokenizer(self.tokenizer.tokenizers_dict['hi'])
else:
self.decoding.set_tokenizer(self.tokenizer)
else:
logging.warning("Could not resolve the vocabulary of the tokenizer !")

Expand Down
11 changes: 8 additions & 3 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ def __init__(self, cfg: DictConfig, trainer=None):
with open_dict(self.cfg):
self.cfg.decoding = decoding_cfg
if (self.tokenizer_type == "agg" or self.tokenizer_type == "multilingual") and "multisoftmax" in cfg.decoder:
self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
if decoding_cfg.strategy == 'pyctcdecode':
# create separate decoders for each language
# self.decoding = [CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()),lang=l) for l in self.tokenizer.tokenizers_dict.keys()]
self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()),lang='any')
else:
self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer, blank_id=self.decoder._num_classes//len(self.tokenizer.tokenizers_dict.keys()))
else:
self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer)

Expand Down Expand Up @@ -140,7 +145,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
drop_last=config.get('drop_last', False),
shuffle=config['shuffle'],
num_workers=config.get('num_workers', 0),
pin_memory=config.get('pin_memory', False),
pin_memory=config.get('pin_memory', False)
)
else:
return torch.utils.data.DataLoader(
Expand All @@ -149,7 +154,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
collate_fn=collate_fn,
drop_last=config.get('drop_last', False),
num_workers=config.get('num_workers', 0),
pin_memory=config.get('pin_memory', False),
pin_memory=config.get('pin_memory', False)
)

def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/ctc_bpe_multisoftmax_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(self, cfg: DictConfig, trainer=None):

# Initialize a dummy vocabulary
vocabulary = self.tokenizer.tokenizer.get_vocab()
print(vocabulary)
breakpoint()

# Set the new vocabulary
with open_dict(cfg):
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/common/parts/preprocessing/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
if lang is not None:
text_tokens = parser(text, lang)
else:
print(audio_file)
raise ValueError("lang required in manifest when using aggregate tokenizers")
else:
text_tokens = parser(text)
Expand Down
45 changes: 42 additions & 3 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,40 @@
from nemo.utils.model_utils import uninject_model_parallel_rank


class MinStepsCallback(EarlyStopping):
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 3,
verbose: bool = False, mode: str = 'auto', strict: bool = True,
min_steps: int = 5000, check_finite: bool = True, stopping_threshold: Optional[float] = None,
divergence_threshold: Optional[float] = None,check_on_train_epoch_end: Optional[bool] = None,
log_rank_zero_only: bool = False
):
self.min_steps = min_steps
super().__init__(monitor=monitor, min_delta=min_delta, patience=patience,
verbose=verbose, mode=mode, strict=strict, check_finite=check_finite,
stopping_threshold=stopping_threshold,divergence_threshold=divergence_threshold,
check_on_train_epoch_end=check_on_train_epoch_end,log_rank_zero_only=log_rank_zero_only)

def _run_early_stopping_check(self, trainer: pytorch_lightning.Trainer) -> None:
if trainer.global_step > self.min_steps:
return super()._run_early_stopping_check(trainer)
else:
return False, f"Yet to reach the minimum steps {trainer.global_step}"

@dataclass
class MinStepsCallbackParams:
monitor: str = "val_loss" # The metric that early stopping should consider.
mode: str = "min" # inform early stopping whether to look for increase or decrease in monitor.
min_delta: float = 0.001 # smallest change to consider as improvement.
patience: int = 10 # how many (continuous) validation cycles to wait with no improvement and stopping training.
verbose: bool = True
strict: bool = True
check_finite: bool = True
stopping_threshold: Optional[float] = None
divergence_threshold: Optional[float] = None
check_on_train_epoch_end: Optional[bool] = None
log_rank_zero_only: bool = False
min_steps: int = 5000

class NotFoundError(NeMoBaseException):
""" Raised when a file or folder is not found"""

Expand Down Expand Up @@ -170,6 +204,8 @@ class ExpManagerConfig:
ema: Optional[EMAParams] = EMAParams()
# Wall clock time limit
max_time_per_run: Optional[str] = None
early_stopping_with_min_steps: Optional[bool] = False
early_stopping_with_min_steps_params: Optional[MinStepsCallbackParams] = MinStepsCallbackParams()


class TimingCallback(Callback):
Expand Down Expand Up @@ -436,10 +472,13 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
every_n_steps=cfg.ema.every_n_steps,
)
trainer.callbacks.append(ema_callback)
if cfg.early_stopping_with_min_steps:
min_steps_cb = MinStepsCallback(**cfg.early_stopping_with_min_steps_params)
trainer.callbacks.append(min_steps_cb)

if cfg.create_early_stopping_callback:
early_stop_callback = EarlyStopping(**cfg.early_stopping_callback_params)
trainer.callbacks.append(early_stop_callback)
# if cfg.create_early_stopping_callback:
# early_stop_callback = EarlyStopping(**cfg.early_stopping_callback_params)
# trainer.callbacks.append(early_stop_callback)

if cfg.create_checkpoint_callback:
configure_checkpointing(
Expand Down

0 comments on commit 1ed6d46

Please sign in to comment.