| | from __future__ import annotations |
| |
|
| | import csv |
| | import json |
| | import os |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import NamedTuple |
| |
|
| | import numpy as np |
| | import torch |
| | import spacy |
| | from marisa_trie import Trie |
| | from transformers import BatchEncoding, BertTokenizer, PreTrainedTokenizerBase |
| |
|
| | NONE_ID = "<None>" |
| |
|
| |
|
| | @dataclass |
| | class Mention: |
| | kb_id: str | None |
| | text: str |
| | start: int |
| | end: int |
| | link_count: int | None |
| | total_link_count: int | None |
| | doc_count: int | None |
| |
|
| | @property |
| | def span(self) -> tuple[int, int]: |
| | return self.start, self.end |
| |
|
| | @property |
| | def link_prob(self) -> float | None: |
| | if self.doc_count is None or self.total_link_count is None: |
| | return None |
| | elif self.doc_count > 0: |
| | return min(1.0, self.total_link_count / self.doc_count) |
| | else: |
| | return 0.0 |
| |
|
| | @property |
| | def prior_prob(self) -> float | None: |
| | if self.link_count is None or self.total_link_count is None: |
| | return None |
| | elif self.total_link_count > 0: |
| | return min(1.0, self.link_count / self.total_link_count) |
| | else: |
| | return 0.0 |
| |
|
| | def __repr__(self): |
| | return f"<Mention {self.text} -> {self.kb_id}>" |
| |
|
| |
|
| | def get_tokenizer(language: str) -> spacy.tokenizer.Tokenizer: |
| | language_obj = spacy.blank(language) |
| | return language_obj.tokenizer |
| |
|
| |
|
| | class DictionaryEntityLinker: |
| | def __init__( |
| | self, |
| | name_trie: Trie, |
| | kb_id_trie: Trie, |
| | data: np.ndarray, |
| | offsets: np.ndarray, |
| | max_mention_length: int, |
| | case_sensitive: bool, |
| | min_link_prob: float | None, |
| | min_prior_prob: float | None, |
| | min_link_count: int | None, |
| | ): |
| | self._name_trie = name_trie |
| | self._kb_id_trie = kb_id_trie |
| | self._data = data |
| | self._offsets = offsets |
| | self._max_mention_length = max_mention_length |
| | self._case_sensitive = case_sensitive |
| |
|
| | self._min_link_prob = min_link_prob |
| | self._min_prior_prob = min_prior_prob |
| | self._min_link_count = min_link_count |
| |
|
| | self._tokenizer = get_tokenizer("en") |
| |
|
| | @staticmethod |
| | def load( |
| | data_dir: str, |
| | min_link_prob: float | None = None, |
| | min_prior_prob: float | None = None, |
| | min_link_count: int | None = None, |
| | ) -> "DictionaryEntityLinker": |
| | data = np.load(os.path.join(data_dir, "data.npy")) |
| | offsets = np.load(os.path.join(data_dir, "offsets.npy")) |
| | name_trie = Trie() |
| | name_trie.load(os.path.join(data_dir, "name.trie")) |
| | kb_id_trie = Trie() |
| | kb_id_trie.load(os.path.join(data_dir, "kb_id.trie")) |
| |
|
| | with open(os.path.join(data_dir, "config.json")) as config_file: |
| | config = json.load(config_file) |
| |
|
| | if min_link_prob is None: |
| | min_link_prob = config.get("min_link_prob", None) |
| |
|
| | if min_prior_prob is None: |
| | min_prior_prob = config.get("min_prior_prob", None) |
| |
|
| | if min_link_count is None: |
| | min_link_count = config.get("min_link_count", None) |
| |
|
| | return DictionaryEntityLinker( |
| | name_trie=name_trie, |
| | kb_id_trie=kb_id_trie, |
| | data=data, |
| | offsets=offsets, |
| | max_mention_length=config["max_mention_length"], |
| | case_sensitive=config["case_sensitive"], |
| | min_link_prob=min_link_prob, |
| | min_prior_prob=min_prior_prob, |
| | min_link_count=min_link_count, |
| | ) |
| |
|
| | def detect_mentions(self, text: str) -> list[Mention]: |
| | tokens = self._tokenizer(text) |
| | end_offsets = frozenset(token.idx + len(token) for token in tokens) |
| | if not self._case_sensitive: |
| | text = text.lower() |
| |
|
| | ret = [] |
| | cur = 0 |
| | for token in tokens: |
| | start = token.idx |
| | if cur > start: |
| | continue |
| |
|
| | for prefix in sorted( |
| | self._name_trie.prefixes(text[start : start + self._max_mention_length]), |
| | key=len, |
| | reverse=True, |
| | ): |
| | end = start + len(prefix) |
| | if end in end_offsets: |
| | matched = False |
| | mention_idx = self._name_trie[prefix] |
| | data_start, data_end = self._offsets[mention_idx : mention_idx + 2] |
| | for item in self._data[data_start:data_end]: |
| | if item.size == 4: |
| | kb_idx, link_count, total_link_count, doc_count = item |
| | elif item.size == 1: |
| | (kb_idx,) = item |
| | link_count, total_link_count, doc_count = None, None, None |
| | else: |
| | raise ValueError("Unexpected data array format") |
| |
|
| | mention = Mention( |
| | kb_id=self._kb_id_trie.restore_key(kb_idx), |
| | text=prefix, |
| | start=start, |
| | end=end, |
| | link_count=link_count, |
| | total_link_count=total_link_count, |
| | doc_count=doc_count, |
| | ) |
| | if item.size == 1 or ( |
| | mention.link_prob >= self._min_link_prob |
| | and mention.prior_prob >= self._min_prior_prob |
| | and mention.link_count >= self._min_link_count |
| | ): |
| | ret.append(mention) |
| |
|
| | matched = True |
| |
|
| | if matched: |
| | cur = end |
| | break |
| |
|
| | return ret |
| |
|
| | def detect_mentions_batch(self, texts: list[str]) -> list[list[Mention]]: |
| | return [self.detect_mentions(text) for text in texts] |
| |
|
| | def save(self, data_dir: str) -> None: |
| | """ |
| | Save the entity linker data to the specified directory. |
| | |
| | Args: |
| | data_dir: Directory to save the entity linker data |
| | """ |
| | os.makedirs(data_dir, exist_ok=True) |
| |
|
| | |
| | np.save(os.path.join(data_dir, "data.npy"), self._data) |
| | np.save(os.path.join(data_dir, "offsets.npy"), self._offsets) |
| |
|
| | |
| | self._name_trie.save(os.path.join(data_dir, "name.trie")) |
| | self._kb_id_trie.save(os.path.join(data_dir, "kb_id.trie")) |
| |
|
| | |
| | with open(os.path.join(data_dir, "config.json"), "w") as config_file: |
| | json.dump( |
| | { |
| | "max_mention_length": self._max_mention_length, |
| | "case_sensitive": self._case_sensitive, |
| | "min_link_prob": self._min_link_prob, |
| | "min_prior_prob": self._min_prior_prob, |
| | "min_link_count": self._min_link_count, |
| | }, |
| | config_file, |
| | ) |
| |
|
| |
|
| | def load_tsv_entity_vocab(file_path: str) -> dict[str, int]: |
| | vocab = {} |
| | with open(file_path, "r", encoding="utf-8") as file: |
| | reader = csv.reader(file, delimiter="\t") |
| | for row in reader: |
| | vocab[row[0]] = int(row[1]) |
| | return vocab |
| |
|
| |
|
| | def save_tsv_entity_vocab(file_path: str, entity_vocab: dict[str, int]) -> None: |
| | """ |
| | Save entity vocabulary to a TSV file. |
| | |
| | Args: |
| | file_path: Path to save the entity vocabulary |
| | entity_vocab: Entity vocabulary to save |
| | """ |
| | os.makedirs(os.path.dirname(file_path), exist_ok=True) |
| | with open(file_path, "w", encoding="utf-8") as f: |
| | writer = csv.writer(f, delimiter="\t") |
| | for entity_id, idx in entity_vocab.items(): |
| | writer.writerow([entity_id, idx]) |
| |
|
| |
|
| | class _Entity(NamedTuple): |
| | entity_id: int |
| | start: int |
| | end: int |
| |
|
| | @property |
| | def length(self) -> int: |
| | return self.end - self.start |
| |
|
| |
|
| | def preprocess_text( |
| | text: str, |
| | mentions: list[Mention] | None, |
| | title: str | None, |
| | title_mentions: list[Mention] | None, |
| | tokenizer: PreTrainedTokenizerBase, |
| | entity_vocab: dict[str, int], |
| | ) -> dict[str, list[int]]: |
| | tokens = [] |
| | entity_ids = [] |
| | entity_position_ids = [] |
| | if title is not None: |
| | if title_mentions is None: |
| | title_mentions = [] |
| |
|
| | title_tokens, title_entities = _tokenize_text_with_mentions(title, title_mentions, tokenizer, entity_vocab) |
| | tokens += title_tokens + [tokenizer.sep_token] |
| | for entity in title_entities: |
| | entity_ids.append(entity.entity_id) |
| | entity_position_ids.append(list(range(entity.start, entity.end))) |
| |
|
| | if mentions is None: |
| | mentions = [] |
| |
|
| | entity_offset = len(tokens) |
| | text_tokens, text_entities = _tokenize_text_with_mentions(text, mentions, tokenizer, entity_vocab) |
| | tokens += text_tokens |
| | for entity in text_entities: |
| | entity_ids.append(entity.entity_id) |
| | entity_position_ids.append(list(range(entity.start + entity_offset, entity.end + entity_offset))) |
| |
|
| | input_ids = tokenizer.convert_tokens_to_ids(tokens) |
| |
|
| | return { |
| | "input_ids": input_ids, |
| | "entity_ids": entity_ids, |
| | "entity_position_ids": entity_position_ids, |
| | } |
| |
|
| |
|
| | def _tokenize_text_with_mentions( |
| | text: str, |
| | mentions: list[Mention], |
| | tokenizer: PreTrainedTokenizerBase, |
| | entity_vocab: dict[str, int], |
| | ) -> tuple[list[str], list[_Entity]]: |
| | """ |
| | Tokenize text while preserving mention boundaries and mapping entities. |
| | |
| | Args: |
| | text: Input text to tokenize |
| | mentions: List of detected mentions in the text |
| | tokenizer: Pre-trained tokenizer to use for tokenization |
| | entity_vocab: Mapping from entity KB IDs to entity vocabulary indices |
| | |
| | Returns: |
| | Tuple containing: |
| | - List of tokens from the tokenized text |
| | - List of _Entity objects with entity IDs and token positions |
| | """ |
| | target_mentions = [mention for mention in mentions if mention.kb_id is not None and mention.kb_id in entity_vocab] |
| | split_char_positions = {mention.start for mention in target_mentions} | {mention.end for mention in target_mentions} |
| |
|
| | tokens: list[str] = [] |
| | cur = 0 |
| | char_to_token_mapping = {} |
| | for char_position in sorted(split_char_positions): |
| | target_text = text[cur:char_position] |
| | tokens += tokenizer.tokenize(target_text) |
| | char_to_token_mapping[char_position] = len(tokens) |
| | cur = char_position |
| | tokens += tokenizer.tokenize(text[cur:]) |
| |
|
| | entities = [ |
| | _Entity( |
| | entity_vocab[mention.kb_id], |
| | char_to_token_mapping[mention.start], |
| | char_to_token_mapping[mention.end], |
| | ) |
| | for mention in target_mentions |
| | ] |
| | return tokens, entities |
| |
|
| |
|
| | class KPRBertTokenizer(BertTokenizer): |
| | vocab_files_names = { |
| | **BertTokenizer.vocab_files_names, |
| | "entity_linker_data_file": "entity_linker/data.npy", |
| | "entity_linker_offsets_file": "entity_linker/offsets.npy", |
| | "entity_linker_name_trie_file": "entity_linker/name.trie", |
| | "entity_linker_kb_id_trie_file": "entity_linker/kb_id.trie", |
| | "entity_linker_config_file": "entity_linker/config.json", |
| | "entity_vocab_file": "entity_vocab.tsv", |
| | "entity_embeddings_file": "entity_embeddings.npy", |
| | } |
| | model_input_names = [ |
| | "input_ids", |
| | "token_type_ids", |
| | "attention_mask", |
| | "entity_ids", |
| | "entity_position_ids", |
| | ] |
| |
|
| | def __init__( |
| | self, |
| | vocab_file, |
| | entity_linker_data_file: str, |
| | entity_vocab_file: str, |
| | entity_embeddings_file: str | None = None, |
| | *args, |
| | **kwargs, |
| | ): |
| | super().__init__(vocab_file=vocab_file, *args, **kwargs) |
| | entity_linker_dir = str(Path(entity_linker_data_file).parent) |
| | self.entity_linker = DictionaryEntityLinker.load(entity_linker_dir) |
| | self.entity_to_id = load_tsv_entity_vocab(entity_vocab_file) |
| | self.id_to_entity = {v: k for k, v in self.entity_to_id.items()} |
| |
|
| | self.entity_embeddings = None |
| | if entity_embeddings_file: |
| | |
| | self.entity_embeddings = np.load(entity_embeddings_file, mmap_mode="r") |
| | if self.entity_embeddings.shape[0] != len(self.entity_to_id): |
| | raise ValueError( |
| | f"Entity embeddings shape {self.entity_embeddings.shape[0]} does not match " |
| | f"the number of entities {len(self.entity_to_id)}. " |
| | "Make sure `embeddings.py` and `entity_vocab.tsv` are consistent." |
| | ) |
| |
|
| | def _preprocess_text(self, text: str, **kwargs) -> dict[str, list[int | list[int]]]: |
| | mentions = self.entity_linker.detect_mentions(text) |
| | model_inputs = preprocess_text( |
| | text=text, |
| | mentions=mentions, |
| | title=None, |
| | title_mentions=None, |
| | tokenizer=self, |
| | entity_vocab=self.entity_to_id, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | prepared_inputs = self.prepare_for_model( |
| | model_inputs["input_ids"], |
| | **{k: v for k, v in kwargs.items() if k != "return_tensors"}, |
| | ) |
| | model_inputs.update(prepared_inputs) |
| |
|
| | |
| | if kwargs.get("add_special_tokens", True): |
| | if prepared_inputs["input_ids"][0] != self.cls_token_id: |
| | raise ValueError( |
| | "We assume that the input IDs start with the [CLS] token with add_special_tokens = True." |
| | ) |
| | |
| | model_inputs["entity_position_ids"] = [ |
| | [pos + 1 for pos in positions] for positions in model_inputs["entity_position_ids"] |
| | ] |
| |
|
| | |
| | if not model_inputs["entity_ids"]: |
| | model_inputs["entity_ids"] = [0] |
| | model_inputs["entity_position_ids"] = [[0]] |
| |
|
| | |
| | num_special_tokens_at_end = 0 |
| | input_ids = prepared_inputs["input_ids"] |
| | if isinstance(input_ids, torch.Tensor): |
| | input_ids = input_ids.tolist() |
| | for input_id in input_ids[::-1]: |
| | if int(input_id) not in { |
| | self.sep_token_id, |
| | self.pad_token_id, |
| | self.cls_token_id, |
| | }: |
| | break |
| | num_special_tokens_at_end += 1 |
| |
|
| | |
| | max_effective_pos = len(model_inputs["input_ids"]) - num_special_tokens_at_end |
| | entity_indices_to_keep = list() |
| | for i, position_ids in enumerate(model_inputs["entity_position_ids"]): |
| | if len(position_ids) > 0 and max(position_ids) < max_effective_pos: |
| | entity_indices_to_keep.append(i) |
| | model_inputs["entity_ids"] = [model_inputs["entity_ids"][i] for i in entity_indices_to_keep] |
| | model_inputs["entity_position_ids"] = [model_inputs["entity_position_ids"][i] for i in entity_indices_to_keep] |
| |
|
| | if self.entity_embeddings is not None: |
| | model_inputs["entity_embeds"] = self.entity_embeddings[model_inputs["entity_ids"]].astype(np.float32) |
| | return model_inputs |
| |
|
| | def __call__(self, text: str | list[str], **kwargs) -> BatchEncoding: |
| | for unsupported_arg in ["text_pair", "text_target", "text_pair_target"]: |
| | if unsupported_arg in kwargs: |
| | raise ValueError( |
| | f"Argument '{unsupported_arg}' is not supported by {self.__class__.__name__}. " |
| | "This tokenizer only supports single text inputs. " |
| | ) |
| |
|
| | if isinstance(text, str): |
| | processed_inputs = self._preprocess_text(text, **kwargs) |
| | return BatchEncoding( |
| | processed_inputs, |
| | tensor_type=kwargs.get("return_tensors", None), |
| | prepend_batch_axis=True, |
| | ) |
| |
|
| | processed_inputs_list: list[dict[str, list[int]]] = [self._preprocess_text(t, **kwargs) for t in text] |
| | collated_inputs = { |
| | key: [item[key] for item in processed_inputs_list] for key in processed_inputs_list[0].keys() |
| | } |
| | if kwargs.get("padding"): |
| | collated_inputs = self.pad( |
| | collated_inputs, |
| | padding=kwargs["padding"], |
| | max_length=kwargs.get("max_length"), |
| | pad_to_multiple_of=kwargs.get("pad_to_multiple_of"), |
| | return_attention_mask=kwargs.get("return_attention_mask"), |
| | verbose=kwargs.get("verbose", True), |
| | ) |
| | |
| | max_num_entities = max(len(ids) for ids in collated_inputs["entity_ids"]) |
| | for entity_ids in collated_inputs["entity_ids"]: |
| | entity_ids += [0] * (max_num_entities - len(entity_ids)) |
| | |
| | flattened_entity_length = [ |
| | len(ids) for ids_list in collated_inputs["entity_position_ids"] for ids in ids_list |
| | ] |
| | max_entity_token_length = max(flattened_entity_length) if flattened_entity_length else 0 |
| | for entity_position_ids_list in collated_inputs["entity_position_ids"]: |
| | |
| | for entity_position_ids in entity_position_ids_list: |
| | entity_position_ids += [0] * (max_entity_token_length - len(entity_position_ids)) |
| | |
| | entity_position_ids_list += [[0 for _ in range(max_entity_token_length)]] * ( |
| | max_num_entities - len(entity_position_ids_list) |
| | ) |
| | |
| | if "entity_embeds" in collated_inputs: |
| | for i in range(len(collated_inputs["entity_embeds"])): |
| | collated_inputs["entity_embeds"][i] = np.pad( |
| | collated_inputs["entity_embeds"][i], |
| | pad_width=( |
| | ( |
| | 0, |
| | max_num_entities - len(collated_inputs["entity_embeds"][i]), |
| | ), |
| | (0, 0), |
| | ), |
| | mode="constant", |
| | constant_values=0, |
| | ) |
| | return BatchEncoding(collated_inputs, tensor_type=kwargs.get("return_tensors", None)) |
| |
|
| | def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]: |
| | os.makedirs(save_directory, exist_ok=True) |
| | saved_files = list(super().save_vocabulary(save_directory, filename_prefix)) |
| |
|
| | |
| | entity_linker_save_dir = str( |
| | Path(save_directory) / Path(self.vocab_files_names["entity_linker_data_file"]).parent |
| | ) |
| | self.entity_linker.save(entity_linker_save_dir) |
| | for file_name in self.vocab_files_names.values(): |
| | if file_name.startswith("entity_linker/"): |
| | saved_files.append(file_name) |
| |
|
| | |
| | entity_vocab_path = str(Path(save_directory) / self.vocab_files_names["entity_vocab_file"]) |
| | save_tsv_entity_vocab(entity_vocab_path, self.entity_to_id) |
| | saved_files.append(self.vocab_files_names["entity_vocab_file"]) |
| |
|
| | if self.entity_embeddings is not None: |
| | entity_embeddings_path = str(Path(save_directory) / self.vocab_files_names["entity_embeddings_file"]) |
| | np.save(entity_embeddings_path, self.entity_embeddings) |
| | saved_files.append(self.vocab_files_names["entity_embeddings_file"]) |
| | return tuple(saved_files) |
| |
|