| | """ |
| | SHIVIK-Code Model Implementation |
| | |
| | This is a modified version of SHIVIK-M4 with: |
| | - Extended context length (32K) via YaRN RoPE scaling |
| | - Tool calling capabilities |
| | - Fill-in-the-Middle support |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import LlamaForCausalLM, LlamaConfig |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from typing import Optional, Tuple, List, Union |
| |
|
| |
|
| | class ShivikCodeConfig(LlamaConfig): |
| | """Configuration for SHIVIK-Code model.""" |
| | |
| | model_type = "shivik_code" |
| | |
| | def __init__( |
| | self, |
| | vocab_size=128279, |
| | hidden_size=2048, |
| | intermediate_size=8192, |
| | num_hidden_layers=16, |
| | num_attention_heads=32, |
| | num_key_value_heads=8, |
| | max_position_embeddings=32768, |
| | rope_theta=500000.0, |
| | rope_scaling=None, |
| | **kwargs |
| | ): |
| | |
| | if rope_scaling is None: |
| | rope_scaling = { |
| | "type": "yarn", |
| | "factor": 8.0, |
| | "original_max_position_embeddings": 4096, |
| | } |
| | |
| | super().__init__( |
| | vocab_size=vocab_size, |
| | hidden_size=hidden_size, |
| | intermediate_size=intermediate_size, |
| | num_hidden_layers=num_hidden_layers, |
| | num_attention_heads=num_attention_heads, |
| | num_key_value_heads=num_key_value_heads, |
| | max_position_embeddings=max_position_embeddings, |
| | rope_theta=rope_theta, |
| | rope_scaling=rope_scaling, |
| | **kwargs |
| | ) |
| | |
| | |
| | self.tool_call_start_id = None |
| | self.tool_call_end_id = None |
| | self.tool_result_start_id = None |
| | self.tool_result_end_id = None |
| |
|
| |
|
| | class ShivikCodeForCausalLM(LlamaForCausalLM): |
| | """ |
| | SHIVIK-Code: An agentic coding model. |
| | |
| | Extends LlamaForCausalLM with: |
| | - Tool calling support |
| | - Extended context via YaRN |
| | - FIM capability |
| | """ |
| | |
| | config_class = ShivikCodeConfig |
| | |
| | def __init__(self, config: ShivikCodeConfig): |
| | super().__init__(config) |
| | |
| | |
| | self.tool_tokens = { |
| | "call_start": config.tool_call_start_id, |
| | "call_end": config.tool_call_end_id, |
| | "result_start": config.tool_result_start_id, |
| | "result_end": config.tool_result_end_id, |
| | } |
| | |
| | def is_tool_call(self, token_id: int) -> bool: |
| | """Check if token is a tool call token.""" |
| | return token_id in [ |
| | self.tool_tokens["call_start"], |
| | self.tool_tokens["call_end"], |
| | ] |
| | |
| | def generate_with_tools( |
| | self, |
| | input_ids: torch.Tensor, |
| | tool_executor, |
| | max_new_tokens: int = 512, |
| | max_tool_calls: int = 10, |
| | **generate_kwargs |
| | ): |
| | """ |
| | Generate with automatic tool execution. |
| | |
| | Args: |
| | input_ids: Input token IDs |
| | tool_executor: Function that takes tool call JSON and returns result |
| | max_new_tokens: Max tokens per generation step |
| | max_tool_calls: Max number of tool calls allowed |
| | |
| | Returns: |
| | Full generated sequence including tool results |
| | """ |
| | current_ids = input_ids |
| | tool_call_count = 0 |
| | |
| | while tool_call_count < max_tool_calls: |
| | |
| | outputs = self.generate( |
| | current_ids, |
| | max_new_tokens=max_new_tokens, |
| | stop_strings=["</tool_call>"], |
| | **generate_kwargs |
| | ) |
| | |
| | generated = outputs[0] |
| | |
| | |
| | if self._contains_tool_call(generated): |
| | |
| | tool_call = self._extract_tool_call(generated) |
| | tool_result = tool_executor(tool_call) |
| | |
| | |
| | result_tokens = self._format_tool_result(tool_result) |
| | current_ids = torch.cat([generated, result_tokens], dim=-1) |
| | tool_call_count += 1 |
| | else: |
| | |
| | return generated |
| | |
| | return current_ids |
| | |
| | def _contains_tool_call(self, token_ids: torch.Tensor) -> bool: |
| | """Check if sequence contains a tool call.""" |
| | |
| | pass |
| | |
| | def _extract_tool_call(self, token_ids: torch.Tensor) -> dict: |
| | """Extract tool call JSON from sequence.""" |
| | |
| | pass |
| | |
| | def _format_tool_result(self, result: str) -> torch.Tensor: |
| | """Format tool result as tokens.""" |
| | |
| | pass |
| |
|
| |
|
| | |
| | from transformers import AutoConfig, AutoModelForCausalLM |
| |
|
| | AutoConfig.register("shivik_code", ShivikCodeConfig) |
| | AutoModelForCausalLM.register(ShivikCodeConfig, ShivikCodeForCausalLM) |
| |
|