| import os |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download |
| import torch |
|
|
| def download_smpl_model(model_path: str = "models/smpl"): |
| model_dir = Path(model_path) |
| model_dir.mkdir(parents=True, exist_ok=True) |
| |
| repo_id = os.getenv("SMPL_MODEL_REPO", "nghorbani/smpl") |
| |
| try: |
| print(f"Attempting to download SMPL model from {repo_id}...") |
| |
| files_to_download = [ |
| "SMPL_NEUTRAL.pkl", |
| "SMPL_MALE.pkl", |
| "SMPL_FEMALE.pkl" |
| ] |
| |
| for filename in files_to_download: |
| try: |
| file_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| local_dir=str(model_dir), |
| local_dir_use_symlinks=False |
| ) |
| print(f"Downloaded {filename}") |
| except Exception as e: |
| print(f"Could not download {filename}: {e}") |
| |
| if not any((model_dir / f).exists() for f in files_to_download): |
| print("Warning: No SMPL model files found. You may need to download them manually.") |
| print("Visit https://smpl.is.tue.mpg.de/ to register and download SMPL models.") |
| return False |
| |
| return True |
| except Exception as e: |
| print(f"Error downloading SMPL model: {e}") |
| print("You may need to download SMPL models manually from https://smpl.is.tue.mpg.de/") |
| return False |
|
|
| if __name__ == "__main__": |
| download_smpl_model() |
|
|
|
|