Coverage for app \ engines \ starcoder.py: 78%

88 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-30 09:36 +0100

1""" 

2StarCoder engine implementation 

3 

4Uses StarCoder2-3B (quantized) for code generation, fixing, and refactoring. 

5Loaded via llama-cpp-python (GGUF format). 

6""" 

7import logging 

8import re 

9from typing import Dict, Any, Optional 

10from pathlib import Path 

11 

12from app.engines.base import BaseEngine 

13from app.models.schemas import TaskType, Language 

14from app.config import settings 

15from app.utils.localization import get_string 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20class StarCoderEngine(BaseEngine): 

21 """StarCoder2-3B engine for code tasks""" 

22 

23 def __init__(self): 

24 super().__init__( 

25 name="starcoder", 

26 model_path=str(settings.starcoder_path) 

27 ) 

28 self.llm = None 

29 

30 async def initialize(self): 

31 """Load StarCoder model""" 

32 if self.initialized: 

33 logger.info("StarCoder already initialized") 

34 return 

35 

36 logger.info(f"Loading StarCoder from {self.model_path}") 

37 

38 try: 

39 if not Path(self.model_path).exists(): 

40 raise FileNotFoundError( 

41 f"Model file not found: {self.model_path}\n" 

42 f"Please run: python scripts/download_models.py" 

43 ) 

44 from llama_cpp import Llama 

45 self.llm = Llama( 

46 model_path=self.model_path, 

47 n_ctx=settings.n_ctx, 

48 n_threads=settings.n_threads, 

49 n_batch=512, 

50 verbose=False 

51 ) 

52 self.initialized = True 

53 logger.info("StarCoder loaded successfully") 

54 

55 except Exception as e: 

56 logger.error(f"Failed to load StarCoder: {e}") 

57 raise 

58 

59 async def process( 

60 self, 

61 task: TaskType, 

62 code: str, 

63 language: Language, 

64 context: Optional[str] = None, 

65 trace: Optional[str] = None, 

66 **kwargs 

67 ) -> Dict[str, Any]: 

68 """Process task with StarCoder""" 

69 if not self.initialized: 

70 await self.initialize() 

71 

72 try: 

73 prompt = self._build_prompt(task, code, language, context, trace) 

74 logger.info(f"StarCoder processing {task} for {language}") 

75 logger.debug(f"Prompt: {prompt[:300]}...") 

76 

77 task_max_tokens = { 

78 TaskType.FIX: 512, 

79 TaskType.EXPLAIN: 512, 

80 TaskType.REFACTOR: 1024, 

81 TaskType.TEST: 1024, 

82 TaskType.TRANSLATE: 1024, 

83 TaskType.BOILERPLATE: 2048 

84 } 

85 max_tokens = task_max_tokens.get(task, 512) 

86 

87 response = self.llm( 

88 prompt, 

89 max_tokens=max_tokens, 

90 temperature=0.1, 

91 top_p=0.9, 

92 top_k=20, 

93 repeat_penalty=1.2, 

94 stop=["\n```\n", "```\n\n", "\nExample", "\nNow fix", "\nBuggy code", "Exercise"], 

95 echo=False 

96 ) 

97 

98 generated_text = response["choices"][0]["text"] 

99 

100 if task in [TaskType.FIX, TaskType.REFACTOR, TaskType.TRANSLATE, TaskType.BOILERPLATE]: 

101 result_code = self._extract_code_from_response(generated_text, language) 

102 explanation = self._extract_explanation(generated_text) 

103 return self._format_result( 

104 success=True, 

105 result=result_code or code, # Return original code if extraction fails 

106 explanation=explanation, 

107 suggestions=self._generate_suggestions(task) 

108 ) 

109 elif task == TaskType.EXPLAIN: 

110 return self._format_result( 

111 success=True, 

112 result=None, 

113 explanation=generated_text.strip(), 

114 suggestions=[] 

115 ) 

116 elif task == TaskType.TEST: 

117 test_code = self._extract_code_from_response(generated_text, language) 

118 return self._format_result( 

119 success=True, 

120 result=test_code, 

121 explanation=get_string("starcoder_test_explanation"), 

122 suggestions=self._generate_suggestions(task) 

123 ) 

124 else: 

125 return self._format_result( 

126 success=True, 

127 result=generated_text.strip() 

128 ) 

129 except Exception as e: 

130 logger.error(f"StarCoder processing failed: {e}", exc_info=True) 

131 return self._format_result( 

132 success=False, 

133 explanation=get_string("starcoder_error", error=str(e)) 

134 ) 

135 

136 def _build_prompt(self, task: TaskType, code: str, language: Language, context: Optional[str], trace: Optional[str]) -> str: 

137 """Builds a task-specific, improved prompt.""" 

138 

139 system_prompt_content = ( 

140 "You are an expert programmer and a helpful coding assistant. " 

141 "Provide a clear and concise response. " 

142 f"The user's preferred language for explanations is {settings.language}." 

143 ) 

144 

145 system_block = f"<|system|>\n{system_prompt_content}\n<|end|>".replace('\n', '\n') 

146 

147 task_instructions = { 

148 TaskType.FIX: ( 

149 "The following code has an error. Analyze the code and the error trace, then provide a corrected version. " 

150 "Explain the fix in a comment or before the code block." 

151 ), 

152 TaskType.EXPLAIN: "Explain the following code. Describe its purpose, how it works, and any key algorithms or patterns used.", 

153 TaskType.REFACTOR: "Refactor the following code to improve its readability, performance, or maintainability. Explain the changes made.", 

154 TaskType.TEST: (f"Generate a comprehensive suite of unit tests for the following {language.value} code " 

155 "using a standard testing framework (e.g., pytest for Python, Jest for JavaScript)."), 

156 TaskType.TRANSLATE: f"Translate the following code snippet from its current language to {language.value}. Preserve logic and comments.", 

157 TaskType.BOILERPLATE: f"Generate boilerplate code for a {context} in {language.value}." 

158 } 

159 

160 instruction = task_instructions.get(task, "Process the following code:") 

161 

162 prompt_parts = [] 

163 prompt_parts.append(system_block) 

164 prompt_parts.append(f"<|user|>\n{instruction}") 

165 

166 if context: 

167 prompt_parts.append(f"\nHere is some additional context and examples:\n```\n{context}\n```") 

168 

169 if trace: 

170 prompt_parts.append(f"\nHere is the error trace:\n```\n{trace}\n```") 

171 

172 prompt_parts.append(f"\nHere is the code:\n```{{language.value}}\n{code}\n```") 

173 prompt_parts.append(f"\n<|assistant|>\n") 

174 

175 return "\n".join(prompt_parts) 

176 

177 def _extract_code_from_response(self, text: str, language: Language) -> Optional[str]: 

178 """Extracts the first code block from the model's response.""" 

179 pattern = re.compile(r"```(?:" + re.escape(language.value) + r")?\s*\n(.*?)\n```", re.DOTALL) 

180 match = pattern.search(text) 

181 if match: 

182 return match.group(1).strip() 

183 

184 if text.strip() and "```" not in text: 

185 return text.strip() 

186 

187 return None 

188 

189 def _extract_explanation(self, text: str) -> Optional[str]: 

190 """Extract explanation from response (text before the first code block).""" 

191 parts = re.split(r"```.*", text, 1) 

192 if parts and parts[0].strip(): 

193 return parts[0].strip() 

194 return None 

195 

196 def _generate_suggestions(self, task: TaskType) -> list: 

197 """Generate task-specific suggestions using localized strings.""" 

198 suggestion_keys = { 

199 TaskType.FIX: ["starcoder_suggestion_fix_1", "starcoder_suggestion_fix_2"], 

200 TaskType.REFACTOR: ["starcoder_suggestion_refactor_1", "starcoder_suggestion_refactor_2"], 

201 TaskType.TRANSLATE: ["starcoder_suggestion_translate_1", "starcoder_suggestion_translate_2"], 

202 TaskType.BOILERPLATE: ["starcoder_suggestion_boilerplate_1", "starcoder_suggestion_boilerplate_2"], 

203 TaskType.TEST: ["starcoder_suggestion_test_1", "starcoder_suggestion_test_2"] 

204 } 

205 return [get_string(key) for key in suggestion_keys.get(task, [])] 

206 

207 async def shutdown(self): 

208 """Cleanup StarCoder""" 

209 logger.info("Shutting down StarCoder engine") 

210 if self.llm: 

211 self.llm = None 

212 self.initialized = False