| from itertools import chain |
| from typing import List, Optional, Tuple |
|
|
| import numpy as np |
| from transformers import Pipeline |
|
|
|
|
| class RefSegPipeline(Pipeline): |
|
|
| labels = [ |
| 'publisher', 'source', 'url', 'other', 'author', 'editor', 'lpage', |
| 'volume', 'year', 'issue', 'title', 'fpage', 'edition' |
| ] |
| iob_labels = list(chain.from_iterable([['B-' + x, 'I-' + x] for x in labels])) + ['O'] |
| id2seg = {k: v for k, v in enumerate(iob_labels)} |
| id2ref = {k: v for k, v in enumerate(['B-ref', 'I-ref', ])} |
| is_split_into_words = False |
|
|
| def _sanitize_parameters(self, **kwargs): |
| if "id2seg" in kwargs: |
| self.id2seg = kwargs["id2seg"] |
| if "id2ref" in kwargs: |
| self.id2ref = kwargs["id2ref"] |
|
|
| return {}, {}, {} |
|
|
| def preprocess(self, sentence, offset_mapping=None, split_into_words=True): |
| tokens = sentence |
| if split_into_words: |
| split_sentence = self.tokenizer.pre_tokenizer.pre_tokenize_str(sentence) |
| tokens, offsets = zip(*split_sentence) |
| model_inputs = self.tokenizer( |
| tokens, |
| return_offsets_mapping=True, |
| padding='max_length', |
| truncation=True, |
| max_length=512, |
| return_tensors="pt", |
| return_special_tokens_mask=True, |
| return_overflowing_tokens=True, |
| is_split_into_words=split_into_words |
| ) |
|
|
| if offset_mapping: |
| model_inputs["offset_mapping"] = offset_mapping |
|
|
| model_inputs["sentence"] = sentence |
| model_inputs["token_offsets"] = offsets |
|
|
| return model_inputs |
|
|
|
|
| def _forward(self, model_inputs): |
| special_tokens_mask = model_inputs.pop("special_tokens_mask") |
| offset_mapping = model_inputs.pop("offset_mapping", None) |
| sentence = model_inputs.pop("sentence") |
| token_offsets = model_inputs.pop("token_offsets") |
| overflow_mapping = model_inputs.pop("overflow_to_sample_mapping") |
| if self.framework == "tf": |
| logits = self.model(model_inputs.data)[0] |
| else: |
| logits = self.model(**model_inputs)[0] |
|
|
| return { |
| "logits": logits, |
| "special_tokens_mask": special_tokens_mask, |
| "offset_mapping": offset_mapping, |
| "overflow_mapping": overflow_mapping, |
| "sentence": sentence, |
| "token_offsets": token_offsets, |
| **model_inputs, |
| } |
|
|
| def postprocess(self, model_outputs): |
| |
| ignore_labels = ["O"] |
| logits_seg = model_outputs["logits"][0].numpy() |
| logits_ref = model_outputs["logits"][1].numpy() |
| sentence = model_outputs["sentence"] |
| token_offsets = model_outputs["token_offsets"] |
| input_ids = model_outputs["input_ids"] |
| special_tokens_mask = model_outputs["special_tokens_mask"] |
|
|
| offset_mapping = model_outputs["offset_mapping"] if model_outputs["offset_mapping"] is not None else None |
|
|
| maxes_seg = np.max(logits_seg, axis=-1, keepdims=True) |
| shifted_exp_seg = np.exp(logits_seg - maxes_seg) |
| scores_seg = shifted_exp_seg / shifted_exp_seg.sum(axis=-1, keepdims=True) |
|
|
| maxes_ref = np.max(logits_ref, axis=-1, keepdims=True) |
| shifted_exp_ref = np.exp(logits_ref - maxes_ref) |
| scores_ref = shifted_exp_ref / shifted_exp_ref.sum(axis=-1, keepdims=True) |
|
|
| pre_entities = self.gather_pre_entities( |
| input_ids, scores_seg, scores_ref, offset_mapping, special_tokens_mask |
| ) |
| grouped_entities = self.aggregate(pre_entities, token_offsets, sentence) |
|
|
| cleaned_groups = [] |
| for group in grouped_entities: |
| entities = [ |
| entity |
| for entity in group |
| if entity.get("entity_group", None) not in ignore_labels |
| ] |
| if entities: |
| cleaned_groups.append(entities) |
| return { |
| "number_of_references": len(cleaned_groups), |
| "references": cleaned_groups, |
| } |
|
|
| def gather_pre_entities( |
| self, |
| input_ids: np.ndarray, |
| scores_seg: np.ndarray, |
| scores_ref: np.ndarray, |
| offset_mappings: Optional[List[Tuple[int, int]]], |
| special_tokens_masks: np.ndarray, |
| ) -> List[dict]: |
| """Fuse various numpy arrays into dicts with all the information needed for aggregation""" |
| pre_entities = [] |
| for idx_list, (input_id, offset_mapping, special_tokens_mask, s_seg, s_ref) in enumerate( |
| zip(input_ids, offset_mappings, special_tokens_masks, scores_seg, scores_ref)): |
| for idx, iid in enumerate(input_id): |
|
|
| if special_tokens_mask[idx]: |
| continue |
|
|
| word = self.tokenizer.convert_ids_to_tokens(int(input_id[idx])) |
| if offset_mapping is not None: |
| start_ind, end_ind = offset_mapping[idx] |
| if not isinstance(start_ind, int): |
| if self.framework == "pt": |
| start_ind = start_ind.item() |
| end_ind = end_ind.item() |
|
|
| is_subword = not word.startswith('\u2581') |
|
|
| if int(input_id[idx]) == self.tokenizer.unk_token_id: |
| is_subword = False |
| else: |
| start_ind = None |
| end_ind = None |
| is_subword = False |
|
|
| pre_entity = { |
| "word": word, |
| "scores_seg": s_seg[idx], |
| "scores_ref": s_ref[idx], |
| "start": start_ind, |
| "end": end_ind, |
| "index": idx, |
| "is_subword": is_subword, |
| } |
| pre_entities.append(pre_entity) |
| return pre_entities |
|
|
| def aggregate(self, pre_entities: List[dict], token_offsets: List[tuple], sentence: str) -> List[dict]: |
| entities = self.aggregate_words(pre_entities, token_offsets) |
|
|
| return self.group_entities(entities, sentence) |
|
|
| def aggregate_word(self, entities: List[dict], token_offset: tuple) -> dict: |
| word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities]) |
| scores_seg = entities[0]["scores_seg"] |
| idx_seg = scores_seg.argmax() |
| score_seg = scores_seg[idx_seg] |
| entity_seg = self.id2seg[idx_seg] |
|
|
| scores_ref = np.stack([entity["scores_ref"] for entity in entities]) |
| indices_ref = scores_ref.argmax(axis=1) |
| idx_ref = 1 if all(indices_ref) else 0 |
| entity_ref = self.id2ref[idx_ref] |
|
|
| new_entity = { |
| "entity_seg": entity_seg, |
| "score_seg": score_seg, |
| "entity_ref": entity_ref, |
| "word": word, |
| "start": entities[0]["start"] + token_offset[0], |
| "end": entities[-1]["end"] + token_offset[0], |
| } |
| return new_entity |
|
|
| def aggregate_words(self, entities: List[dict], token_offsets: List[tuple]) -> List[dict]: |
| """ |
| Override tokens from a given word that disagree to force agreement on word boundaries. |
| Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft| |
| company| B-ENT I-ENT |
| """ |
| word_entities = [] |
| word_group = None |
| idx = 0 |
| for entity in entities: |
| if word_group is None: |
| word_group = [entity] |
| elif entity["is_subword"]: |
| word_group.append(entity) |
| else: |
| word_entities.append(self.aggregate_word(word_group, token_offsets[idx])) |
| word_group = [entity] |
| idx += 1 |
| word_entities.append(self.aggregate_word(word_group, token_offsets[idx])) |
| idx += 1 |
| return word_entities |
|
|
| def group_entities(self, entities: List[dict], sentence: str) -> List[dict]: |
| """ |
| Find and group together the adjacent tokens with the same entity predicted. |
| Args: |
| entities (`dict`): The entities predicted by the pipeline. |
| """ |
| entity_chunk = [] |
| entity_chunk_disagg = [] |
|
|
| for entity in entities: |
| if not entity_chunk_disagg: |
| entity_chunk_disagg.append(entity) |
| continue |
|
|
| bi_ref, tag_ref = self.get_tag(entity["entity_ref"]) |
| last_bi_ref, last_tag_ref = self.get_tag(entity_chunk_disagg[-1]["entity_ref"]) |
|
|
| if tag_ref == last_tag_ref and bi_ref != "B": |
| entity_chunk_disagg.append(entity) |
| else: |
| entity_chunk.append(entity_chunk_disagg) |
| entity_chunk_disagg = [entity] |
|
|
| if entity_chunk_disagg: |
| entity_chunk.append(entity_chunk_disagg) |
|
|
| entity_chunks_all = [] |
|
|
| for chunk in entity_chunk: |
|
|
| entity_groups = [] |
| entity_group_disagg = [] |
|
|
| for entity in chunk: |
| if not entity_group_disagg: |
| entity_group_disagg.append(entity) |
| continue |
|
|
| bi_seg, tag_seg = self.get_tag(entity["entity_seg"]) |
| last_bi_seg, last_tag_seg = self.get_tag(entity_group_disagg[-1]["entity_seg"]) |
|
|
| if tag_seg == last_tag_seg and bi_seg != "B": |
| entity_group_disagg.append(entity) |
| else: |
| entity_groups.append(self.group_sub_entities(entity_group_disagg, sentence)) |
| entity_group_disagg = [entity] |
|
|
| if entity_group_disagg: |
| entity_groups.append(self.group_sub_entities(entity_group_disagg, sentence)) |
|
|
| entity_chunks_all.append(entity_groups) |
|
|
| return entity_chunks_all |
|
|
| def group_sub_entities(self, entities: List[dict], sentence: str) -> dict: |
| """ |
| Group together the adjacent tokens with the same entity predicted. |
| Args: |
| entities (`dict`): The entities predicted by the pipeline. |
| """ |
| entity = entities[0]["entity_seg"].split("-")[-1] |
| scores = np.nanmean([entity["score_seg"] for entity in entities]) |
| start = min([entity["start"] for entity in entities]) |
| end = max([entity["end"] for entity in entities]) |
| word = sentence[start:end] |
|
|
|
|
|
|
| entity_group = { |
| "entity_group": entity, |
| "score": np.mean(scores), |
| "word": word, |
| "start": entities[0]["start"], |
| "end": entities[-1]["end"], |
| } |
| return entity_group |
|
|
| def get_tag(self, entity_name: str) -> Tuple[str, str]: |
| if entity_name.startswith("B-"): |
| bi = "B" |
| tag = entity_name[2:] |
| elif entity_name.startswith("I-"): |
| bi = "I" |
| tag = entity_name[2:] |
| else: |
| bi = "I" |
| tag = entity_name |
| return bi, tag |
|
|