from benchmarkstt.schema import Schema
import logging
import json
from benchmarkstt.diff import Differ
from benchmarkstt.diff.core import RatcliffObershelp
from benchmarkstt.diff.formatter import format_diff
from benchmarkstt.metrics import Metric
from collections import namedtuple
import editdistance
logger = logging.getLogger(__name__)
OpcodeCounts = namedtuple('OpcodeCounts',
('equal', 'replace', 'insert', 'delete'))
[docs]def traversible(schema, key=None):
if key is None:
key = 'item'
return [word[key] for word in schema]
[docs]def get_opcode_counts(opcodes) -> OpcodeCounts:
counts = OpcodeCounts(0, 0, 0, 0)._asdict()
for tag, alo, ahi, blo, bhi in opcodes:
if tag == 'equal':
counts[tag] += ahi - alo
elif tag == 'insert':
counts[tag] += bhi - blo
elif tag == 'delete':
counts[tag] += ahi - alo
elif tag == 'replace':
ca = ahi - alo
cb = bhi - blo
if ca < cb:
counts['insert'] += cb - ca
counts['replace'] += ca
elif ca > cb:
counts['delete'] += ca - cb
counts['replace'] += cb
else:
counts[tag] += ahi - alo
return OpcodeCounts(counts['equal'], counts['replace'], counts['insert'], counts['delete'])
[docs]def get_differ(a, b, differ_class: Differ):
if differ_class is None:
# differ_class = HuntMcIlroy
differ_class = RatcliffObershelp
return differ_class(traversible(a), traversible(b))
[docs]class WordDiffs(Metric):
"""
Present differences on a per-word basis
:param dialect: Presentation format. Default is 'ansi'.
:example dialect: 'html'
:param differ_class: For future use.
"""
def __init__(self, dialect=None, differ_class: Differ = None):
self._differ_class = differ_class
self._dialect = dialect
[docs] def compare(self, ref: Schema, hyp: Schema):
differ = get_differ(ref, hyp, differ_class=self._differ_class)
a = traversible(ref)
b = traversible(hyp)
return format_diff(a, b, differ.get_opcodes(),
dialect=self._dialect,
preprocessor=lambda x: ' %s' % (' '.join(x),))
[docs]class WER(Metric):
"""
Word Error Rate, basically defined as::
insertions + deletions + substitions
------------------------------------
number of reference words
See: https://en.wikipedia.org/wiki/Word_error_rate
Calculates the WER using one of two algorithms:
[Mode: 'strict' or 'hunt'] Insertions, deletions and
substitutions are identified using the Hunt–McIlroy
diff algorithm. The 'hunt' mode applies 0.5 weight to
insertions and deletions. This algorithm is the one
used internally by Python.
See https://docs.python.org/3/library/difflib.html
[Mode: 'levenshtein'] In the context of WER, Levenshtein
distance is the minimum edit distance computed at the
word level. This implementation uses the Editdistance
c++ implementation by Hiroyuki Tanaka:
https://github.com/aflc/editdistance. See:
https://en.wikipedia.org/wiki/Levenshtein_distance
:param mode: 'strict' (default), 'hunt' or 'levenshtein'.
:param differ_class: For future use.
"""
# WER modes
MODE_STRICT = 'strict'
MODE_HUNT = 'hunt'
MODE_LEVENSHTEIN = 'levenshtein'
DEL_PENALTY = 1
INS_PENALTY = 1
SUB_PENALTY = 1
def __init__(self, mode=None, differ_class: Differ = None):
self._mode = mode
if mode == self.MODE_LEVENSHTEIN:
return
if differ_class is None:
differ_class = RatcliffObershelp
self._differ_class = differ_class
if mode == self.MODE_HUNT:
self.DEL_PENALTY = self.INS_PENALTY = .5
[docs] def compare(self, ref: Schema, hyp: Schema) -> float:
if self._mode == self.MODE_LEVENSHTEIN:
ref_list = [i['item'] for i in ref]
hyp_list = [i['item'] for i in hyp]
total_ref = len(ref_list)
if total_ref == 0:
return 0 if len(hyp_list) == 0 else 1
return editdistance.eval(ref_list, hyp_list) / total_ref
diffs = get_differ(ref, hyp, differ_class=self._differ_class)
counts = get_opcode_counts(diffs.get_opcodes())
changes = counts.replace * self.SUB_PENALTY + \
counts.delete * self.DEL_PENALTY + \
counts.insert * self.INS_PENALTY
total = counts.equal + counts.replace + counts.delete
if total == 0:
return 1 if changes else 0
return changes / total
[docs]class CER(Metric):
"""
Character Error Rate, basically defined as::
insertions + deletions + substitutions
--------------------------------------
number of reference characters
Character error rate, CER, compare the differences
between reference and hypothesis on a character level.
A CER measure is usually lower than WER measure, since
words might differ on only one or a few characters, and
be classified as fully different.
The CER metric might be useful as a perspective on the
WER metric. Word endings might be less relevant if the
text will be preprocessed with stemming, or minor
spelling mistakes might be acceptable in certain
situations. A CER metric might also be used to evaluate
a source (an ASR) which output a stream of characters
rather than words.
Important: The current implementation of the CER metric
ignores whitespace characters. A string like 'aa bb cc'
will first be split into words, ['aa','bb','cc'], and
then merged into a final string for evaluation: 'aabbcc'.
:param mode: 'levenshtein' (default).
:param differ_class: For future use.
"""
# CER modes
MODE_LEVENSHTEIN = 'levenshtein'
def __init__(self, mode=None, differ_class=None):
if mode is None:
mode = self.MODE_LEVENSHTEIN
self._mode = mode
[docs] def compare(self, ref: Schema, hyp: Schema):
ref_str = ''.join([i['item'] for i in ref])
hyp_str = ''.join([i['item'] for i in hyp])
total_ref = len(ref_str)
if self._mode != self.MODE_LEVENSHTEIN:
raise NotImplementedError('CER is only implemented for Levenshtein distance')
if total_ref == 0:
return 0 if len(hyp_str) == 0 else 1
return editdistance.eval(ref_str, hyp_str) / total_ref
[docs]class DiffCounts(Metric):
"""
Get the amount of differences between reference and hypothesis
"""
MODE_LEVENSHTEIN = 'levenshtein'
def __init__(self, mode=None, differ_class: Differ = None):
if differ_class is None:
differ_class = RatcliffObershelp
self._differ_class = differ_class
self._mode = mode
[docs] def compare(self, ref: Schema, hyp: Schema) -> OpcodeCounts:
if self._mode == self.MODE_LEVENSHTEIN:
raise NotImplementedError('diffcounts is not implemented for Levenshtein distance')
diffs = get_differ(ref, hyp, differ_class=self._differ_class)
return get_opcode_counts(diffs.get_opcodes())
[docs]class BEER(Metric):
"""
Bag of Entities Error Rate, BEER, is defined as the error rate per entity with a bag of words approach::
abs(ne_hyp - ne_ref)
BEER (entity) = ----------------------
ne_ref
- ne_hyp = number of detections of the entity in the hypothesis file
- ne_ref = number of detections of the entity in the reference file
The WA_BEER for a set of N entities is defined as the weighted average of the BEER for the set of
entities::
WA_BEER ([entity_1, ... entity_N) = w_1*BEER (entity_1)*L_1/L + ... + w_N*BEER (entity_N))*L_N/L
which is equivalent to::
w_1*abs(ne_hyp_1 - ne_ref_1) + ... + w_N*abs(ne_hyp_N - ne_ref_N)
WA_BEER ([entity_1, ... entity_N) = ------------------------------------------------------------------
L
- L_1 = number of occurrences of entity 1 in the reference document
- L = L_1 + ... + L_N
the weights being normalised by the tool:
- w_1 + ... + w_N = 1
The input file defines the list of entities and the weight per entity, w_n. It is processed as a json file with the
following structure::
{ "entity_1":W_1, "entity_2" : W_2, "entity_3" :W_3 .. }
W_n being the non-normalized weight, the normalization of the weights is performed by the tool as::
W_n
w_n = ---------------
W_1 + ... +W_N
The minimum value for weight being 0.
"""
def __init__(self, entities_file=None):
"""
"""
self._error_message = None
self._entities = None
if entities_file is not None:
try:
with open(entities_file) as f:
data = json.load(f)
self._entities = list(data.keys())
weight = list(data.values())
self.set_weight(weight)
except (IOError, json.decoder.JSONDecodeError) as e:
self._error_message = str(e)
return
[docs] def get_weight(self):
return self._weight
[docs] def set_weight(self, weight):
weight = [0 if w < 0 else w for w in weight]
sw = sum(weight)
if sw > 0:
self._weight = [w / sw for w in weight]
# if the sum of the weights is null, the wa_beer is null
else:
self._weight = weight
[docs] def get_entities(self):
return self._entities
[docs] def set_entities(self, entities):
self._entities = entities
# find the position of one entity
# an entity can contain more than one word
@staticmethod
def __find_pattern(search_list, complex_entity):
entity = ''.join(complex_entity).split(' ')
le = len(entity)
# complex_entity = [entity1 entity2 ...]
# the cursor sweep complex_entity to find consecutive entities entity1 ... entity2
cursor = 0
idx_found = []
for idx_sl, elt in enumerate(search_list):
if elt == entity[cursor]:
cursor += 1
if cursor == le:
idx_found.append(list(range(idx_sl - le + 1, idx_sl - le + 1 + le)))
cursor = 0
else:
cursor = 0
return idx_found
# generate a list containing the detected entities in list_parsed
def __generate_list_entity(self, list_parsed):
list_entity = []
index_entities = []
entities = self._entities
for entity in entities:
index_entity = self.__find_pattern(list_parsed, entity)
index_entities.extend(index_entity)
# sort on the position of the first part of the entity
index_entities.sort(key=lambda l: l[0])
# copy-past the entity found in the list
for k_list in index_entities:
list_entity.append(' '.join(list_parsed[k_list[0]:k_list[-1] + 1]))
return list_entity
# computes the BEER
[docs] def compute_beer(self, list_hypothesis_entity, list_reference_entity):
beer = {}
beer_av = 0
entities = self._entities
for idx, entity in enumerate(entities):
count_hypothesis = list_hypothesis_entity.count(entity)
count_ref = list_reference_entity.count(entity)
beer_entity = 0
if count_ref != 0:
beer_entity = round(abs(count_ref - count_hypothesis) / count_ref, 3)
# accumulate the distance per entity
beer_av += abs(count_ref - count_hypothesis) * self._weight[idx]
beer[entity] = {'beer': beer_entity, 'occurrence_ref': count_ref}
l_ref = len(list_reference_entity)
if l_ref > 0:
beer_av = round(beer_av / l_ref, 3)
else:
beer_av = 0
beer['w_av_beer'] = {'beer': beer_av, 'occurrence_ref': l_ref}
return beer
[docs] def compare(self, ref: Schema, hyp: Schema):
if self._entities is None:
if self._error_message:
return {'Error': self._error_message}
else:
return {'Error': 'Missing .json input file'}
# get the list of reference and hypothesis
ref_list = [i['item'] for i in ref]
hyp_list = [i['item'] for i in hyp]
# extract the entities
list_hypothesis_entity = self.__generate_list_entity(hyp_list)
list_reference_entity = self.__generate_list_entity(ref_list)
# compute the score
wer_entity = self.compute_beer(list_hypothesis_entity, list_reference_entity)
return wer_entity
# For a future version
# class ExternalMetric(LoadObjectProxy, Base):
# """
# Automatically loads an external metric class.
#
# :param name: The name of the metric to load (eg. mymodule.metrics.MyOwnMetricClass)
# """