| import os |
|
|
| import gdown as gdown |
| import nltk |
| import streamlit as st |
| from nltk.tokenize import sent_tokenize |
|
|
| from source.pipeline import MultiLabelPipeline, inputs_to_dataset |
|
|
|
|
| def download_models(ids): |
| """ |
| Download all models. |
| |
| :param ids: name and links of models |
| :return: |
| """ |
|
|
| |
| nltk.download('punkt') |
|
|
| |
| for key in ids: |
| if not os.path.isfile(f"model/{key}.pt"): |
| url = f"https://drive.google.com/uc?id={ids[key]}" |
| gdown.download(url=url, output=f"model/{key}.pt") |
|
|
|
|
| @st.cache |
| def load_labels(): |
| """ |
| Load model labels. |
| |
| :return: |
| """ |
|
|
| return [ |
| "admiration", |
| "amusement", |
| "anger", |
| "annoyance", |
| "approval", |
| "caring", |
| "confusion", |
| "curiosity", |
| "desire", |
| "disappointment", |
| "disapproval", |
| "disgust", |
| "embarrassment", |
| "excitement", |
| "fear", |
| "gratitude", |
| "grief", |
| "joy", |
| "love", |
| "nervousness", |
| "optimism", |
| "pride", |
| "realization", |
| "relief", |
| "remorse", |
| "sadness", |
| "surprise", |
| "neutral" |
| ] |
|
|
|
|
| @st.cache(allow_output_mutation=True) |
| def load_model(model_path): |
| """ |
| Load model and cache it. |
| |
| :param model_path: path to model |
| :return: |
| """ |
|
|
| model = MultiLabelPipeline(model_path=model_path) |
|
|
| return model |
|
|
|
|
| |
| st.set_page_config(layout="centered") |
| st.title("Multiclass Emotion Classification") |
| st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ") |
|
|
| maintenance = False |
| if maintenance: |
| st.write("Unavailable for now (file downloads limit). ") |
| else: |
| |
| ids = {'perceiver-go-emotions': st.secrets['model']} |
| labels = load_labels() |
|
|
| |
| download_models(ids) |
|
|
| |
| st.markdown(f"__Labels:__ {', '.join(labels)}") |
|
|
| |
| left, right = st.columns([4, 2]) |
| inputs = left.text_area('', max_chars=4096, value='This is a space about multiclass emotion classification. Write ' |
| 'something here to see what happens!') |
| model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ') |
| split = right.checkbox('Split into sentences', value=True) |
| model = load_model(model_path=f"model/{model_path}.pt") |
| right.write(model.device) |
|
|
| if split: |
| if not inputs.isspace() and inputs != "": |
| with st.spinner('Processing text... This may take a while.'): |
| left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1)) |
| else: |
| if not inputs.isspace() and inputs != "": |
| with st.spinner('Processing text... This may take a while.'): |
| left.write(model(inputs_to_dataset([inputs]), batch_size=1)) |
|
|