Update app.py
Browse files
app.py
CHANGED
|
@@ -127,10 +127,34 @@ def get_pred_classification_labels(clauses, batch_size=32):
|
|
| 127 |
logits = classification_model(**model_inputs)[0]
|
| 128 |
pred_labels = logits.argmax(-1).numpy()
|
| 129 |
pred_labels = [index2label[l] for l in pred_labels]
|
| 130 |
-
clause2labels.extend([(s,
|
| 131 |
-
print(clause2labels)
|
| 132 |
return clause2labels
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
|
| 136 |
def run_pipeline(text):
|
|
@@ -143,7 +167,10 @@ def run_pipeline(text):
|
|
| 143 |
|
| 144 |
clause2labels = get_pred_classification_labels(all_clauses)
|
| 145 |
output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
# with open("pipeline_outputs.jsonl", "w") as fw:
|
| 149 |
# with open("all_text.txt", "r") as f:
|
|
@@ -165,7 +192,7 @@ color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon
|
|
| 165 |
index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
|
| 166 |
color_panel_2 = ["Gray", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Violet"]
|
| 167 |
str_attrs = sorted([str(v) for v in set(labels2attrs.values())])
|
| 168 |
-
print(str_attrs, len(str_attrs), len(color_panel_2))
|
| 169 |
assert len(str_attrs) == len(color_panel_2)
|
| 170 |
attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)}
|
| 171 |
# attr_colormap = {
|
|
@@ -203,6 +230,8 @@ demo = gr.Interface(
|
|
| 203 |
show_legend=True,
|
| 204 |
combine_adjacent=False,
|
| 205 |
).style(color_map=attr_colormap),
|
|
|
|
|
|
|
| 206 |
]
|
| 207 |
)
|
| 208 |
|
|
|
|
| 127 |
logits = classification_model(**model_inputs)[0]
|
| 128 |
pred_labels = logits.argmax(-1).numpy()
|
| 129 |
pred_labels = [index2label[l] for l in pred_labels]
|
| 130 |
+
clause2labels.extend([(s, labels2attrs[l],) for s,l in zip(batch_examples, pred_labels)])
|
|
|
|
| 131 |
return clause2labels
|
| 132 |
|
| 133 |
+
def label_visualization(clause2labels):
|
| 134 |
+
total_clauses = len(clause2labels)
|
| 135 |
+
aspect_labels, genericity_labels, boundedness_labels = [], [], []
|
| 136 |
+
for _, labels in clause2labels:
|
| 137 |
+
labels = tuple(labels)
|
| 138 |
+
print(labels)
|
| 139 |
+
|
| 140 |
+
genericity_label = labels[0]
|
| 141 |
+
aspect_label = labels[1]
|
| 142 |
+
boundedness_label = labels[2]
|
| 143 |
+
aspect_labels.append(aspect_label)
|
| 144 |
+
genericity_labels.append(genericity_label)
|
| 145 |
+
boundedness_labels.append(boundedness_label)
|
| 146 |
+
aspect_dict = {"Dynamic": aspect_labels.count("dynamic"), "Stative": aspect_labels.count("stative")}
|
| 147 |
+
genericity_dict = {"Generic": genericity_labels.count("generic"), "Specific": genericity_labels.count("specific")}
|
| 148 |
+
boundedness_dict = {"Static": boundedness_labels.count("static"), "Episodic": boundedness_labels.count("episodic"), "Habitual": aspect_labels.count("habitual")}
|
| 149 |
+
print(aspect_dict, genericity_dict, boundedness_dict)
|
| 150 |
+
fig, axs = plt.subplots(1, 3, figsize=(10, 6,))
|
| 151 |
+
axs[0].pie([float(v / total_clauses) for v in aspect_dict.values()], colors = sns.color_palette('pastel')[0:3], labels=aspect_dict.keys(), autopct='%.0f%%', normalize=True )
|
| 152 |
+
axs[0].set_title("Aspect")
|
| 153 |
+
axs[1].pie([float(v / total_clauses) for v in genericity_dict.values()], colors = sns.color_palette('pastel')[3:6], labels=genericity_dict.keys(), autopct='%.0f%%', normalize=True)
|
| 154 |
+
axs[1].set_title("Genericity")
|
| 155 |
+
axs[2].pie([float(v / total_clauses) for v in boundedness_dict.values()], colors = sns.color_palette('pastel')[8:10], labels=boundedness_dict.keys(), autopct='%.0f%%', normalize=True)
|
| 156 |
+
axs[2].set_title("Boundedness")
|
| 157 |
+
return fig
|
| 158 |
|
| 159 |
|
| 160 |
def run_pipeline(text):
|
|
|
|
| 167 |
|
| 168 |
clause2labels = get_pred_classification_labels(all_clauses)
|
| 169 |
output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)]
|
| 170 |
+
figure = label_visualization(clause2labels)
|
| 171 |
+
clause2labels = [(k,str(v),) for k, v in clause2labels]
|
| 172 |
+
return output_clauses, clause2labels, figure
|
| 173 |
+
|
| 174 |
|
| 175 |
# with open("pipeline_outputs.jsonl", "w") as fw:
|
| 176 |
# with open("all_text.txt", "r") as f:
|
|
|
|
| 192 |
index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)}
|
| 193 |
color_panel_2 = ["Gray", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Violet"]
|
| 194 |
str_attrs = sorted([str(v) for v in set(labels2attrs.values())])
|
| 195 |
+
# print(str_attrs, len(str_attrs), len(color_panel_2))
|
| 196 |
assert len(str_attrs) == len(color_panel_2)
|
| 197 |
attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)}
|
| 198 |
# attr_colormap = {
|
|
|
|
| 230 |
show_legend=True,
|
| 231 |
combine_adjacent=False,
|
| 232 |
).style(color_map=attr_colormap),
|
| 233 |
+
|
| 234 |
+
gr.Plot(),
|
| 235 |
]
|
| 236 |
)
|
| 237 |
|