| import os |
| from typing import Optional |
| from urllib.request import urlretrieve |
|
|
| files = { |
| "original_model.py": "https://gist.githubusercontent.com/lantiga/fd36849fb1c498da949a0af635318a7b/raw/7dd20f51c2a1ff2886387f0e25c1750a485a08e1/llama_model.py", |
| "original_adapter.py": "https://gist.githubusercontent.com/awaelchli/546f33fcdb84cc9f1b661ca1ca18418d/raw/e81d8f35fb1fec53af1099349b0c455fc8c9fb01/original_adapter.py", |
| } |
|
|
|
|
| def download_original(wd: str) -> None: |
| for file, url in files.items(): |
| filepath = os.path.join(wd, file) |
| if not os.path.isfile(filepath): |
| print(f"Downloading original implementation to {filepath!r}") |
| urlretrieve(url=url, filename=file) |
| print("Done") |
| else: |
| print("Original implementation found. Skipping download.") |
|
|
|
|
| def download_from_hub(repo_id: Optional[str] = None, local_dir: str = "checkpoints/hf-llama/7B") -> None: |
| if repo_id is None: |
| raise ValueError("Please pass `--repo_id=...`. You can try googling 'huggingface hub llama' for options.") |
|
|
| from huggingface_hub import snapshot_download |
|
|
| snapshot_download(repo_id, local_dir=local_dir) |
|
|
|
|
| if __name__ == "__main__": |
| from jsonargparse import CLI |
|
|
| CLI(download_from_hub) |
|
|