| | |
| | """ |
| | Download SAM 2 Model Script |
| | |
| | This script downloads the SAM 2 model checkpoints and sets up the environment |
| | for few-shot and zero-shot segmentation experiments. |
| | """ |
| |
|
| | import os |
| | import sys |
| | import requests |
| | import zipfile |
| | from pathlib import Path |
| | import argparse |
| | from tqdm import tqdm |
| |
|
| |
|
| | def download_file(url: str, destination: str, chunk_size: int = 8192): |
| | """Download a file with progress bar.""" |
| | response = requests.get(url, stream=True) |
| | total_size = int(response.headers.get('content-length', 0)) |
| | |
| | with open(destination, 'wb') as file, tqdm( |
| | desc=os.path.basename(destination), |
| | total=total_size, |
| | unit='iB', |
| | unit_scale=True, |
| | unit_divisor=1024, |
| | ) as pbar: |
| | for data in response.iter_content(chunk_size=chunk_size): |
| | size = file.write(data) |
| | pbar.update(size) |
| |
|
| |
|
| | def setup_sam2_environment(): |
| | """Set up SAM 2 environment and download checkpoints.""" |
| | print("Setting up SAM 2 environment...") |
| | |
| | |
| | os.makedirs("models/checkpoints", exist_ok=True) |
| | os.makedirs("data", exist_ok=True) |
| | os.makedirs("results", exist_ok=True) |
| | |
| | |
| | sam2_urls = { |
| | "vit_h": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_h.pth", |
| | "vit_l": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_l.pth", |
| | "vit_b": "https://dl.fbaipublicfiles.com/segment_anything_2/sam2_b.pth" |
| | } |
| | |
| | |
| | for model_name, url in sam2_urls.items(): |
| | checkpoint_path = f"models/checkpoints/sam2_{model_name}.pth" |
| | |
| | if not os.path.exists(checkpoint_path): |
| | print(f"Downloading SAM 2 {model_name} checkpoint...") |
| | try: |
| | download_file(url, checkpoint_path) |
| | print(f"Successfully downloaded {model_name} checkpoint") |
| | except Exception as e: |
| | print(f"Failed to download {model_name} checkpoint: {e}") |
| | print("Please download manually from the SAM 2 repository") |
| | else: |
| | print(f"SAM 2 {model_name} checkpoint already exists") |
| | |
| | |
| | if not os.path.exists("sam2_checkpoint"): |
| | try: |
| | os.symlink("models/checkpoints/sam2_vit_h.pth", "sam2_checkpoint") |
| | print("Created symbolic link: sam2_checkpoint -> models/checkpoints/sam2_vit_h.pth") |
| | except: |
| | print("Could not create symbolic link (this is normal on Windows)") |
| |
|
| |
|
| | def install_dependencies(): |
| | """Install required dependencies.""" |
| | print("Installing dependencies...") |
| | |
| | |
| | os.system("pip install -r requirements.txt") |
| | |
| | |
| | print("Installing SAM 2...") |
| | os.system("pip install git+https://github.com/facebookresearch/segment-anything-2.git") |
| | |
| | |
| | print("Installing CLIP...") |
| | os.system("pip install git+https://github.com/openai/CLIP.git") |
| |
|
| |
|
| | def create_demo_data(): |
| | """Create demo data for testing.""" |
| | print("Creating demo data...") |
| | |
| | |
| | demo_dirs = [ |
| | "data/satellite_demo", |
| | "data/fashion_demo", |
| | "data/robotics_demo" |
| | ] |
| | |
| | for demo_dir in demo_dirs: |
| | os.makedirs(f"{demo_dir}/images", exist_ok=True) |
| | os.makedirs(f"{demo_dir}/masks", exist_ok=True) |
| | |
| | print("Demo data directories created. Run experiments to generate dummy data.") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Set up SAM 2 environment") |
| | parser.add_argument("--skip-download", action="store_true", |
| | help="Skip downloading SAM 2 checkpoints") |
| | parser.add_argument("--skip-install", action="store_true", |
| | help="Skip installing dependencies") |
| | parser.add_argument("--demo-only", action="store_true", |
| | help="Only create demo data directories") |
| | |
| | args = parser.parse_args() |
| | |
| | if args.demo_only: |
| | create_demo_data() |
| | return |
| | |
| | if not args.skip_install: |
| | install_dependencies() |
| | |
| | if not args.skip_download: |
| | setup_sam2_environment() |
| | |
| | create_demo_data() |
| | |
| | print("\nSetup complete!") |
| | print("\nNext steps:") |
| | print("1. Run few-shot satellite experiment:") |
| | print(" python experiments/few_shot_satellite.py --sam2_checkpoint sam2_checkpoint --data_dir data/satellite_demo") |
| | print("\n2. Run zero-shot fashion experiment:") |
| | print(" python experiments/zero_shot_fashion.py --sam2_checkpoint sam2_checkpoint --data_dir data/fashion_demo") |
| | print("\n3. Check the results/ directory for experiment outputs") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |