Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from annotated_text import annotated_text | |
| import os | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch | |
| import torch.optim as optim | |
| from transformers import DistilBertModel | |
| from transformers import AutoTokenizer | |
| import lightning.pytorch as pl | |
| class Classifier(pl.LightningModule): | |
| def __init__(self): | |
| super().__init__() | |
| self.ln1 = torch.nn.Linear(512*768, 3) | |
| # self.ln2 = torch.nn.Linear(1000, 3 ) | |
| self.criterion = nn.CrossEntropyLoss() | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| with torch.no_grad(): | |
| x = get_bert()(input_ids = x[:,:512], attention_mask = x[:,512:]).last_hidden_state.reshape(-1, 512*768) | |
| x = (x/torch.linalg.norm(x,2, 1)).reshape(-1,512*768) | |
| x = self.ln1(x) | |
| # x = self.ln2(x) | |
| loss = self.criterion(x, y) | |
| self.log("my_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
| return loss | |
| def configure_optimizers(self): | |
| optimizer = optim.Adam(self.parameters(), lr=1e-3) | |
| return optimizer | |
| def preprocess(self, x): | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True) | |
| return tokenizer(x, padding='max_length', return_tensors="pt") | |
| def forward(self, x): | |
| print("here!", self.ln1.type) | |
| with torch.no_grad(): | |
| x = get_bert()(**x).last_hidden_state.reshape(-1, 512*768) | |
| x = (x/torch.linalg.norm(x,2, 1)).reshape(-1,512*768) | |
| x = self.ln1(x) | |
| # x = self.ln2(x) | |
| return x | |
| def get_bert(): | |
| return DistilBertModel.from_pretrained("distilbert-base-uncased") | |
| def get_classifier(): | |
| os.system('gdown 1GxhHvg3lwlGpA7So06v3l43U8pSASy9L') | |
| return Classifier.load_from_checkpoint(f"{os.getcwd()}/model_params") | |
| def get_annotated_text(text): | |
| model = get_classifier() | |
| text = text.split(".") | |
| l = [] | |
| for i in text: | |
| if i.strip(' ') == '': | |
| continue | |
| c = model(model.preprocess([i])).argmax() | |
| print("class : ", c) | |
| if c == 0: | |
| l.append((i, "Leadership")) | |
| if c == 1: | |
| l.append((i, "Diversity")) | |
| if c == 2: | |
| l.append((i, "Integrity")) | |
| l.append(".") | |
| return tuple(l) | |
| st.title("Code of Conduct Classifier") | |
| input_text = st.text_area("enter code of conduct text" ) | |
| st.title("annotated text") | |
| print(input_text) | |
| annotated_text(*get_annotated_text(input_text)) |