psypredict-backend / download_models.py
therandomuser03's picture
Add application file
f0f84fb
import os
import gdown
from huggingface_hub import hf_hub_download
# --- Assets ---
MODEL_ID = "10GWSogJNKlPlTeWtJkDq_zc4roB1Vmnu" # Keras Face Emotion
CSV_ID = "1bJ8C1BY0rvPNKuWcBgqiUtiSzHziZokH" # Medication CSV
# Destinations
ML_ASSETS = "app/ml_assets"
FACE_MODEL_PATH = os.path.join(ML_ASSETS, "emotion_model_trained.h5")
MEDS_CSV_PATH = os.path.join(ML_ASSETS, "MEDICATION.csv")
# HF Transformers (Downloaded via snapshot_download for full directory)
CRISIS_MODEL_REPO = "cross-encoder/nli-MiniLM2-L6-H768"
DISTILBERT_MODEL_REPO = "bhadresh-savani/distilbert-base-uncased-emotion"
CRISIS_MODEL_PATH = os.path.join(ML_ASSETS, "crisis_model")
DISTILBERT_MODEL_PATH = os.path.join(ML_ASSETS, "distilbert_model")
def download_drive_file(file_id, output_path):
if not os.path.exists(output_path):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
url = f'https://drive.google.com/uc?id={file_id}'
print(f"⬇️ Downloading Drive file to {output_path}...")
gdown.download(url, output_path, quiet=False)
else:
print(f"✅ Found {output_path}, skipping.")
def download_hf_model(repo_id, filename, output_path):
if not os.path.exists(output_path):
os.makedirs(os.path.dirname(output_path), exist_ok=True)
print(f"⬇️ Downloading HF model: {filename} from {repo_id}...")
hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=os.path.dirname(output_path),
local_dir_use_symlinks=False
)
# Rename to match our config expectation
downloaded_path = os.path.join(os.path.dirname(output_path), filename)
if downloaded_path != output_path:
os.rename(downloaded_path, output_path)
else:
print(f"✅ Found {output_path}, skipping.")
def download_hf_directory(repo_id, output_dir):
from huggingface_hub import snapshot_download
if not os.path.exists(output_dir) or not os.listdir(output_dir):
print(f"⬇️ Downloading HF repo: {repo_id} to {output_dir}...")
snapshot_download(
repo_id=repo_id,
local_dir=output_dir,
local_dir_use_symlinks=False,
ignore_patterns=["*.msgpack", "*.h5", "*.ot", "rust_model.ot"] # save space, only PyTorch/Safetensors needed
)
else:
print(f"✅ Found {output_dir}, skipping.")
if __name__ == "__main__":
print("🚀 Starting Production Model Sync...")
# 1. Drive Files
download_drive_file(MODEL_ID, FACE_MODEL_PATH)
download_drive_file(CSV_ID, MEDS_CSV_PATH)
# 2. HF Transformers Pipeline Models
try:
download_hf_directory(CRISIS_MODEL_REPO, CRISIS_MODEL_PATH)
download_hf_directory(DISTILBERT_MODEL_REPO, DISTILBERT_MODEL_PATH)
except Exception as e:
print(f"⚠️ HF Transformers Download failed: {e}")
print("✅ All models synchronized!")