File size: 2,662 Bytes
6d3b8ba
 
 
 
 
 
 
 
4c1a1ff
 
9ae60b4
bfef3bd
6452527
6d3b8ba
 
 
 
 
9d4693d
6d3b8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
9d4693d
6d3b8ba
 
 
bfef3bd
6d3b8ba
1915ca8
6d3b8ba
9d4693d
6d3b8ba
9d4693d
 
 
 
 
 
 
 
 
87451ee
9d4693d
 
 
 
87451ee
9d4693d
 
 
 
 
 
 
 
 
87451ee
6452527
9d4693d
6452527
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import streamlit as st
from model import QuestAnsweringDistilBERT, CustomDataset
import torch
import transformers
import pandas as pd
import torch.utils.data as data_utils
from transformers import DistilBertConfig, DistilBertTokenizer, DistilBertModel

st.markdown("## 🤖Check the fact!🤖")
st.markdown("-----------------------------------------------------------")
st.markdown("### Challenge the bot awaraness with contextual statements!", unsafe_allow_html=True)
st.markdown("Provide a passage, suggest a statement (with verdict) about it and wait the bot's opinion!", unsafe_allow_html=True)


@st.cache_resource
def load_utils():
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    bert = DistilBertModel.from_pretrained('distilbert-base-uncased')
    model = QuestAnsweringDistilBERT(bert, DistilBertConfig(), 'part', n_layers=1,
                                 hidden_dim=384, dim=768).to('cpu')
    model.load_state_dict(torch.load("deploy_model_weights.pt", map_location='cpu'))
    model.eval()
    return model, tokenizer

model, tokenizer = load_utils()

@st.cache_data
def process(passage, question, answer):
    df = pd.DataFrame({'passage': [passage.lower()],
		       'question': [question.lower()],
		       'answer': [1.0 if answer.lower() == "true" else 0.0]})
    dataset = CustomDataset(df, tokenizer)
    data = data_utils.DataLoader(dataset=dataset, batch_size=1)
    return next(iter(data))[0]

passage = st.text_input("Enter an interesting passage", value="A long long time ago ...")

question = st.text_input("Enter a challenging statement about passage", value="Was Darth Vader the father?")

answer = st.text_input("Enter a 'true' or 'false' verdict about the statement", value="true")

submitted = st.button("Compare Results!")

if submitted:
    
    ok = 1
    
    if answer.lower() not in ["true", "false"]:
        st.error("Wrong answers. Only 'true' and 'false' are acceptable!", icon="🚨")
        ok = 0


    if len(question) == 0:
        st.error('We need a question!', icon="🔥") 
        ok = 0


    if len(passage) == 0:
        st.error('We need something exciting to investigate!', icon="🤖") 
        ok = 0
    
    if ok == 1:

        batch = process(passage, question, answer)

        model_answer = "true" if torch.sigmoid(model(batch)).item() > 0.5 else "false"

        if model_answer == answer:
            st.markdown("The model thinks that the answer is '{}'. The model guessed RIGHT!".format(model_answer))
        else:
            st.markdown("The model thinks that the answer is '{}'. The model is WRONG. We are sorry for the mistake!".format(model_answer))