from sentence_transformers import SentenceTransformer import os from huggingface_hub import hf_hub_download from joblib import load # <-- Import this to load the model # Get the Hugging Face token from environment variable hf_token = os.getenv('HF_TOKEN') # Hugging Face Model ID and local model directory hf_model_id = 'Alibaba-NLP/gte-base-en-v1.5' model_dir = '/tmp/sentence_transformer' # Use /tmp for write permissions clf_model_path = 'models/logistic_regression_model.pkl' # Create model directory if not exists os.makedirs(model_dir, exist_ok=True) # Download model if not already downloaded if not os.path.exists(os.path.join(model_dir, 'config.json')): print(f"Downloading model '{hf_model_id}' from Hugging Face...") model = SentenceTransformer(hf_model_id, use_auth_token=hf_token, trust_remote_code=True) model.save(model_dir) else: print(f"Loading model from local directory: {model_dir}") model = SentenceTransformer(model_dir, trust_remote_code=True) # Added trust_remote_code=True # ✅ Load the logistic regression model and define clf globally clf = None # Initialize as None if os.path.exists(clf_model_path): clf = load(clf_model_path) # Load the logistic regression model print("Logistic Regression model loaded successfully.") else: print("Logistic Regression model not found. Ensure it is saved in /tmp.") # Define predict_label function def predict_label(text): try: # Check if clf is loaded if clf is None: raise ValueError("Logistic Regression model is not loaded.") # Ensure input is a list for the model if not isinstance(text, list): text = [text] # Generate embeddings embeddings = model.encode(text) # Ensure embeddings are in the correct shape if len(embeddings) == 0: raise ValueError("No embeddings generated.") # Predict using the logistic regression model prediction = clf.predict(embeddings) probability = clf.predict_proba(embeddings).max() # Convert label to string ("0" or "1") label = str(prediction[0]) # Return label and probability return label, float(probability) except Exception as e: # Log the exception for debugging print(f"Error in predict_label: {e}") return "Error", 0.0