Skip to content

Commit

Permalink
Update text_augment.py
Browse files Browse the repository at this point in the history
  • Loading branch information
huu4ontocord authored Mar 12, 2022
1 parent d2d06b8 commit a8b91a3
Showing 1 changed file with 1 addition and 133 deletions.
134 changes: 1 addition & 133 deletions text_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,130 +399,6 @@ def deserialize_doc(doc):
return [deserialize_doc(doc) for doc in docs]
return docs

@staticmethod
def get_lang_groups(src_lang):
""" we use langid because it's pretty fast but it has difficulties in low resource languages
langid can sometimes mistake languages that are in the same group. that is ok for our purpose as
we mainly use the langid check to confirm the labels from other models. """
lang_groups=[src_lang]
if src_lang in ('ig', 'sn', 'ny', 'st', 'zu', 'xh', 'rw', 'sw', 'yo'):
lang_groups = ['ig', 'sn', 'ny', 'st', 'zu', 'xh', 'rw', 'sw', 'yo']
elif src_lang in ('mr', 'ne', 'hi', ):
lang_groups = ['mr', 'ne', 'hi', ]
elif src_lang in ('pt', 'gl'):
lang_groups = ['pt','gl','la' ]
elif src_lang in ('fr', 'br'):
lang_groups = ['fr','la', 'br' ]
elif src_lang in ('es', 'oc', 'ca', 'eu', 'an', 'gl' ):
lang_groups = ['es', 'oc', 'ca', 'eu', 'an', 'gl', 'la' ]
elif src_lang in ('arz', 'ar', 'fa', 'ur', 'az', 'azb', 'ckb' ):
lang_groups = ['arz', 'ar', 'fa', 'ur', 'az', 'azb', 'ckb' ]
elif src_lang in ('id', 'ms', ):
lang_groups = ['id', 'ms',]
elif src_lang in ('as', 'bn', 'bpy'):
lang_groups = ['as', 'bn', 'bpy']
elif src_lang in ('af', 'nl', ):
lang_groups = ['af', 'nl',]
elif src_lang in ('bo', 'dz', ):
lang_groups = ['bo', 'dz',]
elif src_lang in ('bs', 'hr', ):
lang_groups = ['bs', 'hr',]
elif src_lang in ('bxr', 'mn', ):
lang_groups = ['bxr', 'mn',]
elif src_lang in ('ceb', 'tl', ):
lang_groups = ['ceb', 'tl',]
elif src_lang in ('cs', 'sk', ):
lang_groups = ['cs', 'sk',]
elif src_lang in ('da', 'no', ):
lang_groups = ['da', 'no',]
elif src_lang in ('eml', 'wa', ):
lang_groups = ['eml', 'wa',]
elif src_lang in ('de', 'lb', 'pl', 'dsb'):
lang_groups = ['de', 'lb', 'pl', 'dsb']
elif src_lang in ('av', 'ru', 'bg', 'ba', 'kk', 'uk', 'be', 'ce', 'cv'):
lang_groups = ['av', 'ru', 'bg', 'ba', 'kk', 'uk', 'be', 'ce', 'cv']
return set(lang_groups)

@staticmethod
def check_good_sentence(s, src_lang, stopwords, show_err=False, lang_groups=[], ret_score=False, stopword_ratio_cutoff=0.06, bannedwords=None, flagged_words=None, badword_ratio_cutoff=0.15, junk_ratio=0.16, max_badword_len=5):
#basic dejunk
# for flagged_words, only filter out if the ratio is exceeded AND there exists one banned word
if bannedwords is None:
bannedwords = banned_words.get(src_lang, banned_words['default'])
default_bannedwords = banned_words['default']
s = s.lower().strip()
if not s:
return False
jr = len([s2 for s2 in s if s2 in junk])/len(s)
if jr >= junk_ratio:
return False
if src_lang in ("ja", "ko", "zh"):
sArr = s
else:
sArr = [s2.strip(special_char) for s2 in s.lower().split() if s2.strip(special_char)]
if len(sArr) == 0:
return False
bad_score = 0.0
if flagged_words:
if src_lang not in ("ja", "ko", "zh") and len([s2 for s2 in sArr if s2 in flagged_words])/len(sArr) > badword_ratio_cutoff:
if any(s2 for s2 in sArr if s2 in bannedwords) or any(s2 for s2 in sArr if s2 in default_bannedwords):
#print ('bw', len([s2 for s2 in sArr if s2 in flagged_words])/len(sArr))
return False
else:
bad_score = len([s2 for s2 in sArr if s2 in flagged_words])/len(sArr)
if src_lang in ("ja", "ko", "zh"):
badword_ratio_cutoff /= 100
len_s = len(s)
bad_cnt = 0
total_cnt = 0
for i in range(len_s):
for j in range(i+1,min(len_s, i+max_badword_len)):
if s[i:j] in flagged_words:
bad_cnt += 1
total_cnt += 1
bad_score = (bad_cnt/total_cnt)
if bad_score > badword_ratio_cutoff:
for bword in bannedwords:
if bword in s:
return False
for bword in default_bannedwords:
if bword in s:
return False

#stopword check
if stopwords:
#TODO: catch multi word with spaces
if src_lang not in ("ja", "ko", "zh") and len([s2 for s2 in sArr if s2 in stopwords])/len(sArr) < stopword_ratio_cutoff:
#print ('sw', len([s2 for s2 in sArr if s2 in stopwords])/len(sArr))
return False
if src_lang in ("ja", "ko", "zh"):
if src_lang == "zh":
max_stoword = TextAugment.max_stoword_len_zh
elif src_lang == "ko":
max_stoword = TextAugment.max_stoword_len_ko
elif src_lang == "ja":
max_stoword = TextAugment.max_stoword_len_ja
len_s = len(s)
stop_cnt = 0
total_cnt = 0
for i in range(len_s):
for j in range(i+1,min(len_s, i+max_stoword)):
if s[i:j] in stopwords:
stop_cnt += 1
total_cnt += 1
#print ('stopword', (stop_cnt/total_cnt) )
if (stop_cnt/total_cnt) < stopword_ratio_cutoff:
return False
#langid check
try:
lang = langid.classify(s)[0]
except:
return True
if show_err and lang != src_lang and lang not in lang_groups:
logger.info ((src_lang, lang))
if ret_score: return lang == src_lang or lang in lang_groups, bad_score
return lang == src_lang or lang in lang_groups

#WIP - we can use this question generation method to extract people, place and thing, and potentially age/date AND to get a relationship between a person and a PII info
def generate_questions_answers_rel(self, docs, chunks, src_lang, default_answers=[], text_key=None, ner_key=None, rel_key=None, signal='qg_rel', weight=1.0):
answers = {}
Expand Down Expand Up @@ -2123,7 +1999,6 @@ def process_ner(self,
regex_weight=1.5,
backtrans_weight=0.9,
do_docs_trim_for_person=False,
do_docs_filter=False,
do_qg_rel=False,
do_kenlm = True,
cutoff=None,
Expand Down Expand Up @@ -2176,10 +2051,7 @@ def process_ner(self,

flagged_words1 = set([s for s in flagged_words.get(src_lang, []) if len(s) < 5])
stopwords1 = set(stopwords.get(src_lang, []))
if do_docs_filter:
lang_groups=TextAugment.get_lang_groups(src_lang)
docs = [doc for doc in docs if self.check_good_sentence(doc[f'{src_lang}_text'], src_lang, lang_groups=lang_groups, stopwords=stopwords1, flagged_words=flagged_words1)]


if cutoff is not None and cutoff > 0 and len(docs) > cutoff:
docs = docs[:cutoff]
len_docs = len(docs)
Expand Down Expand Up @@ -2618,7 +2490,6 @@ def singleprocess_ner(infile,
regex_weight=1.5,
backtrans_weight=0.9,
do_docs_trim_for_person=False,
do_docs_filter=False,
do_qg_rel=False,
do_kenlm = True,
cutoff=None,
Expand Down Expand Up @@ -2674,7 +2545,6 @@ def singleprocess_ner(infile,
regex_weight=regex_weight,
backtrans_weight=backtrans_weight,
do_docs_trim_for_person=do_docs_trim_for_person,
do_docs_filter=do_docs_filter,
do_qg_rel=do_qg_rel,
do_kenlm = do_kenlm,
cutoff=cutoff,
Expand Down Expand Up @@ -2709,7 +2579,6 @@ def multiprocess_ner(infile,
regex_weight=1.5,
backtrans_weight=0.9,
do_docs_trim_for_person=False,
do_docs_filter=False,
do_qg_rel=False,
do_kenlm = True,
cutoff=None,
Expand Down Expand Up @@ -2770,7 +2639,6 @@ def multiprocess_ner(infile,
regex_weight=regex_weight,
backtrans_weight=backtrans_weight,
do_docs_trim_for_person=do_docs_trim_for_person,
do_docs_filter=do_docs_filter,
do_qg_rel=do_qg_rel,
do_kenlm = do_kenlm,
cutoff=cutoff,
Expand Down

0 comments on commit a8b91a3

Please sign in to comment.