Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from model import QuestAnsweringDistilBERT, CustomDataset | |
| import torch | |
| import transformers | |
| import pandas as pd | |
| import torch.utils.data as data_utils | |
| from transformers import DistilBertConfig, DistilBertTokenizer, DistilBertModel | |
| st.markdown("## 🤖Check the fact!🤖") | |
| st.markdown("-----------------------------------------------------------") | |
| st.markdown("### Challenge the bot awaraness with contextual statements!", unsafe_allow_html=True) | |
| st.markdown("Provide a passage, suggest a statement (with verdict) about it and wait the bot's opinion!", unsafe_allow_html=True) | |
| def load_utils(): | |
| tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') | |
| bert = DistilBertModel.from_pretrained('distilbert-base-uncased') | |
| model = QuestAnsweringDistilBERT(bert, DistilBertConfig(), 'part', n_layers=1, | |
| hidden_dim=384, dim=768).to('cpu') | |
| model.load_state_dict(torch.load("deploy_model_weights.pt", map_location='cpu')) | |
| model.eval() | |
| return model, tokenizer | |
| model, tokenizer = load_utils() | |
| def process(passage, question, answer): | |
| df = pd.DataFrame({'passage': [passage.lower()], | |
| 'question': [question.lower()], | |
| 'answer': [1.0 if answer.lower() == "true" else 0.0]}) | |
| dataset = CustomDataset(df, tokenizer) | |
| data = data_utils.DataLoader(dataset=dataset, batch_size=1) | |
| return next(iter(data))[0] | |
| passage = st.text_input("Enter an interesting passage", value="A long long time ago ...") | |
| question = st.text_input("Enter a challenging statement about passage", value="Was Darth Vader the father?") | |
| answer = st.text_input("Enter a 'true' or 'false' verdict about the statement", value="true") | |
| submitted = st.button("Compare Results!") | |
| if submitted: | |
| ok = 1 | |
| if answer.lower() not in ["true", "false"]: | |
| st.error("Wrong answers. Only 'true' and 'false' are acceptable!", icon="🚨") | |
| ok = 0 | |
| if len(question) == 0: | |
| st.error('We need a question!', icon="🔥") | |
| ok = 0 | |
| if len(passage) == 0: | |
| st.error('We need something exciting to investigate!', icon="🤖") | |
| ok = 0 | |
| if ok == 1: | |
| batch = process(passage, question, answer) | |
| model_answer = "true" if torch.sigmoid(model(batch)).item() > 0.5 else "false" | |
| if model_answer == answer: | |
| st.markdown("The model thinks that the answer is '{}'. The model guessed RIGHT!".format(model_answer)) | |
| else: | |
| st.markdown("The model thinks that the answer is '{}'. The model is WRONG. We are sorry for the mistake!".format(model_answer)) |