| import os |
| import dataclasses |
| from typing import Optional, Tuple |
|
|
| @dataclasses.dataclass |
| class ModelConfig: |
| audio_model_id: str = "openai/whisper-medium" |
| text_model_id: str = "sarvamai/sarvam-m" |
| hidden_size: int = 2048 |
| projector_act: str = "gelu" |
| stack_factor: int = 8 |
|
|
| def to_dict(self): |
| return dataclasses.asdict(self) |
|
|
| @dataclasses.dataclass |
| class TrainConfig: |
| |
| batch_size: int = 8 |
| accum_steps: int = 2 |
| use_bf16: bool = True |
| gradient_checkpointing: bool = False |
| dataloader_num_workers: int = 8 |
| dataloader_pin_memory: bool = True |
|
|
| learning_rate: float = 1e-4 |
| lr_scheduler_type: str = "cosine" |
| num_epochs: int = 1 |
| max_steps: int = 10000 |
|
|
| |
| output_dir: str = "./checkpoints" |
| |
| dataset_name: str = "fixie-ai/common_voice_17_0" |
| dataset_subset: str = "hi" |
| dataset_split: str = "train" |
| val_dataset_split: str = "validation" |
| |
| |
| use_lora: bool = True |
| lora_r: int = 16 |
| lora_alpha: int = 32 |
| lora_dropout: float = 0.05 |
| |
| |
| push_to_hub: bool = False |
| hub_model_id: Optional[str] = os.getenv("HUB_MODEL_ID", None) |
| hub_token: Optional[str] = os.getenv("HUB_TOKEN", None) |
| hub_private_repo: bool = True |
|
|
| |
| wandb_project: str = os.getenv("WANDB_PROJECT", "audio-language-model") |
| wandb_entity: Optional[str] = os.getenv("WANDB_ENTITY", None) |
| wandb_run_name: Optional[str] = None |
| wandb_watch: str = "false" |
| wandb_log_model: str = "false" |
|
|
| |
| seed: int = 42 |
| log_steps: int = 10 |
| eval_steps: int = 250 |
| save_steps: int = 500 |
| save_total_limit: int = 1 |
| sample_pred_every_steps: int = 250 |
|
|