Update app.py
Browse files
app.py
CHANGED
|
@@ -10,6 +10,7 @@ from numpy import savetxt, loadtxt
|
|
| 10 |
import numpy as np
|
| 11 |
import json
|
| 12 |
from copy import deepcopy
|
|
|
|
| 13 |
import re
|
| 14 |
from tqdm import tqdm
|
| 15 |
import gradio as gr
|
|
@@ -17,8 +18,8 @@ import gradio as gr
|
|
| 17 |
os.system("python -m spacy download en_core_web_sm")
|
| 18 |
nlp = spacy.load("en_core_web_sm")
|
| 19 |
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
| 20 |
-
clause_model = AutoModelForTokenClassification.from_pretrained("
|
| 21 |
-
classification_model = RobertaForSequenceClassification.from_pretrained("
|
| 22 |
|
| 23 |
|
| 24 |
labels2attrs = {
|
|
@@ -103,41 +104,43 @@ def seg_clause(text):
|
|
| 103 |
segmented_clauses.append(deepcopy(current_clause))
|
| 104 |
return [" ".join(clause) for clause in segmented_clauses if clause is not None]
|
| 105 |
|
| 106 |
-
def pretty_print_segmented_clause(segmented_clauses):
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
|
| 120 |
|
| 121 |
def get_pred_classification_labels(clauses, batch_size=32):
|
| 122 |
clause2labels = []
|
| 123 |
-
for i in range(0, len(clauses)
|
| 124 |
batch_examples = clauses[i : i + batch_size]
|
| 125 |
model_inputs = tokenizer(batch_examples, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
|
| 126 |
logits = classification_model(**model_inputs)[0]
|
| 127 |
pred_labels = logits.argmax(-1).numpy()
|
| 128 |
pred_labels = [index2label[l] for l in pred_labels]
|
| 129 |
-
|
| 130 |
-
clause2labels
|
| 131 |
return clause2labels
|
| 132 |
|
| 133 |
|
| 134 |
|
| 135 |
def run_pipeline(text):
|
| 136 |
snippets = auto_split(text)
|
|
|
|
| 137 |
all_clauses = []
|
| 138 |
for s in snippets:
|
| 139 |
segmented_clauses = seg_clause(s)
|
| 140 |
all_clauses.extend(segmented_clauses)
|
|
|
|
| 141 |
clause2labels = get_pred_classification_labels(all_clauses)
|
| 142 |
output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
|
| 143 |
return output_clauses, clause2labels
|
|
@@ -160,7 +163,7 @@ def run_pipeline(text):
|
|
| 160 |
|
| 161 |
color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon", "pink", "cyan", "gold", "aqua", "violet"]
|
| 162 |
index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
|
| 163 |
-
color_panel_2 = ["Violet", "
|
| 164 |
str_attrs = [str(v) for v in set(labels2attrs.values())]
|
| 165 |
print(str_attrs, len(str_attrs), len(color_panel_2))
|
| 166 |
assert len(str_attrs) == len(color_panel_2)
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import json
|
| 12 |
from copy import deepcopy
|
| 13 |
+
# from sty import fg, bg, ef, rs, RgbBg, Style
|
| 14 |
import re
|
| 15 |
from tqdm import tqdm
|
| 16 |
import gradio as gr
|
|
|
|
| 18 |
os.system("python -m spacy download en_core_web_sm")
|
| 19 |
nlp = spacy.load("en_core_web_sm")
|
| 20 |
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
| 21 |
+
clause_model = AutoModelForTokenClassification.from_pretrained("C:\\Users\\pixin\\Desktop\\Reddit ML\\Trained Models\\clause_model_512", num_labels=3)
|
| 22 |
+
classification_model = RobertaForSequenceClassification.from_pretrained("C:\\Users\\pixin\\Desktop\\Reddit ML\\Trained Models\\classfication_model", num_labels=18)
|
| 23 |
|
| 24 |
|
| 25 |
labels2attrs = {
|
|
|
|
| 104 |
segmented_clauses.append(deepcopy(current_clause))
|
| 105 |
return [" ".join(clause) for clause in segmented_clauses if clause is not None]
|
| 106 |
|
| 107 |
+
# def pretty_print_segmented_clause(segmented_clauses):
|
| 108 |
+
# np.random.seed(42)
|
| 109 |
+
# bg.orange = Style(RgbBg(255, 150, 50))
|
| 110 |
+
# bg.purple = Style(RgbBg(180, 130, 225))
|
| 111 |
+
# colors = [bg.red, bg.orange, bg.yellow, bg.green, bg.blue, bg.purple]
|
| 112 |
+
# prev_color = 0
|
| 113 |
+
# to_print = []
|
| 114 |
+
# for cl in segmented_clauses:
|
| 115 |
+
# color_choice = np.random.choice(np.delete(np.arange(len(colors)), prev_color))
|
| 116 |
+
# prev_color = color_choice
|
| 117 |
+
# colored_cl = colors[color_choice] + cl + bg.rs
|
| 118 |
+
# to_print.append(colored_cl)
|
| 119 |
+
# print(*to_print, sep=" ")
|
| 120 |
|
| 121 |
|
| 122 |
def get_pred_classification_labels(clauses, batch_size=32):
|
| 123 |
clause2labels = []
|
| 124 |
+
for i in range(0, len(clauses), batch_size):
|
| 125 |
batch_examples = clauses[i : i + batch_size]
|
| 126 |
model_inputs = tokenizer(batch_examples, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
|
| 127 |
logits = classification_model(**model_inputs)[0]
|
| 128 |
pred_labels = logits.argmax(-1).numpy()
|
| 129 |
pred_labels = [index2label[l] for l in pred_labels]
|
| 130 |
+
clause2labels.extend([(s, str(labels2attrs[l]),) for s,l in zip(batch_examples, pred_labels)])
|
| 131 |
+
print(clause2labels)
|
| 132 |
return clause2labels
|
| 133 |
|
| 134 |
|
| 135 |
|
| 136 |
def run_pipeline(text):
|
| 137 |
snippets = auto_split(text)
|
| 138 |
+
print(snippets)
|
| 139 |
all_clauses = []
|
| 140 |
for s in snippets:
|
| 141 |
segmented_clauses = seg_clause(s)
|
| 142 |
all_clauses.extend(segmented_clauses)
|
| 143 |
+
|
| 144 |
clause2labels = get_pred_classification_labels(all_clauses)
|
| 145 |
output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
|
| 146 |
return output_clauses, clause2labels
|
|
|
|
| 163 |
|
| 164 |
color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon", "pink", "cyan", "gold", "aqua", "violet"]
|
| 165 |
index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
|
| 166 |
+
color_panel_2 = ["Violet", "Gray", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "DodgerBlue"]
|
| 167 |
str_attrs = [str(v) for v in set(labels2attrs.values())]
|
| 168 |
print(str_attrs, len(str_attrs), len(color_panel_2))
|
| 169 |
assert len(str_attrs) == len(color_panel_2)
|