diff --git a/bin/rouge_cmd.py b/bin/rouge_cmd.py index 56593e9..9def2ff 100755 --- a/bin/rouge_cmd.py +++ b/bin/rouge_cmd.py @@ -10,6 +10,8 @@ def main(): parser.add_argument('-f', '--file', help="File mode", action='store_true') parser.add_argument('-a', '--avg', help="Average mode", action='store_true') + parser.add_argument('--ignore_empty', action='store_true', + help="Ignore empty hypothesis") parser.add_argument('hypothesis', type=str, help='Text of file path') parser.add_argument('reference', type=str, help='Text or file path') @@ -20,7 +22,8 @@ def main(): assert(os.path.isfile(ref)) files_rouge = FilesRouge(hyp, ref) - scores = files_rouge.get_scores(avg=args.avg) + scores = files_rouge.get_scores(avg=args.avg, + ignore_empty=args.ignore_empty) print(json.dumps(scores, indent=2)) else: diff --git a/rouge/rouge.py b/rouge/rouge.py index a6e8fa4..5ecc135 100644 --- a/rouge/rouge.py +++ b/rouge/rouge.py @@ -31,7 +31,7 @@ def line_count(path): self.ref_path = ref_path self.batch_lines = batch_lines - def get_scores(self, avg=False): + def get_scores(self, avg=False, ignore_empty=False): """Calculate ROUGE scores between each pair of lines (hyp_file[i], ref_file[i]). Args: @@ -46,7 +46,8 @@ def get_scores(self, avg=False): with io.open(ref_path, encoding="utf-8", mode="r") as ref_file: refs = [line[:-1] for line in ref_file] - return self.rouge.get_scores(hyps, refs, avg=avg) + return self.rouge.get_scores(hyps, refs, avg=avg, + ignore_empty=ignore_empty) class Rouge: @@ -74,10 +75,16 @@ def __init__(self, metrics=None, stats=None): if s not in Rouge.AVAILABLE_STATS: raise ValueError("Unknown stat '%s'" % s) - def get_scores(self, hyps, refs, avg=False): + def get_scores(self, hyps, refs, avg=False, ignore_empty=False): if isinstance(hyps, six.string_types): hyps, refs = [hyps], [refs] + if ignore_empty: + # Filter out hyps of 0 length + hyps_and_refs = zip(hyps, refs) + hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] + hyps, refs = zip(*hyps_and_refs) + assert(type(hyps) == type(refs)) assert(len(hyps) == len(refs))