| import copy |
| import regex |
| import string |
| import unicodedata |
| from typing import List |
| from collections import Counter |
| from ..core.logging import logger |
|
|
| |
| |
| |
|
|
| |
| def normalize_answer(s : str) -> str: |
| def remove_articles(text: str) -> str: |
| return regex.sub(r'\b(a|an|the)\b', ' ', text) |
| def white_space_fix(text: str) -> str: |
| return ' '.join(text.split()) |
| def remove_punc(text: str) -> str: |
| exclude = set(string.punctuation) |
| return ''.join(ch for ch in text if ch not in exclude) |
| return white_space_fix(remove_articles(remove_punc(s.lower()))) |
|
|
| def exact_match_score(prediction : str, ground_truth : str) -> float: |
| assert isinstance(ground_truth, str), f"ground_truth must be a string, but got {type(ground_truth)}" |
| return float(normalize_answer(prediction) == normalize_answer(ground_truth)) |
|
|
| def ems(prediction : str, ground_truths : List[str]) -> float: |
| assert isinstance(ground_truths, list), f"ground_truths must be a list, but got {type(ground_truths)}" |
| return max([exact_match_score(prediction, gt) for gt in ground_truths]) |
|
|
| |
| def f1_score(prediction : str, ground_truth: str) -> float: |
| assert isinstance(ground_truth, str), f"ground_truth must be a string, but got {type(ground_truth)}" |
| normalized_prediction = normalize_answer(prediction) |
| normalized_ground_truth = normalize_answer(ground_truth) |
| ZERO_METRIC = (0, 0, 0) |
| if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: |
| return ZERO_METRIC[0] |
| if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: |
| return ZERO_METRIC[0] |
| prediction_tokens = normalized_prediction.split() |
| ground_truth_tokens = normalized_ground_truth.split() |
| common = Counter(prediction_tokens) & Counter(ground_truth_tokens) |
| num_same = sum(common.values()) |
| if num_same == 0: |
| return ZERO_METRIC[0] |
| precision = 1.0 * num_same / len(prediction_tokens) |
| recall = 1.0 * num_same / len(ground_truth_tokens) |
| f1 = (2 * precision * recall) / (precision + recall) |
| |
| return f1 |
|
|
|
|
| def _normalize(text): |
| return unicodedata.normalize('NFD', text) |
|
|
| class Tokenizer(object): |
| """Base tokenizer class. |
| Tokenizers implement tokenize, which should return a Tokens class. |
| """ |
| def tokenize(self, text): |
| raise NotImplementedError |
|
|
| def shutdown(self): |
| pass |
|
|
| def __del__(self): |
| self.shutdown() |
|
|
|
|
| class Tokens(object): |
| """A class to represent a list of tokenized text.""" |
|
|
| TEXT = 0 |
| TEXT_WS = 1 |
| SPAN = 2 |
| POS = 3 |
| LEMMA = 4 |
| NER = 5 |
|
|
| def __init__(self, data, annotators, opts=None): |
| self.data = data |
| self.annotators = annotators |
| self.opts = opts or {} |
|
|
| def __len__(self): |
| """The number of tokens.""" |
| return len(self.data) |
|
|
| def slice(self, i=None, j=None): |
| """Return a view of the list of tokens from [i, j).""" |
| new_tokens = copy.copy(self) |
| new_tokens.data = self.data[i:j] |
| return new_tokens |
|
|
| def untokenize(self): |
| """Returns the original text (with whitespace reinserted).""" |
| return "".join([t[self.TEXT_WS] for t in self.data]).strip() |
|
|
| def words(self, uncased=False): |
| """Returns a list of the text of each token |
| |
| Args: |
| uncased: lower cases text |
| """ |
| if uncased: |
| return [t[self.TEXT].lower() for t in self.data] |
| else: |
| return [t[self.TEXT] for t in self.data] |
|
|
| def offsets(self): |
| """Returns a list of [start, end) character offsets of each token.""" |
| return [t[self.SPAN] for t in self.data] |
|
|
| def pos(self): |
| """Returns a list of part-of-speech tags of each token. |
| Returns None if this annotation was not included. |
| """ |
| if "pos" not in self.annotators: |
| return None |
| return [t[self.POS] for t in self.data] |
|
|
| def lemmas(self): |
| """Returns a list of the lemmatized text of each token. |
| Returns None if this annotation was not included. |
| """ |
| if "lemma" not in self.annotators: |
| return None |
| return [t[self.LEMMA] for t in self.data] |
|
|
| def entities(self): |
| """Returns a list of named-entity-recognition tags of each token. |
| Returns None if this annotation was not included. |
| """ |
| if "ner" not in self.annotators: |
| return None |
| return [t[self.NER] for t in self.data] |
|
|
| def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): |
| """Returns a list of all ngrams from length 1 to n. |
| |
| Args: |
| n: upper limit of ngram length |
| uncased: lower cases text |
| filter_fn: user function that takes in an ngram list and returns |
| True or False to keep or not keep the ngram |
| as_string: return the ngram as a string vs list |
| """ |
|
|
| def _skip(gram): |
| if not filter_fn: |
| return False |
| return filter_fn(gram) |
|
|
| words = self.words(uncased) |
| ngrams = [ |
| (s, e + 1) |
| for s in range(len(words)) |
| for e in range(s, min(s + n, len(words))) |
| if not _skip(words[s : e + 1]) |
| ] |
|
|
| |
| if as_strings: |
| ngrams = ["{}".format(" ".join(words[s:e])) for (s, e) in ngrams] |
|
|
| return ngrams |
|
|
| def entity_groups(self): |
| """Group consecutive entity tokens with the same NER tag.""" |
| entities = self.entities() |
| if not entities: |
| return None |
| non_ent = self.opts.get("non_ent", "O") |
| groups = [] |
| idx = 0 |
| while idx < len(entities): |
| ner_tag = entities[idx] |
| |
| if ner_tag != non_ent: |
| |
| start = idx |
| while idx < len(entities) and entities[idx] == ner_tag: |
| idx += 1 |
| groups.append((self.slice(start, idx).untokenize(), ner_tag)) |
| else: |
| idx += 1 |
| return groups |
|
|
|
|
| class SimpleTokenizer(Tokenizer): |
| ALPHA_NUM = r"[\p{L}\p{N}\p{M}]+" |
| NON_WS = r"[^\p{Z}\p{C}]" |
|
|
| def __init__(self, **kwargs): |
| """ |
| Args: |
| annotators: None or empty set (only tokenizes). |
| """ |
| self._regexp = regex.compile( |
| "(%s)|(%s)" % (self.ALPHA_NUM, self.NON_WS), |
| flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE, |
| ) |
| if len(kwargs.get("annotators", {})) > 0: |
| logger.warning( |
| "%s only tokenizes! Skipping annotators: %s" % (type(self).__name__, kwargs.get("annotators")) |
| ) |
| self.annotators = set() |
|
|
| def tokenize(self, text): |
| data = [] |
| matches = [m for m in self._regexp.finditer(text)] |
| for i in range(len(matches)): |
| |
| token = matches[i].group() |
|
|
| |
| span = matches[i].span() |
| start_ws = span[0] |
| if i + 1 < len(matches): |
| end_ws = matches[i + 1].span()[0] |
| else: |
| end_ws = span[1] |
|
|
| |
| data.append( |
| ( |
| token, |
| text[start_ws:end_ws], |
| span, |
| ) |
| ) |
| return Tokens(data, self.annotators) |
|
|
|
|
| def regex_match(text, pattern): |
| """Test if a regex pattern is contained within a text.""" |
| try: |
| pattern = regex.compile(pattern, flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE) |
| except BaseException: |
| return False |
| return pattern.search(text) is not None |
|
|
|
|
| |
| def has_answer(answers, text, match_type="string") -> bool: |
|
|
| """Check if the text contains an answer string. |
| If `match_type` is string, token matching is done between the text and answer. |
| If `match_type` is regex, we search the whole text with the regex. |
| """ |
|
|
| text = _normalize(text) |
|
|
| tokenizer = SimpleTokenizer() |
|
|
| if match_type == "string": |
| |
| text = tokenizer.tokenize(text).words(uncased=True) |
|
|
| for single_answer in answers: |
| single_answer = _normalize(single_answer) |
| single_answer = tokenizer.tokenize(single_answer) |
| single_answer = single_answer.words(uncased=True) |
|
|
| for i in range(0, len(text) - len(single_answer) + 1): |
| if single_answer == text[i : i + len(single_answer)]: |
| return True |
|
|
| elif match_type == "regex": |
| |
| for single_answer in answers: |
| single_answer = _normalize(single_answer) |
| if regex_match(text, single_answer): |
| return True |
| |
| return False |
|
|
| def acc_score(prediction : str, ground_truths : List[str]) -> float: |
| assert isinstance(ground_truths, list), f"ground_truths must be a list, but got {type(ground_truths)}" |
| return float(has_answer(answers=ground_truths, text=prediction, match_type="string")) |