Mobile MoE Architecture Search
Collection
32 MoE models from 41 experiments exploring expert count, routing, and learning rates for mobile deployment.
•
33 items
•
Updated
Mobile-optimized MoE model configuration from architecture search.
# Load the model
from safetensors.torch import load_file
from architecture.model import Qwen3Model # Your custom model class
# Load config
import json
with open("config.json") as f:
config = json.load(f)
# Initialize model
model = Qwen3Model(config)
# Load weights
state_dict = load_file("model.safetensors")
model.load_state_dict(state_dict)
# Load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("kshitijthakkar/moe-202m-104m-12x2-10L-medium-300m-12exp")
{
"vocab_size": 151936,
"emb_dim": 512,
"n_heads": 8,
"n_layers": 10,
"n_kv_groups": 2,
"num_experts": 12,
"num_experts_per_tok": 2,
"moe_hidden_dim": 640,
"head_dim": 64,
"max_position_embeddings": 4096,
"rope_base": 1000000.0,
"qk_norm": true
}
{
"model_config": {
"vocab_size": 151936,
"emb_dim": 512,
"n_heads": 8,
"n_layers": 10,
"n_kv_groups": 2,
"num_experts": 12,
"num_experts_per_tok": 2,
"moe_hidden_dim": 640,
"head_dim": 64,
"max_position_embeddings": 4096,
"rope_base": 1000000.0,
"qk_norm": true
},
"learning_rate": 0.0001,
"batch_size": 4,
"context_length": 1024,
"warmup_ratio": 0.1,
"warmup_steps": null,
"weight_decay": 0.1,
"gradient_clip": 1.0,
"gradient_accumulation_steps": 1,
"scheduler_type": "cosine",
"wsd_decay_ratio": 0.1,
"max_steps": 2000,
"eval_steps": 500,
"eval_batches": 20,
"log_steps": 100,
"early_stopping": true,
"early_stopping_patience": 500,
"early_stopping_min_delta": 0.01,
"early_stopping_min_steps": 200,
"track_expert_balance": true,
"expert_balance_log_steps": 100,
"use_wandb": true,
"wandb_project": "moe-architecture-search",
"wandb_entity": null,
"wandb_tags": [
"medium_300m_12exp",
"architecture-search"
],
"train_data_path": null,
"val_data_path": null,
"output_dir": null,
"experiment_name": "medium_300m_12exp",
"device": "cuda",
"dtype": "bfloat16",
"gradient_checkpointing": true,
"architecture_name": "medium_300m_12exp",
"mobile_estimate": {
"tok_per_sec_fp16": 44.38417320458442,
"tok_per_sec_q8": 73.9736220076407,
"tok_per_sec_q4": 103.56307081069697,
"ttft_ms_fp16": 73.59339054545455,
"ttft_ms_q8": 49.06226036363637,
"ttft_ms_q4": 39.2498082909091,
"memory_mb_fp16": 443.9814453125,
"memory_mb_q8": 266.291357421875,
"memory_mb_q4": 167.79599609374998,
"total_params": 202381824,
"active_params": 104077824,
"meets_ttft_target": true,
"meets_throughput_target": true,
"meets_memory_target": true
}
}