Coverage for app \ engines \ starcoder.py: 78%
88 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-30 09:36 +0100
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-30 09:36 +0100
1"""
2StarCoder engine implementation
4Uses StarCoder2-3B (quantized) for code generation, fixing, and refactoring.
5Loaded via llama-cpp-python (GGUF format).
6"""
7import logging
8import re
9from typing import Dict, Any, Optional
10from pathlib import Path
12from app.engines.base import BaseEngine
13from app.models.schemas import TaskType, Language
14from app.config import settings
15from app.utils.localization import get_string
17logger = logging.getLogger(__name__)
20class StarCoderEngine(BaseEngine):
21 """StarCoder2-3B engine for code tasks"""
23 def __init__(self):
24 super().__init__(
25 name="starcoder",
26 model_path=str(settings.starcoder_path)
27 )
28 self.llm = None
30 async def initialize(self):
31 """Load StarCoder model"""
32 if self.initialized:
33 logger.info("StarCoder already initialized")
34 return
36 logger.info(f"Loading StarCoder from {self.model_path}")
38 try:
39 if not Path(self.model_path).exists():
40 raise FileNotFoundError(
41 f"Model file not found: {self.model_path}\n"
42 f"Please run: python scripts/download_models.py"
43 )
44 from llama_cpp import Llama
45 self.llm = Llama(
46 model_path=self.model_path,
47 n_ctx=settings.n_ctx,
48 n_threads=settings.n_threads,
49 n_batch=512,
50 verbose=False
51 )
52 self.initialized = True
53 logger.info("StarCoder loaded successfully")
55 except Exception as e:
56 logger.error(f"Failed to load StarCoder: {e}")
57 raise
59 async def process(
60 self,
61 task: TaskType,
62 code: str,
63 language: Language,
64 context: Optional[str] = None,
65 trace: Optional[str] = None,
66 **kwargs
67 ) -> Dict[str, Any]:
68 """Process task with StarCoder"""
69 if not self.initialized:
70 await self.initialize()
72 try:
73 prompt = self._build_prompt(task, code, language, context, trace)
74 logger.info(f"StarCoder processing {task} for {language}")
75 logger.debug(f"Prompt: {prompt[:300]}...")
77 task_max_tokens = {
78 TaskType.FIX: 512,
79 TaskType.EXPLAIN: 512,
80 TaskType.REFACTOR: 1024,
81 TaskType.TEST: 1024,
82 TaskType.TRANSLATE: 1024,
83 TaskType.BOILERPLATE: 2048
84 }
85 max_tokens = task_max_tokens.get(task, 512)
87 response = self.llm(
88 prompt,
89 max_tokens=max_tokens,
90 temperature=0.1,
91 top_p=0.9,
92 top_k=20,
93 repeat_penalty=1.2,
94 stop=["\n```\n", "```\n\n", "\nExample", "\nNow fix", "\nBuggy code", "Exercise"],
95 echo=False
96 )
98 generated_text = response["choices"][0]["text"]
100 if task in [TaskType.FIX, TaskType.REFACTOR, TaskType.TRANSLATE, TaskType.BOILERPLATE]:
101 result_code = self._extract_code_from_response(generated_text, language)
102 explanation = self._extract_explanation(generated_text)
103 return self._format_result(
104 success=True,
105 result=result_code or code, # Return original code if extraction fails
106 explanation=explanation,
107 suggestions=self._generate_suggestions(task)
108 )
109 elif task == TaskType.EXPLAIN:
110 return self._format_result(
111 success=True,
112 result=None,
113 explanation=generated_text.strip(),
114 suggestions=[]
115 )
116 elif task == TaskType.TEST:
117 test_code = self._extract_code_from_response(generated_text, language)
118 return self._format_result(
119 success=True,
120 result=test_code,
121 explanation=get_string("starcoder_test_explanation"),
122 suggestions=self._generate_suggestions(task)
123 )
124 else:
125 return self._format_result(
126 success=True,
127 result=generated_text.strip()
128 )
129 except Exception as e:
130 logger.error(f"StarCoder processing failed: {e}", exc_info=True)
131 return self._format_result(
132 success=False,
133 explanation=get_string("starcoder_error", error=str(e))
134 )
136 def _build_prompt(self, task: TaskType, code: str, language: Language, context: Optional[str], trace: Optional[str]) -> str:
137 """Builds a task-specific, improved prompt."""
139 system_prompt_content = (
140 "You are an expert programmer and a helpful coding assistant. "
141 "Provide a clear and concise response. "
142 f"The user's preferred language for explanations is {settings.language}."
143 )
145 system_block = f"<|system|>\n{system_prompt_content}\n<|end|>".replace('\n', '\n')
147 task_instructions = {
148 TaskType.FIX: (
149 "The following code has an error. Analyze the code and the error trace, then provide a corrected version. "
150 "Explain the fix in a comment or before the code block."
151 ),
152 TaskType.EXPLAIN: "Explain the following code. Describe its purpose, how it works, and any key algorithms or patterns used.",
153 TaskType.REFACTOR: "Refactor the following code to improve its readability, performance, or maintainability. Explain the changes made.",
154 TaskType.TEST: (f"Generate a comprehensive suite of unit tests for the following {language.value} code "
155 "using a standard testing framework (e.g., pytest for Python, Jest for JavaScript)."),
156 TaskType.TRANSLATE: f"Translate the following code snippet from its current language to {language.value}. Preserve logic and comments.",
157 TaskType.BOILERPLATE: f"Generate boilerplate code for a {context} in {language.value}."
158 }
160 instruction = task_instructions.get(task, "Process the following code:")
162 prompt_parts = []
163 prompt_parts.append(system_block)
164 prompt_parts.append(f"<|user|>\n{instruction}")
166 if context:
167 prompt_parts.append(f"\nHere is some additional context and examples:\n```\n{context}\n```")
169 if trace:
170 prompt_parts.append(f"\nHere is the error trace:\n```\n{trace}\n```")
172 prompt_parts.append(f"\nHere is the code:\n```{{language.value}}\n{code}\n```")
173 prompt_parts.append(f"\n<|assistant|>\n")
175 return "\n".join(prompt_parts)
177 def _extract_code_from_response(self, text: str, language: Language) -> Optional[str]:
178 """Extracts the first code block from the model's response."""
179 pattern = re.compile(r"```(?:" + re.escape(language.value) + r")?\s*\n(.*?)\n```", re.DOTALL)
180 match = pattern.search(text)
181 if match:
182 return match.group(1).strip()
184 if text.strip() and "```" not in text:
185 return text.strip()
187 return None
189 def _extract_explanation(self, text: str) -> Optional[str]:
190 """Extract explanation from response (text before the first code block)."""
191 parts = re.split(r"```.*", text, 1)
192 if parts and parts[0].strip():
193 return parts[0].strip()
194 return None
196 def _generate_suggestions(self, task: TaskType) -> list:
197 """Generate task-specific suggestions using localized strings."""
198 suggestion_keys = {
199 TaskType.FIX: ["starcoder_suggestion_fix_1", "starcoder_suggestion_fix_2"],
200 TaskType.REFACTOR: ["starcoder_suggestion_refactor_1", "starcoder_suggestion_refactor_2"],
201 TaskType.TRANSLATE: ["starcoder_suggestion_translate_1", "starcoder_suggestion_translate_2"],
202 TaskType.BOILERPLATE: ["starcoder_suggestion_boilerplate_1", "starcoder_suggestion_boilerplate_2"],
203 TaskType.TEST: ["starcoder_suggestion_test_1", "starcoder_suggestion_test_2"]
204 }
205 return [get_string(key) for key in suggestion_keys.get(task, [])]
207 async def shutdown(self):
208 """Cleanup StarCoder"""
209 logger.info("Shutting down StarCoder engine")
210 if self.llm:
211 self.llm = None
212 self.initialized = False