| from transformers import PretrainedConfig |
|
|
| class MiniMambaConfig(PretrainedConfig): |
| """ |
| Minimal or extended config class for MiniMamba. |
| Inherits from HF's PretrainedConfig so we can do: |
| model = MiniMamba.from_pretrained(...) |
| and it will load this config automatically. |
| |
| This config includes all fields from the provided config.json. |
| """ |
| model_type = "minimamba" |
|
|
| def __init__( |
| self, |
| |
| model_type="minimamba", |
| _name_or_path="Mamba_500M", |
| architectures=["MiniMamba"], |
| |
| |
| dim=1024, |
| num_layers=54, |
| num_heads=32, |
| state_dim=128, |
| num_groups=1, |
| conv_size=4, |
| use_mem_eff_path=True, |
| dt_bias=True, |
| D_has_head_dim=True, |
| learnable_init_states=False, |
| ssm_chunk_size=256, |
| vocab_size=200064, |
| ffn_dim_multiplier=2.0, |
| multiple_of=256, |
| norm_eps=1e-5, |
| init_use_depth=False, |
| init_base_std=None, |
| init_std_factor="disabled", |
| hidden_act="silu", |
| bias=False, |
| |
| |
| torch_dtype="bfloat16", |
| seed=1337, |
| |
| |
| init_args=None, |
| |
| |
| seq_len=8192, |
| weight_tying=True, |
| dropout=0.0, |
| num_epochs=1, |
| global_bsz=524288, |
| bsz=1, |
| warmup_steps=1907, |
| eval_period=50, |
| save_period=500, |
| max_lr=0.0003, |
| min_lr=3e-5, |
| max_norm=1.0, |
| dilation=1, |
| fsdp=True, |
| ddp=False, |
| mixed_precision=True, |
| cpu_offload=False, |
| sharding_strategy="full_shard", |
| state_dict_type="full", |
| auto_wrap_policy="partial", |
| backward_prefetch="backward_pre", |
| forward_prefetch=False, |
| sync_module_states=True, |
| use_orig_params=True, |
| device_id=None, |
| precision=None, |
| fsdp_modules=None, |
| use_activation_checkpointing=True, |
| use_attn=False, |
| softcap=50.0, |
| torch_compile=True, |
| |
| |
| **kwargs |
| ): |
| super().__init__( |
| |
| model_type=model_type, |
| _name_or_path=_name_or_path, |
| architectures=architectures, |
| **kwargs |
| ) |
|
|
| self.dim = dim |
| self.num_layers = num_layers |
| self.num_heads = num_heads |
| self.state_dim = state_dim |
| self.num_groups = num_groups |
| self.conv_size = conv_size |
| self.use_mem_eff_path = use_mem_eff_path |
| self.dt_bias = dt_bias |
| self.D_has_head_dim = D_has_head_dim |
| self.learnable_init_states = learnable_init_states |
| self.ssm_chunk_size = ssm_chunk_size |
| self.vocab_size = vocab_size |
| self.ffn_dim_multiplier = ffn_dim_multiplier |
| self.multiple_of = multiple_of |
| self.norm_eps = norm_eps |
| self.init_use_depth = init_use_depth |
| self.init_base_std = init_base_std |
| self.init_std_factor = init_std_factor |
| self.hidden_act = hidden_act |
| self.bias = bias |
|
|
| self.torch_dtype = torch_dtype |
| self.seed = seed |
|
|
| |
| |
| self.init_args = init_args or {} |
|
|
| self.seq_len = seq_len |
| self.weight_tying = weight_tying |
| self.dropout = dropout |
| self.num_epochs = num_epochs |
| self.global_bsz = global_bsz |
| self.bsz = bsz |
| self.warmup_steps = warmup_steps |
| self.eval_period = eval_period |
| self.save_period = save_period |
| self.max_lr = max_lr |
| self.min_lr = min_lr |
| self.max_norm = max_norm |
| self.dilation = dilation |
| self.fsdp = fsdp |
| self.ddp = ddp |
| self.mixed_precision = mixed_precision |
| self.cpu_offload = cpu_offload |
| self.sharding_strategy = sharding_strategy |
| self.state_dict_type = state_dict_type |
| self.auto_wrap_policy = auto_wrap_policy |
| self.backward_prefetch = backward_prefetch |
| self.forward_prefetch = forward_prefetch |
| self.sync_module_states = sync_module_states |
| self.use_orig_params = use_orig_params |
| self.device_id = device_id |
| self.precision = precision |
| self.fsdp_modules = fsdp_modules |
| self.use_activation_checkpointing = use_activation_checkpointing |
| self.use_attn = use_attn |
| self.softcap = softcap |
| self.torch_compile = torch_compile |
|
|
| |
| self.extra_args = kwargs |
|
|