| """ |
| Sheikh-2.5-Coder Model Implementation |
| ==================================== |
| |
| This module implements the Sheikh-2.5-Coder model architecture, a 3B parameter |
| transformer model optimized for code generation and on-device deployment. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple, List |
| from dataclasses import dataclass |
| from transformers import ( |
| PreTrainedModel, |
| PreTrainedTokenizer, |
| AutoConfig, |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| BitsAndBytesConfig, |
| TrainingArguments |
| ) |
| import json |
|
|
| @dataclass |
| class SheikhConfig: |
| """Configuration class for Sheikh-2.5-Coder model.""" |
| |
| |
| num_attention_heads: int = 16 |
| num_key_value_heads: int = 2 |
| hidden_size: int = 3072 |
| intermediate_size: int = 8192 |
| num_hidden_layers: int = 36 |
| vocab_size: int = 50257 |
| |
| |
| max_position_embeddings: int = 32768 |
| rope_theta: float = 10000.0 |
| |
| |
| attention_dropout: float = 0.1 |
| hidden_dropout: float = 0.1 |
| |
| |
| layer_norm_epsilon: float = 1e-6 |
| rms_norm_eps: float = 1e-6 |
| |
| |
| activation_function: str = "swiglu" |
| |
| |
| torch_dtype: str = "bfloat16" |
| |
| |
| use_cache: bool = True |
| |
| |
| tie_word_embeddings: bool = True |
|
|
| class SheikhRMSNorm(nn.Module): |
| """Root Mean Square Layer Normalization.""" |
| |
| def __init__(self, hidden_size: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| input_dtype = x.dtype |
| x = x.float() |
| variance = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(variance + self.eps) |
| return (self.weight * x).to(input_dtype) |
|
|
| class SheikhRotaryEmbedding(nn.Module): |
| """Rotary Positional Embedding.""" |
| |
| def __init__(self, dim: int, max_position_embeddings: int = 32768, base: int = 10000): |
| super().__init__() |
| self.dim = dim |
| self.max_position_embeddings = max_position_embeddings |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self._set_cos_sin_cache( |
| seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32 |
| ) |
| |
| def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): |
| self.max_seq_len_cached = seq_len |
| t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) |
| self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) |
| |
| def forward(self, x: torch.Tensor, seq_len: Optional[int] = None): |
| if seq_len > self.max_seq_len_cached: |
| self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
| return ( |
| self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
| self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
| ) |
|
|
| class SheikhAttention(nn.Module): |
| """Multi-head attention with Grouped Query Attention.""" |
| |
| def __init__(self, config: SheikhConfig): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.hidden_size // self.num_heads |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| |
| self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) |
| self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| |
| self.rotary_emb = SheikhRotaryEmbedding( |
| self.head_dim, max_position_embeddings=config.max_position_embeddings |
| ) |
| |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| ): |
| bsz, q_len, _ = hidden_states.size() |
| |
| |
| q = self.q_proj(hidden_states) |
| k = self.k_proj(hidden_states) |
| v = self.v_proj(hidden_states) |
| |
| |
| q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| |
| |
| cos, sin = self.rotary_emb(v, seq_len=q_len) |
| q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) |
| |
| |
| k = repeat_kv(k, self.num_key_value_groups) |
| v = repeat_kv(v, self.num_key_value_groups) |
| |
| |
| attn_output = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=True |
| ) |
| |
| |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.view(bsz, q_len, self.hidden_size) |
| attn_output = self.o_proj(attn_output) |
| |
| if not output_attentions: |
| attn_weights = None |
| |
| outputs = (attn_output,) |
| if output_attentions: |
| outputs += (attn_weights,) |
| |
| return outputs |
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """Repeat key/value states for grouped query attention.""" |
| batch, slen, num_key_value_heads, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, :, None, :].repeat(1, 1, 1, n_rep, 1) |
| return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim) |
|
|
| def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor): |
| """Apply rotary positional embeddings.""" |
| def rotate_half(x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
| |
| cos = cos.squeeze(1).squeeze(0) |
| sin = sin.squeeze(1).squeeze(0) |
| |
| cos = cos[position_ids].unsqueeze(1) |
| sin = sin[position_ids].unsqueeze(1) |
| |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
| class SheikhMLP(nn.Module): |
| """SwiGLU MLP.""" |
| |
| def __init__(self, config: SheikhConfig): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
| class SheikhTransformerBlock(nn.Module): |
| """Transformer block for Sheikh-2.5-Coder.""" |
| |
| def __init__(self, config: SheikhConfig): |
| super().__init__() |
| self.self_attn = SheikhAttention(config) |
| self.mlp = SheikhMLP(config) |
| self.input_layernorm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| ): |
| |
| attn_output, _ = self.self_attn( |
| self.input_layernorm(hidden_states), |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
| hidden_states = hidden_states + attn_output |
| |
| |
| mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) |
| hidden_states = hidden_states + mlp_output |
| |
| return hidden_states |
|
|
| class SheikhModel(PreTrainedModel): |
| """Sheikh-2.5-Coder base model.""" |
| |
| def __init__(self, config: SheikhConfig): |
| super().__init__(config) |
| self.config = config |
| |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.layers = nn.ModuleList([SheikhTransformerBlock(config) for _ in range(config.num_hidden_layers)]) |
| self.norm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
| |
| self.apply(self._init_weights) |
| |
| def _init_weights(self, module): |
| """Initialize model weights.""" |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
| def get_input_embeddings(self): |
| return self.embed_tokens |
| |
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[List[torch.Tensor]] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ): |
| |
| pass |
|
|
| |
| def load_sheikh_model( |
| model_name_or_path: str, |
| device_map: Optional[str] = "auto", |
| torch_dtype: torch.dtype = torch.bfloat16, |
| load_in_8bit: bool = False, |
| load_in_4bit: bool = False, |
| ) -> AutoModelForCausalLM: |
| """Load Sheikh-2.5-Coder model with optional quantization.""" |
| |
| |
| quantization_config = None |
| if load_in_8bit: |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True) |
| elif load_in_4bit: |
| quantization_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| ) |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name_or_path, |
| device_map=device_map, |
| torch_dtype=torch_dtype, |
| quantization_config=quantization_config, |
| ) |
| |
| return model, tokenizer |
|
|
| |
| def setup_training_args(output_dir: str, learning_rate: float = 1e-4) -> TrainingArguments: |
| """Setup training arguments for Sheikh-2.5-Coder.""" |
| |
| return TrainingArguments( |
| output_dir=output_dir, |
| learning_rate=learning_rate, |
| per_device_train_batch_size=8, |
| per_device_eval_batch_size=8, |
| num_train_epochs=3, |
| max_steps=100000, |
| logging_steps=100, |
| save_steps=2000, |
| eval_steps=1000, |
| warmup_steps=2000, |
| fp16=True, |
| bf16=True, |
| gradient_accumulation_steps=4, |
| gradient_checkpointing=True, |
| remove_unused_columns=False, |
| dataloader_pin_memory=True, |
| report_to="wandb", |
| run_name="sheikh-2.5-coder", |
| ) |
|
|
| if __name__ == "__main__": |
| |
| config = SheikhConfig() |
| model = SheikhModel(config) |
| |
| |
| with open("config.json", "w") as f: |
| json.dump(config.__dict__, f, indent=2) |
| |
| print("Sheikh-2.5-Coder model configuration created successfully!") |
| print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") |