| | import os |
| |
|
| | import gdown as gdown |
| | import nltk |
| | import streamlit as st |
| | from nltk.tokenize import sent_tokenize |
| |
|
| | from source.pipeline import MultiLabelPipeline, inputs_to_dataset |
| |
|
| |
|
| | def download_models(ids): |
| | """ |
| | Download all models. |
| | |
| | :param ids: name and links of models |
| | :return: |
| | """ |
| |
|
| | |
| | nltk.download('punkt') |
| |
|
| | |
| | for key in ids: |
| | if not os.path.isfile(f"model/{key}.pt"): |
| | url = f"https://drive.google.com/uc?id={ids[key]}" |
| | gdown.download(url=url, output=f"model/{key}.pt") |
| |
|
| |
|
| | @st.cache |
| | def load_labels(): |
| | """ |
| | Load model labels. |
| | |
| | :return: |
| | """ |
| |
|
| | return [ |
| | "admiration", |
| | "amusement", |
| | "anger", |
| | "annoyance", |
| | "approval", |
| | "caring", |
| | "confusion", |
| | "curiosity", |
| | "desire", |
| | "disappointment", |
| | "disapproval", |
| | "disgust", |
| | "embarrassment", |
| | "excitement", |
| | "fear", |
| | "gratitude", |
| | "grief", |
| | "joy", |
| | "love", |
| | "nervousness", |
| | "optimism", |
| | "pride", |
| | "realization", |
| | "relief", |
| | "remorse", |
| | "sadness", |
| | "surprise", |
| | "neutral" |
| | ] |
| |
|
| |
|
| | @st.cache(allow_output_mutation=True) |
| | def load_model(model_path): |
| | """ |
| | Load model and cache it. |
| | |
| | :param model_path: path to model |
| | :return: |
| | """ |
| |
|
| | model = MultiLabelPipeline(model_path=model_path) |
| |
|
| | return model |
| |
|
| |
|
| | |
| | st.set_page_config(layout="centered") |
| | st.title("Multiclass Emotion Classification") |
| | st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ") |
| |
|
| | maintenance = False |
| | if maintenance: |
| | st.write("Unavailable for now (file downloads limit). ") |
| | else: |
| | |
| | ids = {'perceiver-go-emotions': st.secrets['model']} |
| | labels = load_labels() |
| |
|
| | |
| | download_models(ids) |
| |
|
| | |
| | st.markdown(f"__Labels:__ {', '.join(labels)}") |
| |
|
| | |
| | left, right = st.columns([4, 2]) |
| | inputs = left.text_area('', max_chars=4096, value='This is a space about multiclass emotion classification. Write ' |
| | 'something here to see what happens!') |
| | model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ') |
| | split = right.checkbox('Split into sentences', value=True) |
| | model = load_model(model_path=f"model/{model_path}.pt") |
| | right.write(model.device) |
| |
|
| | if split: |
| | if not inputs.isspace() and inputs != "": |
| | with st.spinner('Processing text... This may take a while.'): |
| | left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1)) |
| | else: |
| | if not inputs.isspace() and inputs != "": |
| | with st.spinner('Processing text... This may take a while.'): |
| | left.write(model(inputs_to_dataset([inputs]), batch_size=1)) |
| |
|