Spaces:
Sleeping
Sleeping
File size: 2,662 Bytes
6d3b8ba 4c1a1ff 9ae60b4 bfef3bd 6452527 6d3b8ba 9d4693d 6d3b8ba 9d4693d 6d3b8ba bfef3bd 6d3b8ba 1915ca8 6d3b8ba 9d4693d 6d3b8ba 9d4693d 87451ee 9d4693d 87451ee 9d4693d 87451ee 6452527 9d4693d 6452527 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | import streamlit as st
from model import QuestAnsweringDistilBERT, CustomDataset
import torch
import transformers
import pandas as pd
import torch.utils.data as data_utils
from transformers import DistilBertConfig, DistilBertTokenizer, DistilBertModel
st.markdown("## 🤖Check the fact!🤖")
st.markdown("-----------------------------------------------------------")
st.markdown("### Challenge the bot awaraness with contextual statements!", unsafe_allow_html=True)
st.markdown("Provide a passage, suggest a statement (with verdict) about it and wait the bot's opinion!", unsafe_allow_html=True)
@st.cache_resource
def load_utils():
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
model = QuestAnsweringDistilBERT(bert, DistilBertConfig(), 'part', n_layers=1,
hidden_dim=384, dim=768).to('cpu')
model.load_state_dict(torch.load("deploy_model_weights.pt", map_location='cpu'))
model.eval()
return model, tokenizer
model, tokenizer = load_utils()
@st.cache_data
def process(passage, question, answer):
df = pd.DataFrame({'passage': [passage.lower()],
'question': [question.lower()],
'answer': [1.0 if answer.lower() == "true" else 0.0]})
dataset = CustomDataset(df, tokenizer)
data = data_utils.DataLoader(dataset=dataset, batch_size=1)
return next(iter(data))[0]
passage = st.text_input("Enter an interesting passage", value="A long long time ago ...")
question = st.text_input("Enter a challenging statement about passage", value="Was Darth Vader the father?")
answer = st.text_input("Enter a 'true' or 'false' verdict about the statement", value="true")
submitted = st.button("Compare Results!")
if submitted:
ok = 1
if answer.lower() not in ["true", "false"]:
st.error("Wrong answers. Only 'true' and 'false' are acceptable!", icon="🚨")
ok = 0
if len(question) == 0:
st.error('We need a question!', icon="🔥")
ok = 0
if len(passage) == 0:
st.error('We need something exciting to investigate!', icon="🤖")
ok = 0
if ok == 1:
batch = process(passage, question, answer)
model_answer = "true" if torch.sigmoid(model(batch)).item() > 0.5 else "false"
if model_answer == answer:
st.markdown("The model thinks that the answer is '{}'. The model guessed RIGHT!".format(model_answer))
else:
st.markdown("The model thinks that the answer is '{}'. The model is WRONG. We are sorry for the mistake!".format(model_answer)) |