From 17df903ef1724916ec56dcd9b9f188a7ccc9e07e Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Fri, 10 Jan 2025 05:32:00 -0500 Subject: [PATCH 1/2] add new functionality to BPE lexicon creation --- lexicon/bpe.py | 49 +++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/lexicon/bpe.py b/lexicon/bpe.py index 563f0020..75af820e 100644 --- a/lexicon/bpe.py +++ b/lexicon/bpe.py @@ -1,6 +1,7 @@ __all__ = ["CreateBPELexiconJob"] import subprocess as sp +import logging import os import sys from typing import List, Optional, Set, Union @@ -19,6 +20,8 @@ class CreateBPELexiconJob(Job): This job is still in experimental state, and only tested with Flashlight BPE decoding """ + __sis_hash_exclude__ = {"skip_unk_lemmas": False, "add_all_bpe_phonemes": True, "additional_words": None} + def __init__( self, base_lexicon_path: tk.Path, @@ -28,6 +31,9 @@ def __init__( unk_label: str = "UNK", vocab_blacklist: Optional[Union[List[str], Set[str]]] = None, keep_special_lemmas: bool = True, + skip_unk_lemmas: bool = False, + add_all_bpe_phonemes: bool = True, + additional_words: Optional[tk.Path] = None, ): """ :param base_lexicon_path: base lexicon (can be phoneme based) to take the lemmas from @@ -41,6 +47,12 @@ def __init__( usually yes for RASR search and no for Flashlight search. The phonemes of the special lemmas will also be kept, therefore make sure there is no overlap with the BPE vocab. + :param skip_unk_lemmas: whether simply skip lemmas out of the BPE vocab + useful if you set vocab_blacklist + :param add_all_bpe_phonemes: If set to True, all BPE vocab will be added to lexicon phonemes, + otherwise, only phonemes appear in lexicon lemma will be added to the lexicon. + :param additional_words: Aside from vocab specified in base_lexicon, we might want to convert some other words, + e.g. untranslatable words by a g2p model in case of g2p-augmented lexicon """ self.base_lexicon_path = base_lexicon_path self.bpe_codes = bpe_codes @@ -53,6 +65,9 @@ def __init__( # convert list to set for faster "in" check self.vocab_blacklist = set(vocab_blacklist) self.keep_special_lemmas = keep_special_lemmas + self.skip_unk_lemmas = skip_unk_lemmas + self.add_all_bpe_phonemes = add_all_bpe_phonemes + self.additional_words = additional_words self.out_lexicon = self.output_path("lexicon.xml.gz", cached=True) @@ -85,19 +100,32 @@ def _fill_vocab_and_lexicon(self): vocab_file.write(symbol + " -1\n") symbol = symbol.replace(".", "_") vocab.add(symbol) - lexicon.add_phoneme(symbol.replace(".", "_")) + if self.add_all_bpe_phonemes: + lexicon.add_phoneme(symbol.replace(".", "_")) return vocab, lexicon + def _fill_additional_words(self): + additional_words_list = set() + if self.additional_words is not None: + with util.uopen(self.additional_words.get_path(), "rt") as f: + for line in f: + line = line.strip() + additional_words_list.update(line) + return sorted(additional_words_list) + def run(self): base_lexicon = Lexicon() base_lexicon.load(self.base_lexicon_path) lm_tokens, special_lemmas = self._fill_lm_tokens(base_lexicon) + additional_words_list = self._fill_additional_words() with util.uopen("words", "wt") as f: for t in lm_tokens: f.write(f"{t}\n") + for t in additional_words_list: + f.write(f"{t}\n") vocab, lexicon = self._fill_vocab_and_lexicon() @@ -127,15 +155,32 @@ def run(self): with util.uopen("bpes", "rt") as bpe_file: bpe_tokens = [line.strip() for line in bpe_file] - w2b = {w: b for w, b in zip(lm_tokens, bpe_tokens)} + w2b = {w: b for w, b in zip(lm_tokens + additional_words_list, bpe_tokens)} + used_vocab = set() for lemma in base_lexicon.lemmata: if lemma.special: continue for orth in lemma.orth: bpe_pron = " ".join([token if token in vocab else self.unk_label for token in w2b[orth].split()]) + if self.skip_unk_lemmas and self.unk_label in bpe_pron.split(): + logging.info(f"Lemma {orth} is skipped due to unknown BPE vocab.") + continue + used_vocab.update(set(bpe_pron.split())) lexicon.add_lemma(Lemma([orth], [bpe_pron.replace(".", "_")], lemma.synt, lemma.eval)) + for word in additional_words_list: + bpe_pron = " ".join([token if token in vocab else self.unk_label for token in w2b[word].split()]) + if self.skip_unk_lemmas and self.unk_label in bpe_pron.split(): + logging.info(f"Lemma {word} is skipped due to unknown BPE vocab.") + continue + used_vocab.update(set(bpe_pron.split())) + lexicon.add_lemma(Lemma([word], [bpe_pron.replace(".", "_")])) + + if not self.add_all_bpe_phonemes: + for symbol in sorted(used_vocab): + lexicon.add_phoneme(symbol.replace(".", "_")) + elem = lexicon.to_xml() tree = ET.ElementTree(elem) util.write_xml(self.out_lexicon.get_path(), tree) From 8a0d00ed56f2cfc1e7c8393cbd38b18dc15ca870 Mon Sep 17 00:00:00 2001 From: Haotian Wu Date: Tue, 21 Jan 2025 11:20:32 -0500 Subject: [PATCH 2/2] sharing more code suggested by Albert --- lexicon/bpe.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/lexicon/bpe.py b/lexicon/bpe.py index 75af820e..03ba4f42 100644 --- a/lexicon/bpe.py +++ b/lexicon/bpe.py @@ -111,21 +111,22 @@ def _fill_additional_words(self): with util.uopen(self.additional_words.get_path(), "rt") as f: for line in f: line = line.strip() - additional_words_list.update(line) + additional_words_list.add(line) return sorted(additional_words_list) def run(self): base_lexicon = Lexicon() base_lexicon.load(self.base_lexicon_path) - lm_tokens, special_lemmas = self._fill_lm_tokens(base_lexicon) additional_words_list = self._fill_additional_words() + print(additional_words_list) + for w in additional_words_list: + base_lexicon.add_lemma(Lemma([w], None)) # add empty lemmata with only orth for additional words + lm_tokens, special_lemmas = self._fill_lm_tokens(base_lexicon) with util.uopen("words", "wt") as f: for t in lm_tokens: f.write(f"{t}\n") - for t in additional_words_list: - f.write(f"{t}\n") vocab, lexicon = self._fill_vocab_and_lexicon() @@ -155,7 +156,7 @@ def run(self): with util.uopen("bpes", "rt") as bpe_file: bpe_tokens = [line.strip() for line in bpe_file] - w2b = {w: b for w, b in zip(lm_tokens + additional_words_list, bpe_tokens)} + w2b = {w: b for w, b in zip(lm_tokens, bpe_tokens)} used_vocab = set() for lemma in base_lexicon.lemmata: @@ -169,14 +170,6 @@ def run(self): used_vocab.update(set(bpe_pron.split())) lexicon.add_lemma(Lemma([orth], [bpe_pron.replace(".", "_")], lemma.synt, lemma.eval)) - for word in additional_words_list: - bpe_pron = " ".join([token if token in vocab else self.unk_label for token in w2b[word].split()]) - if self.skip_unk_lemmas and self.unk_label in bpe_pron.split(): - logging.info(f"Lemma {word} is skipped due to unknown BPE vocab.") - continue - used_vocab.update(set(bpe_pron.split())) - lexicon.add_lemma(Lemma([word], [bpe_pron.replace(".", "_")])) - if not self.add_all_bpe_phonemes: for symbol in sorted(used_vocab): lexicon.add_phoneme(symbol.replace(".", "_"))