TrueOrFalseBot / app.py
RomeCaesar's picture
Got rid of ()
1915ca8
import streamlit as st
from model import QuestAnsweringDistilBERT, CustomDataset
import torch
import transformers
import pandas as pd
import torch.utils.data as data_utils
from transformers import DistilBertConfig, DistilBertTokenizer, DistilBertModel
st.markdown("## 🤖Check the fact!🤖")
st.markdown("-----------------------------------------------------------")
st.markdown("### Challenge the bot awaraness with contextual statements!", unsafe_allow_html=True)
st.markdown("Provide a passage, suggest a statement (with verdict) about it and wait the bot's opinion!", unsafe_allow_html=True)
@st.cache_resource
def load_utils():
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
model = QuestAnsweringDistilBERT(bert, DistilBertConfig(), 'part', n_layers=1,
hidden_dim=384, dim=768).to('cpu')
model.load_state_dict(torch.load("deploy_model_weights.pt", map_location='cpu'))
model.eval()
return model, tokenizer
model, tokenizer = load_utils()
@st.cache_data
def process(passage, question, answer):
df = pd.DataFrame({'passage': [passage.lower()],
'question': [question.lower()],
'answer': [1.0 if answer.lower() == "true" else 0.0]})
dataset = CustomDataset(df, tokenizer)
data = data_utils.DataLoader(dataset=dataset, batch_size=1)
return next(iter(data))[0]
passage = st.text_input("Enter an interesting passage", value="A long long time ago ...")
question = st.text_input("Enter a challenging statement about passage", value="Was Darth Vader the father?")
answer = st.text_input("Enter a 'true' or 'false' verdict about the statement", value="true")
submitted = st.button("Compare Results!")
if submitted:
ok = 1
if answer.lower() not in ["true", "false"]:
st.error("Wrong answers. Only 'true' and 'false' are acceptable!", icon="🚨")
ok = 0
if len(question) == 0:
st.error('We need a question!', icon="🔥")
ok = 0
if len(passage) == 0:
st.error('We need something exciting to investigate!', icon="🤖")
ok = 0
if ok == 1:
batch = process(passage, question, answer)
model_answer = "true" if torch.sigmoid(model(batch)).item() > 0.5 else "false"
if model_answer == answer:
st.markdown("The model thinks that the answer is '{}'. The model guessed RIGHT!".format(model_answer))
else:
st.markdown("The model thinks that the answer is '{}'. The model is WRONG. We are sorry for the mistake!".format(model_answer))