# /// script # requires-python = ">=3.10" # dependencies = [ # "torch>=2.0.0", # "diffusers>=0.25.0", # "transformers>=4.35.0", # "accelerate>=0.24.0", # "peft>=0.7.0", # "bitsandbytes>=0.41.0", # "huggingface-hub>=0.20.0", # "safetensors>=0.4.0", # "omegaconf>=2.3.0", # "Pillow>=10.0.0", # "numpy>=1.24.0", # "tqdm>=4.66.0", # ] # /// """ Resume FLUX.2-klein-4B LoRA training from step 500 checkpoint. Output: Limbicnation/pixel-art-lora """ import os import sys import torch from pathlib import Path from huggingface_hub import hf_hub_download, snapshot_download, create_repo, upload_folder CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500" DATASET_REPO = "Limbicnation/sprite-lora-training-data" OUTPUT_REPO = "Limbicnation/pixel-art-lora" def main(): print("="*70) print("šŸš€ FLUX.2-klein-4B LoRA Training (Resuming from Step 500)") print("="*70) # Download checkpoint print("\nšŸ“„ Downloading checkpoint...") checkpoint_path = hf_hub_download( repo_id=CHECKPOINT_REPO, filename="pytorch_lora_weights.safetensors", repo_type="model", local_dir="./checkpoint_step500" ) print(f" āœ… Checkpoint: {checkpoint_path}") # Download dataset print("\nšŸ“„ Downloading dataset...") dataset_path = snapshot_download( repo_id=DATASET_REPO, repo_type="dataset", local_dir="./training_data" ) image_files = list(Path(dataset_path).rglob("*.png")) print(f" āœ… Dataset: {len(image_files)} images") # Clone trainer print("\nšŸ“„ Setting up trainer...") os.system("git clone https://github.com/Limbicnation/klein-lora-trainer.git 2>/dev/null || true") sys.path.insert(0, "./klein-lora-trainer") from flux2_klein_trainer.config import TrainingConfig, ModelConfig, LoRAConfig, DatasetConfig from flux2_klein_trainer.trainer import KleinLoRATrainer # Config config = TrainingConfig( model=ModelConfig( pretrained_model_name="black-forest-labs/FLUX.2-klein-4B", dtype="bfloat16", enable_cpu_offload=True, ), lora=LoRAConfig(rank=64, alpha=128), dataset=DatasetConfig( data_dir="./training_data/images", caption_ext="txt", resolution=512, ), output_dir="./output", resume_from_checkpoint="./checkpoint_step500", num_train_steps=1000, batch_size=1, gradient_accumulation_steps=4, learning_rate=1e-4, optimizer="adamw_8bit", save_every=500, sample_every=500, trigger_word="pixel art sprite", push_to_hub=True, hub_model_id=OUTPUT_REPO, ) print(f"\nšŸ“¤ Output: {OUTPUT_REPO}") create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model") # Train print("\nšŸ‹ļø Starting Training...") trainer = KleinLoRATrainer(config) trainer.train() print("\nāœ… Complete!") print(f"šŸ“¤ Model saved to: {OUTPUT_REPO}") if __name__ == "__main__": main()