| import gradio as gr |
| import os |
| import shutil |
| import subprocess |
|
|
| UPLOAD_DIR = "training_images" |
| OUTPUT_DIR = "lora_output" |
|
|
| def train_lora(images, learning_rate, num_epochs, rank): |
| if os.path.exists(UPLOAD_DIR): |
| shutil.rmtree(UPLOAD_DIR) |
| os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
| if os.path.exists(OUTPUT_DIR): |
| shutil.rmtree(OUTPUT_DIR) |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| for idx, img in enumerate(images): |
| img.save(os.path.join(UPLOAD_DIR, f"image_{idx}.png")) |
|
|
| cmd = [ |
| "python", "train_lora.py", |
| "--images_dir", UPLOAD_DIR, |
| "--output_dir", OUTPUT_DIR, |
| "--learning_rate", str(learning_rate), |
| "--num_epochs", str(num_epochs), |
| "--rank", str(rank), |
| ] |
| result = subprocess.run(cmd, capture_output=True, text=True) |
|
|
| output_file = os.path.join(OUTPUT_DIR, "lora.safetensors") |
| if os.path.exists(output_file): |
| return f"✅ Treinamento finalizado!\nModelo salvo em: {output_file}\n\nLogs:\n{result.stdout}" |
| else: |
| return f"❌ Erro no treinamento:\n{result.stderr}" |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# 🖼️ Criador & Treinador de LoRA") |
| with gr.Row(): |
| image_input = gr.File( |
| file_types=[".png", ".jpg", ".jpeg"], |
| file_types_display="images", |
| file_count="multiple", |
| label="Envie suas imagens (10–50)" |
| ) |
| with gr.Row(): |
| learning_rate = gr.Number(value=1e-4, label="Learning Rate") |
| num_epochs = gr.Number(value=10, label="Número de Epochs") |
| rank = gr.Number(value=4, label="Rank do LoRA") |
| with gr.Row(): |
| train_button = gr.Button("🚀 Treinar LoRA") |
| output_text = gr.Textbox(label="Saída", lines=15) |
|
|
| train_button.click( |
| fn=train_lora, |
| inputs=[image_input, learning_rate, num_epochs, rank], |
| outputs=output_text |
| ) |
|
|
| |
| demo |