| | import torch |
| | import numpy as np |
| | import sys |
| | import os |
| | from transformers import RobertaTokenizer, AutoModelForTokenClassification, RobertaForSequenceClassification |
| | import spacy |
| | import tokenizations |
| | from numpy import asarray |
| | from numpy import savetxt, loadtxt |
| | import numpy as np |
| | import json |
| | from copy import deepcopy |
| | |
| | import re |
| | from tqdm import tqdm |
| | import gradio as gr |
| | from matplotlib import pyplot as plt |
| | import seaborn as sns |
| |
|
| | os.system("python -m spacy download en_core_web_sm") |
| | nlp = spacy.load("en_core_web_sm") |
| | tokenizer = RobertaTokenizer.from_pretrained("roberta-base") |
| | clause_model = AutoModelForTokenClassification.from_pretrained("./clause_model_512", num_labels=3) |
| | classification_model = RobertaForSequenceClassification.from_pretrained("./classfication_model", num_labels=18) |
| |
|
| |
|
| | labels2attrs = { |
| | "##BOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "episodic"), |
| | "##BOUNDED EVENT (GENERIC)": ("generic", "dynamic", "episodic"), |
| | "##UNBOUNDED EVENT (SPECIFIC)": ("specific", "dynamic", "static"), |
| | "##UNBOUNDED EVENT (GENERIC)": ("generic", "dynamic", "static"), |
| | "##BASIC STATE": ("specific", "stative", "static"), |
| | "##COERCED STATE (SPECIFIC)": ("specific", "dynamic", "static"), |
| | "##COERCED STATE (GENERIC)": ("generic", "dynamic", "static"), |
| | "##PERFECT COERCED STATE (SPECIFIC)": ("specific", "dynamic", "episodic"), |
| | "##PERFECT COERCED STATE (GENERIC)": ("generic", "dynamic", "episodic"), |
| | "##GENERIC SENTENCE (DYNAMIC)": ("generic", "dynamic", "habitual"), |
| | "##GENERIC SENTENCE (STATIC)": ("generic", "stative", "static"), |
| | "##GENERIC SENTENCE (HABITUAL)": ("generic", "stative", "habitual"), |
| | "##GENERALIZING SENTENCE (DYNAMIC)": ("specific", "dynamic", "habitual"), |
| | "##GENERALIZING SENTENCE (STATIVE)": ("specific", "stative", "habitual"), |
| | "##QUESTION": ("NA", "NA", "NA"), |
| | "##IMPERATIVE": ("NA", "NA", "NA"), |
| | "##NONSENSE": ("NA", "NA", "NA"), |
| | "##OTHER": ("NA", "NA", "NA"), |
| | } |
| |
|
| | label2index = {l:i for l,i in zip(labels2attrs.keys(), np.arange(len(labels2attrs)))} |
| | index2label = {i:l for l,i in label2index.items()} |
| |
|
| | def auto_split(text): |
| | doc = nlp(text) |
| | current_len = 0 |
| | snippets = [] |
| | current_snippet = "" |
| | for sent in doc.sents: |
| | text = sent.text |
| | words = text.split() |
| | if current_len + len(words) > 200: |
| | snippets.append(current_snippet) |
| | current_snippet = text |
| | current_len = len(words) |
| | else: |
| | current_snippet += " " + text |
| | current_len += len(words) |
| | snippets.append(current_snippet) |
| | return snippets |
| |
|
| |
|
| | def majority_vote(array): |
| | unique, counts = np.unique(np.array(array), return_counts=True) |
| | return unique[np.argmax(counts)] |
| |
|
| | def get_pred_clause_labels(text, words): |
| | model_inputs = tokenizer(text, padding='max_length', max_length=512, truncation=True, return_tensors='pt') |
| | roberta_tokens = (tokenizer.convert_ids_to_tokens(model_inputs['input_ids'][0])) |
| | a2b, b2a = tokenizations.get_alignments(words, roberta_tokens) |
| | logits = clause_model(**model_inputs)[0] |
| | tagging = logits.argmax(-1)[0].numpy() |
| | pred_labels = [] |
| | for aligment in a2b: |
| | if len(aligment) == 0: pred_labels.append(1) |
| | elif len(aligment) == 1: pred_labels.append(tagging[aligment[0]]) |
| | else: |
| | pred_labels.append(majority_vote([tagging[a] for a in aligment])) |
| | assert len(pred_labels) == len(words) |
| | return pred_labels |
| |
|
| | def seg_clause(text): |
| | words = text.strip().split() |
| | labels = get_pred_clause_labels(text, words) |
| | segmented_clauses = [] |
| | prev_label = 2 |
| | current_clause = None |
| | for cur_token, cur_label in zip(words, labels): |
| | if prev_label == 2: current_clause = [] |
| | if current_clause != None: current_clause.append(cur_token) |
| | |
| | if cur_label == 2: |
| | if prev_label in [0, 1]: |
| | segmented_clauses.append(deepcopy(current_clause)) |
| | current_clause = None |
| | prev_label = cur_label |
| |
|
| | if current_clause is not None and len(current_clause) != 0: |
| | segmented_clauses.append(deepcopy(current_clause)) |
| | return [" ".join(clause) for clause in segmented_clauses if clause is not None] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def get_pred_classification_labels(clauses, batch_size=32): |
| | clause2labels = [] |
| | for i in range(0, len(clauses), batch_size): |
| | batch_examples = clauses[i : i + batch_size] |
| | model_inputs = tokenizer(batch_examples, padding='max_length', max_length=128, truncation=True, return_tensors='pt') |
| | logits = classification_model(**model_inputs)[0] |
| | pred_labels = logits.argmax(-1).numpy() |
| | pred_labels = [index2label[l] for l in pred_labels] |
| | clause2labels.extend([(s, labels2attrs[l],) for s,l in zip(batch_examples, pred_labels)]) |
| | return clause2labels |
| |
|
| | def label_visualization(clause2labels): |
| | total_clauses = len(clause2labels) |
| | aspect_labels, genericity_labels, boundedness_labels = [], [], [] |
| | for _, labels in clause2labels: |
| | labels = tuple(labels) |
| | print(labels) |
| | |
| | genericity_label = labels[0] |
| | aspect_label = labels[1] |
| | boundedness_label = labels[2] |
| | aspect_labels.append(aspect_label) |
| | genericity_labels.append(genericity_label) |
| | boundedness_labels.append(boundedness_label) |
| | aspect_dict = {"Dynamic": aspect_labels.count("dynamic"), "Stative": aspect_labels.count("stative"), "NA": aspect_labels.count("NA")} |
| | genericity_dict = {"Generic": genericity_labels.count("generic"), "Specific": genericity_labels.count("specific"), "NA": genericity_labels.count("NA")} |
| | boundedness_dict = {"Static": boundedness_labels.count("static"), "Episodic": boundedness_labels.count("episodic"), "Habitual": boundedness_labels.count("habitual"), "NA": boundedness_labels.count("NA")} |
| | print(aspect_dict, genericity_dict, boundedness_dict) |
| | fig, axs = plt.subplots(1, 3, figsize=(10, 6,)) |
| | fig.tight_layout(pad=5.0) |
| | dict_aspect = {k : float(v / total_clauses) for k, v in aspect_dict.items() if v != 0} |
| | dict_genericity = {k : float(v / total_clauses) for k, v in genericity_dict.items() if v != 0} |
| | dict_boundedness = {k : float(v / total_clauses) for k, v in boundedness_dict.items() if v != 0} |
| | print(dict_aspect) |
| | print(list(dict_aspect.values()), len(dict_aspect.keys()), list(dict_aspect.keys())) |
| | axs[0].pie(list(dict_aspect.values()), colors = sns.color_palette('pastel')[0:len(dict_aspect.keys())], |
| | labels=dict_aspect.keys(), autopct='%.0f%%', normalize=True ) |
| | axs[0].set_title("Aspect") |
| | axs[1].pie(list(dict_genericity.values()), colors = sns.color_palette('pastel')[3: 3 + len(dict_genericity.keys())], |
| | labels=dict_genericity.keys(), autopct='%.0f%%', normalize=True) |
| | axs[1].set_title("Genericity") |
| | axs[2].pie(list(dict_boundedness.values()), colors = sns.color_palette('pastel')[6: 6 + len(dict_boundedness.keys())], |
| | labels=dict_boundedness.keys(), autopct='%.0f%%', normalize=True) |
| | axs[2].set_title("Boundedness") |
| | return fig |
| |
|
| | def run_pipeline(text): |
| | snippets = auto_split(text) |
| | print(snippets) |
| | all_clauses = [] |
| | for s in snippets: |
| | segmented_clauses = seg_clause(s) |
| | all_clauses.extend(segmented_clauses) |
| |
|
| | clause2labels = get_pred_classification_labels(all_clauses) |
| | output_clauses = [(c, str(i + 1)) for i, c in enumerate(all_clauses)] |
| | figure = label_visualization(clause2labels) |
| | clause2labels = [(k,str(v),) for k, v in clause2labels] |
| | return output_clauses, clause2labels, figure |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | color_panel_1 = ["red", "green", "yellow", "DodgerBlue", "orange", "DarkSalmon", "pink", "cyan", "gold", "aqua", "violet"] |
| | index_colormap = {str(i) : color_panel_1[i % len(color_panel_1)] for i in np.arange(1, 100000)} |
| | color_panel_2 = ["Gray", "DodgerBlue", "Wheat", "OliveDrab", "DarkKhaki", "DarkSalmon", "Orange", "Gold", "Aqua", "Tomato", "Violet"] |
| | str_attrs = sorted([str(v) for v in set(labels2attrs.values())]) |
| | |
| | assert len(str_attrs) == len(color_panel_2) |
| | attr_colormap = {a:c for a, c in zip(str_attrs, color_panel_2)} |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | demo = gr.Interface( |
| | fn=run_pipeline, |
| | inputs=["text"], |
| | outputs= [ |
| | gr.HighlightedText( |
| | label="Clause Segmentation", |
| | show_label=True, |
| | combine_adjacent=False, |
| | ).style(color_map = index_colormap), |
| |
|
| | gr.HighlightedText( |
| | label="Attribute Classification", |
| | show_label=True, |
| | show_legend=True, |
| | combine_adjacent=False, |
| | ).style(color_map=attr_colormap), |
| |
|
| | gr.Plot(), |
| | ] |
| | ) |
| |
|
| | demo.launch() |