| | |
| | |
| | import pandas as pd |
| | import torch |
| | from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast |
| | from typing import Literal |
| | import os |
| | from datetime import datetime |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_model_and_tokenizer(model_dir="./past_ref_classifier/updated_model"): |
| | """ |
| | Load tokenizer and model. Adjust model_dir if needed. |
| | """ |
| | tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir) |
| | model = DistilBertForSequenceClassification.from_pretrained(model_dir) |
| | model.eval() |
| | return tokenizer, model |
| |
|
| | @torch.no_grad() |
| | def classify_prompts(df, tokenizer, model, max_length=128, device="cuda" if torch.cuda.is_available() else "cpu"): |
| | """ |
| | Take a DataFrame with 'text' column, run the classifier, and return: |
| | - pred_label: 0 or 1 |
| | - prob_past: probability of label=1 |
| | """ |
| | model.to(device) |
| | pred_labels = [] |
| | prob_pasts = [] |
| | for i, txt in enumerate(df["text"]): |
| | inputs = tokenizer( |
| | txt, |
| | truncation=True, |
| | padding="max_length", |
| | max_length=max_length, |
| | return_tensors="pt" |
| | ) |
| | inputs = {k: v.to(device) for k, v in inputs.items()} |
| | outputs = model(**inputs) |
| | logits = outputs.logits.squeeze() |
| | probs = torch.softmax(logits, dim=-1) |
| | prob_past = probs[1].item() |
| | pred_label = int(prob_past >= 0.5) |
| |
|
| | pred_labels.append(pred_label) |
| | prob_pasts.append(prob_past) |
| |
|
| | if (i + 1) % 50 == 0: |
| | print(f"Classified {i+1}/{len(df)} prompts") |
| |
|
| | df["pred_label"] = pred_labels |
| | df["prob_past"] = prob_pasts |
| | return df |
| |
|
| |
|
| | def read_txt_as_dataframe(txt_input): |
| | |
| | if os.path.isfile(txt_input): |
| | with open(txt_input, 'r', encoding='utf-8') as f: |
| | raw = f.read() |
| | else: |
| | |
| | raw = txt_input |
| |
|
| | |
| | lines = [line.strip() for line in raw.splitlines() if line.strip()] |
| |
|
| | |
| | if len(lines) > 1 and lines[0] == "[": |
| | lines.pop(0) |
| |
|
| | |
| | if lines and lines[-1] == "]": |
| | lines.pop(-1) |
| |
|
| | |
| | df = pd.DataFrame(lines, columns=['text']) |
| | return df |
| |
|
| | AllowedMode = Literal['txt_file_path', 'txt_file', 'csv_file_path', "csv_file"] |
| | AllowedOut = Literal[True, False] |
| |
|
| | def run_tagging(mode: AllowedMode, data_or_path="", out_dir=".", prefix="data", out_as_a_df_variable: AllowedOut = False): |
| | |
| | |
| | if mode=="csv_file" or mode=="csv_file_path": |
| | df = pd.read_csv(data_or_path) |
| | elif mode=="txt_file_path" or mode=="txt_file": |
| | df = read_txt_as_dataframe(data_or_path) |
| | else: |
| | return 0 |
| | |
| | tokenizer, model = load_model_and_tokenizer( |
| | model_dir="./past_ref_classifier/updated_model_3" |
| | ) |
| |
|
| | |
| | df_results = classify_prompts(df, tokenizer, model) |
| |
|
| | |
| | print("\nFirst 20 inference results:\n") |
| | print(df_results.head(20).to_string(index=False)) |
| |
|
| | ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | filename = f"{prefix}_{ts}.csv" |
| | full_path = f"{out_dir.rstrip('/')}/{filename}" |
| |
|
| | df_results.to_csv(full_path, index=False) |
| | print(f"\nSaved full results (with pred_label and prob_past) to {filename}") |
| |
|
| | if out_as_a_df_variable ==True: |
| | return df_results |
| |
|
| |
|
| | if __name__ == "__main__": |
| | runMode = int(input("Please select a running mode:\n\n1. Txt file path\n2. Csv file path\n\n")) |
| | if runMode>0 and runMode<5: |
| | if runMode==1: |
| | path_to_txt=input("Please provide path to the txt file\n") |
| | run_tagging(mode="txt_file_path", data_or_path=path_to_txt) |
| | elif runMode==2: |
| | path_to_csv=input("Please provide path to the csv file\n") |
| | run_tagging(mode="csv_file_path", data_or_path=path_to_csv) |