Skip to content

Commit

Permalink
Merge pull request #783 from ufal/chrf_fix
Browse files Browse the repository at this point in the history
fix issue #782
  • Loading branch information
jindrahelcl authored Jan 6, 2019
2 parents b4ea78c + 98f3b3b commit 2c71059
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 35 deletions.
65 changes: 30 additions & 35 deletions neuralmonkey/evaluators/chrf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Dict
from typeguard import check_argument_types
import numpy as np
from neuralmonkey.evaluators.evaluator import Evaluator

# pylint: disable=invalid-name
Expand All @@ -25,7 +26,6 @@ def __init__(self,
super().__init__(name)

self.n = n
self.max_ord = n
self.beta_2 = beta**2

self.ignored = [] # type: List[str]
Expand All @@ -37,11 +37,11 @@ def score_instance(self,
reference: List[str]) -> float:
hyp_joined = " ".join(hypothesis)
hyp_chars = [x for x in list(hyp_joined) if x not in self.ignored]
hyp_ngrams = self._get_ngrams(hyp_chars, self.n)
hyp_ngrams = _get_ngrams(hyp_chars, self.n)

ref_joined = " ".join(reference)
ref_chars = [x for x in list(ref_joined) if x not in self.ignored]
ref_ngrams = self._get_ngrams(ref_chars, self.n)
ref_ngrams = _get_ngrams(ref_chars, self.n)

if not hyp_chars or not ref_chars:
if "".join(hyp_chars) == "".join(ref_chars):
Expand All @@ -58,48 +58,43 @@ def score_instance(self,
/ ((self.beta_2 * precision) + recall))

def chr_r(self, hyp_ngrams: NGramDicts, ref_ngrams: NGramDicts) -> float:
recall = 0.0
count_all = np.zeros(self.n)
count_matched = np.zeros(self.n)
for m in range(1, self.n + 1):
count_all = 0
count_matched = 0
for ngr in ref_ngrams[m - 1]:
ref_count = ref_ngrams[m - 1][ngr]
count_all += ref_count
count_all[m - 1] += ref_count
if ngr in hyp_ngrams[m - 1]:
count_matched += min(ref_count, hyp_ngrams[m - 1][ngr])
# Catch division by zero
if count_all != 0.0:
recall += count_matched / count_all
return recall / float(self.max_ord)
count_matched[m - 1] += min(
ref_count, hyp_ngrams[m - 1][ngr])
return np.mean(np.divide(
count_matched, count_all, out=np.ones_like(count_all),
where=(count_all != 0)))

def chr_p(self, hyp_ngrams: NGramDicts, ref_ngrams: NGramDicts) -> float:
precision = 0.0
count_all = np.zeros(self.n)
count_matched = np.zeros(self.n)
for m in range(1, self.n + 1):
count_all = 0
count_matched = 0
for ngr in hyp_ngrams[m - 1]:
hyp_count = hyp_ngrams[m - 1][ngr]
count_all += hyp_count
count_all[m - 1] += hyp_count
if ngr in ref_ngrams[m - 1]:
count_matched += min(hyp_count, ref_ngrams[m - 1][ngr])
# Catch division by zero
if count_all != 0.0:
precision += count_matched / count_all

return precision / float(self.max_ord)

def _get_ngrams(self, tokens: List[str], n: int) -> NGramDicts:
if len(tokens) < n:
self.max_ord = len(tokens)

ngr_dicts = []
for m in range(1, n + 1):
ngr_dict = {} # type: Dict[str, int]
for i in range(m, len(tokens)):
ngr = "".join(tokens[i - m:i])
ngr_dict[ngr] = ngr_dict.setdefault(ngr, 0) + 1
ngr_dicts.append(ngr_dict)
return ngr_dicts
count_matched[m - 1] += min(
hyp_count, ref_ngrams[m - 1][ngr])
return np.mean(np.divide(
count_matched, count_all, out=np.ones_like(count_all),
where=(count_all != 0)))


def _get_ngrams(tokens: List[str], n: int) -> NGramDicts:
ngr_dicts = []
for m in range(1, n + 1):
ngr_dict = {} # type: Dict[str, int]
for i in range(m, len(tokens) + 1):
ngr = "".join(tokens[i - m:i])
ngr_dict[ngr] = ngr_dict.setdefault(ngr, 0) + 1
ngr_dicts.append(ngr_dict)
return ngr_dicts


# pylint: disable=invalid-name
Expand Down
54 changes: 54 additions & 0 deletions neuralmonkey/tests/test_chrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3.5


import unittest

from neuralmonkey.evaluators.chrf import ChrFEvaluator, _get_ngrams
from neuralmonkey.tests.test_bleu import DECODED, REFERENCE


TOKENS = ["a", "b", "a"]
NGRAMS = [
{"a": 2, "b": 1},
{"ab": 1, "ba": 1},
{"aba": 1},
{}]

FUNC = ChrFEvaluator()
FUNC_P = FUNC.chr_p
FUNC_R = FUNC.chr_r


class TestChrF(unittest.TestCase):

def test_empty_decoded(self):
# Recall == 0.0
self.assertEqual(FUNC([[] for _ in DECODED], REFERENCE), 0.0)

def test_empty_reference(self):
# Precision == 0.0
self.assertEqual(FUNC([[] for _ in REFERENCE], DECODED), 0.0)

def test_identical(self):
self.assertEqual(FUNC(REFERENCE, REFERENCE), 1.0)

def test_empty_sentence(self):
ref_empty = REFERENCE + [[]]
out_empty = DECODED + [["something"]]
score = FUNC(out_empty, ref_empty)
self.assertAlmostEqual(score, 0.38, delta=10)

def test_chrf(self):
score = FUNC(DECODED, REFERENCE)
self.assertAlmostEqual(score, 0.46, delta=10)

def test_get_ngrams(self):
tokens = ["a", "b", "a"]
ngrams_out = _get_ngrams(tokens, 4)
self.assertEqual(len(ngrams_out), 4)
for i, _ in enumerate(NGRAMS):
self.assertDictEqual(ngrams_out[i], NGRAMS[i])


if __name__ == "__main__":
unittest.main()

0 comments on commit 2c71059

Please sign in to comment.