TymaaHammouda commited on
Commit
fac752d
·
verified ·
1 Parent(s): 309ea92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -40
app.py CHANGED
@@ -1,10 +1,18 @@
1
  from fastapi import FastAPI
2
- from huggingface_hub import hf_hub_download
 
3
  import os
4
  from pydantic import BaseModel
5
  from fastapi.responses import JSONResponse
6
  from transformers import AutoTokenizer, AutoModel
7
  import json
 
 
 
 
 
 
 
8
 
9
  print("Version ---- 2")
10
 
@@ -21,30 +29,167 @@ from fastapi.responses import JSONResponse
21
  print("Version ---- 2")
22
  app = FastAPI()
23
 
24
- def download_file_from_hf(repo_id, filename):
25
- target_dir = os.path.expanduser("~/.sinatools/Wj27012000.tar")
26
- os.makedirs(target_dir, exist_ok=True)
27
-
28
- file_path = hf_hub_download(
29
- repo_id=repo_id,
30
- filename=filename,
31
- local_dir=target_dir,
32
- local_dir_use_symlinks=False
33
- )
34
 
35
- return file_path
36
-
37
- download_file_from_hf("SinaLab/Nested-v1","args.json")
38
- download_file_from_hf("SinaLab/Nested-v1","tag_vocab.pkl")
39
-
40
- snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- BASE_DIR = os.path.expanduser("~/.sinatools")
44
 
45
  # Paths expected by sinatools
46
  RELATION_MODEL_DIR = os.path.join(BASE_DIR, "relation_model")
47
- # NER_DIR = os.path.join(BASE_DIR, "Wj27012000.tar")
48
 
49
  os.makedirs(BASE_DIR, exist_ok=True)
50
 
@@ -58,27 +203,6 @@ if not os.path.exists(RELATION_MODEL_DIR) or not os.listdir(RELATION_MODEL_DIR):
58
  local_dir_use_symlinks=False
59
  )
60
 
61
- # -------------------------
62
- # 2. Download NER resources
63
- # -------------------------
64
- # if not os.path.exists(NER_DIR):
65
- # os.makedirs(NER_DIR, exist_ok=True)
66
-
67
- # nested_repo_path = snapshot_download(
68
- # repo_id="SinaLab/Nested"
69
- # )
70
-
71
- # # Copy tag_vocab.pkl to expected location
72
- # src_vocab = os.path.join(nested_repo_path, "Nested", "utils", "tag_vocab.pkl")
73
- # dst_vocab = os.path.join(NER_DIR, "tag_vocab.pkl")
74
-
75
- # if os.path.exists(src_vocab):
76
- # shutil.copy(src_vocab, dst_vocab)
77
-
78
- # Optional debug
79
- print("sinatools dir:", os.listdir(BASE_DIR))
80
- # print("NER dir:", os.listdir(NER_DIR))
81
-
82
 
83
  from sinatools.relations.relation_extractor import relation_extraction
84
  from sinatools.relations.event_relation_extractor import event_argument_relation_extraction
 
1
  from fastapi import FastAPI
2
+ from huggingface_hub import hf_hub_download, snapshot_download
3
+ from Nested.nn.BertSeqTagger import BertSeqTagger
4
  import os
5
  from pydantic import BaseModel
6
  from fastapi.responses import JSONResponse
7
  from transformers import AutoTokenizer, AutoModel
8
  import json
9
+ from IBO_to_XML import IBO_to_XML
10
+ from XML_to_HTML import NER_XML_to_HTML
11
+ from NER_Distiller import distill_entities
12
+ from collections import namedtuple
13
+ from Nested.utils.helpers import load_checkpoint
14
+ from Nested.utils.data import get_dataloaders, text2segments
15
+ import pickle
16
 
17
  print("Version ---- 2")
18
 
 
29
  print("Version ---- 2")
30
  app = FastAPI()
31
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ app = FastAPI()
 
 
 
 
 
34
 
35
+ pretrained_path = "aubmindlab/bert-base-arabertv2" # must match training
36
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
37
+ encoder = AutoModel.from_pretrained(pretrained_path).eval()
38
+
39
+
40
+ checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
41
+
42
+ args_path = hf_hub_download(
43
+ repo_id="SinaLab/Nested",
44
+ filename="args.json"
45
+ )
46
+
47
+ with open(args_path, 'r') as f:
48
+ args_data = json.load(f)
49
+
50
+ # Load model
51
+ with open("Nested/utils/tag_vocab.pkl", "rb") as f:
52
+ label_vocab = pickle.load(f)
53
+
54
+ label_vocab = label_vocab[0] # the list loaded from pickle
55
+ id2label = {i: s for i, s in enumerate(label_vocab.itos)}
56
+
57
+ def split_text_into_groups_of_Ns(sentence, max_words_per_sentence):
58
+ # Split the text into words
59
+ words = sentence.split()
60
+
61
+ # Initialize variables
62
+ groups = []
63
+ current_group = ""
64
+ group_size = 0
65
+
66
+ # Iterate through the words
67
+ for word in words:
68
+ if group_size < max_words_per_sentence - 1:
69
+ if len(current_group) == 0:
70
+ current_group = word
71
+ else:
72
+ current_group += " " + word
73
+ group_size += 1
74
+ else:
75
+ current_group += " " + word
76
+ groups.append(current_group)
77
+ current_group = ""
78
+ group_size = 0
79
+
80
+ # Add the last group if it contains less than n words
81
+ if current_group:
82
+ groups.append(current_group)
83
+
84
+ return groups
85
+
86
+
87
+
88
+ def remove_empty_values(sentences):
89
+ return [value for value in sentences if value != '']
90
+
91
+
92
+ def sentence_tokenizer(text, dot=True, new_line=True, question_mark=True, exclamation_mark=True):
93
+ separators = []
94
+ split_text = [text]
95
+ if new_line==True:
96
+ separators.append('\n')
97
+ if dot==True:
98
+ separators.append('.')
99
+ if question_mark==True:
100
+ separators.append('?')
101
+ separators.append('؟')
102
+ if exclamation_mark==True:
103
+ separators.append('!')
104
+
105
+ for sep in separators:
106
+ new_split_text = []
107
+ for part in split_text:
108
+ tokens = part.split(sep)
109
+ tokens_with_separator = [token + sep for token in tokens[:-1]]
110
+ tokens_with_separator.append(tokens[-1].strip())
111
+ new_split_text.extend(tokens_with_separator)
112
+ split_text = new_split_text
113
+
114
+ split_text = remove_empty_values(split_text)
115
+ return split_text
116
+
117
+ def jsons_to_list_of_lists(json_list):
118
+ return [[d['token'], d['tags']] for d in json_list]
119
+
120
+ tagger, tag_vocab, train_config = load_checkpoint(checkpoint_path)
121
+
122
+ def extract(sentence):
123
+ dataset, token_vocab = text2segments(sentence)
124
+
125
+ vocabs = namedtuple("Vocab", ["tags", "tokens"])
126
+ vocab = vocabs(tokens=token_vocab, tags=tag_vocab)
127
+
128
+ dataloader = get_dataloaders(
129
+ (dataset,),
130
+ vocab,
131
+ args_data,
132
+ batch_size=32,
133
+ shuffle=(False,),
134
+ )[0]
135
+
136
+ segments = tagger.infer(dataloader)
137
+
138
+ lists = []
139
+
140
+ for segment in segments:
141
+ for token in segment:
142
+ item = {}
143
+ item["token"] = token.text
144
+
145
+ list_of_tags = [t["tag"] for t in token.pred_tag]
146
+ list_of_tags = [i for i in list_of_tags if i not in ("O", " ", "")]
147
+
148
+ if not list_of_tags:
149
+ item["tags"] = "O"
150
+ else:
151
+ item["tags"] = " ".join(list_of_tags)
152
+ lists.append(item)
153
+ return lists
154
+
155
+
156
+ def NER(sentence, mode):
157
+ output_list = []
158
+ xml = ""
159
+ if mode.strip() == "1":
160
+ output_list = jsons_to_list_of_lists(extract(sentence))
161
+ return output_list
162
+ elif mode.strip() == "2":
163
+ if output_list != []:
164
+ xml = IBO_to_XML(output_list)
165
+ return xml
166
+ else:
167
+ output_list = jsons_to_list_of_lists(extract(sentence))
168
+ xml = IBO_to_XML(output_list)
169
+ return xml
170
+
171
+ elif mode.strip() == "3":
172
+ if xml != "":
173
+ html = NER_XML_to_HTML(xml)
174
+ return html
175
+ else:
176
+ output_list = jsons_to_list_of_lists(extract(sentence))
177
+ xml = IBO_to_XML(output_list)
178
+ html = NER_XML_to_HTML(xml)
179
+ return html
180
+
181
+ elif mode.strip() == "4": # json short
182
+ if output_list != []:
183
+ json_short = distill_entities(output_list)
184
+ return json_short
185
+ else:
186
+ output_list = jsons_to_list_of_lists(extract(sentence))
187
+ json_short = distill_entities(output_list)
188
+ return json_short
189
 
 
190
 
191
  # Paths expected by sinatools
192
  RELATION_MODEL_DIR = os.path.join(BASE_DIR, "relation_model")
 
193
 
194
  os.makedirs(BASE_DIR, exist_ok=True)
195
 
 
203
  local_dir_use_symlinks=False
204
  )
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  from sinatools.relations.relation_extractor import relation_extraction
208
  from sinatools.relations.event_relation_extractor import event_argument_relation_extraction