| | import os.path |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import RobertaTokenizerFast, RobertaForMaskedLM |
| | import streamlit as st |
| |
|
| |
|
| | class SimpleClassifier(nn.Module): |
| | def __init__(self, in_features: int, hidden_features: int, |
| | out_features: int, activation=nn.ReLU()): |
| | super().__init__() |
| | self.bn = nn.BatchNorm1d(in_features) |
| | self.in2hid = nn.Linear(in_features, hidden_features) |
| | self.activation = activation |
| | self.hid2hid = nn.Linear(hidden_features, hidden_features) |
| | self.hid2out = nn.Linear(hidden_features, out_features) |
| |
|
| |
|
| | |
| | self.bn2 = nn.BatchNorm1d(hidden_features) |
| |
|
| | def forward(self, X): |
| | X = self.bn(X) |
| | X = self.in2hid(X) |
| |
|
| | X = self.activation(X) |
| | X = self.hid2hid(torch.concat((X,), 1)) |
| |
|
| | X = self.activation(X) |
| | X = self.hid2out(torch.concat((X,), 1)) |
| |
|
| | X = nn.functional.sigmoid(X) |
| | return X |
| |
|
| |
|
| | @st.cache_data() |
| | def load_models(): |
| | model = RobertaForMaskedLM.from_pretrained("roberta-base") |
| | model.lm_head = nn.Identity() |
| | tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") |
| | my_classifier = SimpleClassifier(768, 768, 1) |
| | weights_path = os.path.join(__file__, "..", "twitter_model_91_5-.pth") |
| | my_classifier.load_state_dict(torch.load(weights_path, map_location=device)) |
| | my_classifier.eval() |
| | return { |
| | "tokenizer": tokenizer, |
| | "model": model, |
| | "classifier": my_classifier |
| | } |
| |
|
| |
|
| | def classify_text(text: str) -> float: |
| | models = load_models() |
| | tokenizer, model, classifier = models["tokenizer"], models["model"], models["classifier"] |
| |
|
| | X = tokenizer( |
| | text, |
| | truncation=True, |
| | max_length=128, |
| | return_tensors='pt' |
| | )["input_ids"] |
| |
|
| | X = model.forward(X)[-1][0].sum(axis=0)[None, :] |
| | return classifier(X) |
| |
|
| |
|
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| |
|