Frederick commited on
Commit
053f51f
·
1 Parent(s): 6d96856

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -10,6 +10,7 @@ from numpy import savetxt, loadtxt
10
  import numpy as np
11
  import json
12
  from copy import deepcopy
 
13
  import re
14
  from tqdm import tqdm
15
  import gradio as gr
@@ -17,8 +18,8 @@ import gradio as gr
17
  os.system("python -m spacy download en_core_web_sm")
18
  nlp = spacy.load("en_core_web_sm")
19
  tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
20
- clause_model = AutoModelForTokenClassification.from_pretrained("./clause_model_512", num_labels=3)
21
- classification_model = RobertaForSequenceClassification.from_pretrained("./classfication_model", num_labels=18)
22
 
23
 
24
  labels2attrs = {
@@ -103,41 +104,43 @@ def seg_clause(text):
103
  segmented_clauses.append(deepcopy(current_clause))
104
  return [" ".join(clause) for clause in segmented_clauses if clause is not None]
105
 
106
- def pretty_print_segmented_clause(segmented_clauses):
107
- np.random.seed(42)
108
- bg.orange = Style(RgbBg(255, 150, 50))
109
- bg.purple = Style(RgbBg(180, 130, 225))
110
- colors = [bg.red, bg.orange, bg.yellow, bg.green, bg.blue, bg.purple]
111
- prev_color = 0
112
- to_print = []
113
- for cl in segmented_clauses:
114
- color_choice = np.random.choice(np.delete(np.arange(len(colors)), prev_color))
115
- prev_color = color_choice
116
- colored_cl = colors[color_choice] + cl + bg.rs
117
- to_print.append(colored_cl)
118
- print(*to_print, sep=" ")
119
 
120
 
121
  def get_pred_classification_labels(clauses, batch_size=32):
122
  clause2labels = []
123
- for i in range(0, len(clauses) + 1, batch_size):
124
  batch_examples = clauses[i : i + batch_size]
125
  model_inputs = tokenizer(batch_examples, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
126
  logits = classification_model(**model_inputs)[0]
127
  pred_labels = logits.argmax(-1).numpy()
128
  pred_labels = [index2label[l] for l in pred_labels]
129
-
130
- clause2labels.extend([(s, str(l),) for s,l in zip(batch_examples, pred_labels)])
131
  return clause2labels
132
 
133
 
134
 
135
  def run_pipeline(text):
136
  snippets = auto_split(text)
 
137
  all_clauses = []
138
  for s in snippets:
139
  segmented_clauses = seg_clause(s)
140
  all_clauses.extend(segmented_clauses)
 
141
  clause2labels = get_pred_classification_labels(all_clauses)
142
  output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
143
  return output_clauses, clause2labels
@@ -160,7 +163,7 @@ def run_pipeline(text):
160
 
161
  color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon", "pink", "cyan", "gold", "aqua", "violet"]
162
  index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
163
- color_panel_2 = ["Violet", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Gray"]
164
  str_attrs = [str(v) for v in set(labels2attrs.values())]
165
  print(str_attrs, len(str_attrs), len(color_panel_2))
166
  assert len(str_attrs) == len(color_panel_2)
 
10
  import numpy as np
11
  import json
12
  from copy import deepcopy
13
+ # from sty import fg, bg, ef, rs, RgbBg, Style
14
  import re
15
  from tqdm import tqdm
16
  import gradio as gr
 
18
  os.system("python -m spacy download en_core_web_sm")
19
  nlp = spacy.load("en_core_web_sm")
20
  tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
21
+ clause_model = AutoModelForTokenClassification.from_pretrained("C:\\Users\\pixin\\Desktop\\Reddit ML\\Trained Models\\clause_model_512", num_labels=3)
22
+ classification_model = RobertaForSequenceClassification.from_pretrained("C:\\Users\\pixin\\Desktop\\Reddit ML\\Trained Models\\classfication_model", num_labels=18)
23
 
24
 
25
  labels2attrs = {
 
104
  segmented_clauses.append(deepcopy(current_clause))
105
  return [" ".join(clause) for clause in segmented_clauses if clause is not None]
106
 
107
+ # def pretty_print_segmented_clause(segmented_clauses):
108
+ # np.random.seed(42)
109
+ # bg.orange = Style(RgbBg(255, 150, 50))
110
+ # bg.purple = Style(RgbBg(180, 130, 225))
111
+ # colors = [bg.red, bg.orange, bg.yellow, bg.green, bg.blue, bg.purple]
112
+ # prev_color = 0
113
+ # to_print = []
114
+ # for cl in segmented_clauses:
115
+ # color_choice = np.random.choice(np.delete(np.arange(len(colors)), prev_color))
116
+ # prev_color = color_choice
117
+ # colored_cl = colors[color_choice] + cl + bg.rs
118
+ # to_print.append(colored_cl)
119
+ # print(*to_print, sep=" ")
120
 
121
 
122
  def get_pred_classification_labels(clauses, batch_size=32):
123
  clause2labels = []
124
+ for i in range(0, len(clauses), batch_size):
125
  batch_examples = clauses[i : i + batch_size]
126
  model_inputs = tokenizer(batch_examples, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
127
  logits = classification_model(**model_inputs)[0]
128
  pred_labels = logits.argmax(-1).numpy()
129
  pred_labels = [index2label[l] for l in pred_labels]
130
+ clause2labels.extend([(s, str(labels2attrs[l]),) for s,l in zip(batch_examples, pred_labels)])
131
+ print(clause2labels)
132
  return clause2labels
133
 
134
 
135
 
136
  def run_pipeline(text):
137
  snippets = auto_split(text)
138
+ print(snippets)
139
  all_clauses = []
140
  for s in snippets:
141
  segmented_clauses = seg_clause(s)
142
  all_clauses.extend(segmented_clauses)
143
+
144
  clause2labels = get_pred_classification_labels(all_clauses)
145
  output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
146
  return output_clauses, clause2labels
 
163
 
164
  color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon", "pink", "cyan", "gold", "aqua", "violet"]
165
  index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
166
+ color_panel_2 = ["Violet", "Gray", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "DodgerBlue"]
167
  str_attrs = [str(v) for v in set(labels2attrs.values())]
168
  print(str_attrs, len(str_attrs), len(color_panel_2))
169
  assert len(str_attrs) == len(color_panel_2)