| import streamlit as st |
| import numpy as np |
| import transformers |
| import re |
| import string |
| import preprocessor as pre |
|
|
| import torch |
| from transformers import BertTokenizer, BertForSequenceClassification |
|
|
| with open("style.css") as f: |
| st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True) |
|
|
| |
| model_path = "ninahf1503/SA-BERTchatgptapp" |
| tokenizer = BertTokenizer.from_pretrained(model_path) |
| model = BertForSequenceClassification.from_pretrained(model_path, ignore_mismatched_sizes=True ) |
|
|
| |
| seq_max_length = 55 |
|
|
| |
| def tokenizing_text(sentence): |
| sentence = preprocess_text(sentence) |
| encoded = tokenizer.encode_plus( |
| sentence, |
| add_special_tokens=True, |
| max_length=seq_max_length, |
| truncation=True, |
| padding='max_length', |
| return_tensors='pt' |
| ) |
|
|
| input_ids = encoded['input_ids'] |
| attention_mask = encoded['attention_mask'] |
| return input_ids, attention_mask |
|
|
| |
| def preprocess_text(sentence): |
| re_cleansing = "@\S+|https?:\S+|http?:\S|#[A-Za-z0-9]+|^RT[\s]+|(^|\W)\d+" |
| for punctuation in string.punctuation: |
| sentence = sentence.encode().decode('unicode_escape') |
| sentence = re.sub(r'\n', ' ', sentence) |
| sentence = pre.clean(sentence) |
| sentence = re.sub(r'[^\w\s]', ' ', sentence) |
| sentence = re.sub(r'[0-9]', ' ', sentence) |
| sentence = re.sub(re_cleansing, ' ', sentence).strip() |
| sentence = sentence.replace(punctuation, '') |
| sentence = sentence.lower() |
| return sentence |
|
|
| |
| def predict_sentiment(input_text): |
| input_ids, attention_mask = tokenizing_text(input_text) |
|
|
| with torch.no_grad(): |
| outputs = model(input_ids, attention_mask) |
| |
| logits = outputs.logits |
| predict_class = torch.argmax(logits, dim=1).item() |
| |
| label_sentiment = {0: "Bad", 1: "Good", 2: "Neutral"} |
| predict_label = label_sentiment[predict_class] |
| |
| return predict_label |
|
|
|
|
|
|
| |
| def main(): |
| st.title("Sentimen Analysis", anchor=False) |
| tweet_text = st.text_area(" ", placeholder="Enter the sentence you want to analyze", label_visibility="collapsed") |
| |
| if st.button("SUBMIT"): |
| if tweet_text.strip() == "": |
| st.title("Text Input Still Empty", anchor=False) |
| st.info("Please fill in the sentence you want to analyze") |
| else: |
| sentiment = predict_sentiment(tweet_text) |
| if sentiment == "Good": |
| st.title("Sentiment Analysis Results", anchor=False) |
| st.markdown('<div style="background-color: #5d9c59; padding: 16px; border-radius: 5px; font-weight: bold; color:white;">This sentence contains a positive sentiment</div>', unsafe_allow_html=True) |
| elif sentiment == "Bad": |
| st.title("Sentiment Analysis Results", anchor=False) |
| st.markdown('<div style="background-color: #df2e38; padding: 16px; border-radius: 5px; font-weight: bold; color:white;">This sentence contains a negative sentiment</div>', unsafe_allow_html=True) |
| else: |
| st.title("Sentiment Analysis Results", anchor=False) |
| st.markdown('<div style="background-color: #ffa500; padding: 16px; border-radius: 5px; font-weight: bold; color:white;">This sentence is neutral</div>', unsafe_allow_html=True) |
|
|
| if __name__ == "__main__": |
| main() |
|
|