import os import torch import logging from dotenv import load_dotenv from transformers import AutoTokenizer, AutoModelForSequenceClassification # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) load_dotenv() class FraudDetector: def __init__(self, model_name=None, hf_token=None): self.model_name = model_name or os.getenv("MODEL_NAME", "austinb/fraud_text_detection") self.hf_token = hf_token or os.getenv("HUGGINGFACEHUB_API_TOKEN") self.low_threshold = float(os.getenv("LOW_THRESHOLD", 0.3)) self.high_threshold = float(os.getenv("HIGH_THRESHOLD", 0.7)) self.max_length = int(os.getenv("MAX_LENGTH", 512)) self.tokenizer = None self.model = None self.fraud_index = None if not self.model_name: raise ValueError("MODEL_NAME not provided and not found in environment variables") self._load_model() def _load_model(self): try: logger.info(f"Loading model: {self.model_name}") self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, token=self.hf_token ) self.model = AutoModelForSequenceClassification.from_pretrained( self.model_name, token=self.hf_token ) self.model.eval() # Detect fraud label index from model config id2label = self.model.config.id2label logger.info(f"Model labels: {id2label}") for idx, label in id2label.items(): if "fraud" in label.lower() or label == "LABEL_1": self.fraud_index = idx break # Fallback: assume index 1 is fraud for binary classifiers if self.fraud_index is None: self.fraud_index = 1 logger.warning( f"Could not detect fraud label from {list(id2label.values())}. " f"Defaulting to index 1. Set FRAUD_LABEL_INDEX in .env to override." ) # Allow manual override via env env_override = os.getenv("FRAUD_LABEL_INDEX") if env_override is not None: self.fraud_index = int(env_override) logger.info(f"Fraud label index overridden by env: {self.fraud_index}") logger.info( f"Model loaded. Fraud index: {self.fraud_index} " f"(label: {id2label.get(self.fraud_index, 'unknown')})" ) except Exception as e: logger.error(f"Failed to load model: {str(e)}") raise def _tokenize(self, texts): """Shared tokenizer call with consistent settings.""" return self.tokenizer( texts, return_tensors="pt", truncation=True, padding=True, max_length=self.max_length ) def get_fraud_score(self, text: str) -> float: inputs = self._tokenize(text) with torch.no_grad(): outputs = self.model(**inputs) probs = torch.softmax(outputs.logits, dim=1) return probs[0][self.fraud_index].item() def get_fraud_scores(self, texts: list) -> list: inputs = self._tokenize(texts) with torch.no_grad(): outputs = self.model(**inputs) probs = torch.softmax(outputs.logits, dim=1) return probs[:, self.fraud_index].tolist() def risk_label(self, score: float) -> str: if score < self.low_threshold: return "Low Risk" elif score < self.high_threshold: return "Medium Risk" else: return "High Risk 🚨" def predict(self, text: str) -> dict: score = self.get_fraud_score(text) preview = text[:50] + ("..." if len(text) > 50 else "") result = { "text": text, "fraud_score": round(score, 4), "risk_level": self.risk_label(score) } logger.info(f"Prediction for '{preview}': {result['risk_level']} ({result['fraud_score']})") return result def analyze(self, text: str) -> dict: """Returns fraud score + risk level + binary detection in one call.""" score = self.get_fraud_score(text) is_fraud = score >= self.high_threshold preview = text[:50] + ("..." if len(text) > 50 else "") result = { "text": text, "fraud_score": round(score, 4), "risk_level": self.risk_label(score), "is_fraud": is_fraud, "detection": "Fraud Detected 🚨" if is_fraud else "No Fraud Detected ✅" } logger.info(f"Analyze for '{preview}': {result['detection']} | {result['risk_level']} ({result['fraud_score']})") return result def predict_batch(self, texts: list) -> list: """Batch predict with consistent logging.""" scores = self.get_fraud_scores(texts) results = [] for text, score in zip(texts, scores): preview = text[:50] + ("..." if len(text) > 50 else "") risk = self.risk_label(score) logger.info(f"Batch prediction for '{preview}': {risk} ({round(score, 4)})") results.append({ "text": text, "fraud_score": round(score, 4), "risk_level": risk }) return results # Example Usage if __name__ == "__main__": try: detector = FraudDetector() sample_text = "User transferred ₹50,000 to an unknown account at midnight" result = detector.predict(sample_text) print("\nPrediction Result:") print(result) except Exception as e: print(f"Error: {e}")