Coverage for app \ engines \ codet5.py: 16%

91 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-30 09:36 +0100

1""" 

2CodeT5 engine implementation 

3 

4Uses CodeT5-small for code explanation and translation. 

5Loaded via HuggingFace transformers. 

6""" 

7import logging 

8from typing import Dict, Any, Optional 

9from pathlib import Path 

10 

11from app.engines.base import BaseEngine 

12from app.models.schemas import TaskType, Language 

13from app.config import settings 

14from app.utils.localization import get_string 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19class CodeT5Engine(BaseEngine): 

20 """CodeT5-small engine for explanations and translation""" 

21 

22 def __init__(self): 

23 super().__init__( 

24 name="codet5", 

25 model_path=str(settings.codet5_path) 

26 ) 

27 self.tokenizer = None 

28 self.model_instance = None 

29 self.torch = None 

30 

31 async def initialize(self): 

32 """Load CodeT5 model""" 

33 if self.initialized: 

34 logger.info("CodeT5 already initialized") 

35 return 

36 

37 logger.info(f"Loading CodeT5 from {self.model_path}") 

38 

39 try: 

40 from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 

41 import torch 

42 self.torch = torch 

43 

44 model_name = "Salesforce/codet5-small" 

45 if Path(self.model_path).exists(): 

46 model_name = str(self.model_path) 

47 else: 

48 logger.warning(f"Local model not found at {self.model_path}, using default: {model_name}") 

49 

50 self.tokenizer = AutoTokenizer.from_pretrained(model_name) 

51 self.model_instance = AutoModelForSeq2SeqLM.from_pretrained(model_name) 

52 self.model_instance.eval() 

53 

54 if self.torch.cuda.is_available(): 

55 self.model_instance = self.model_instance.cuda() 

56 logger.info("CodeT5 loaded on GPU") 

57 else: 

58 logger.info("CodeT5 loaded on CPU") 

59 

60 self.initialized = True 

61 logger.info("CodeT5 loaded successfully") 

62 

63 except Exception as e: 

64 logger.error(f"Failed to load CodeT5: {e}") 

65 raise 

66 

67 async def process( 

68 self, 

69 task: TaskType, 

70 code: str, 

71 language: Language, 

72 context: Optional[str] = None, 

73 trace: Optional[str] = None, 

74 **kwargs 

75 ) -> Dict[str, Any]: 

76 """Process task with CodeT5""" 

77 if not self.initialized: 

78 await self.initialize() 

79 

80 try: 

81 prompt = self._build_codet5_prompt(task, code, language, context, trace) 

82 logger.info(f"CodeT5 processing {task} for {language}") 

83 logger.debug(f"Prompt: {prompt[:200]}...") 

84 

85 inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True) 

86 if self.torch.cuda.is_available(): 

87 inputs = {k: v.cuda() for k, v in inputs.items()} 

88 

89 with self.torch.no_grad(): 

90 outputs = self.model_instance.generate( 

91 **inputs, 

92 max_length=settings.max_tokens, 

93 temperature=settings.temperature, 

94 num_beams=2, 

95 early_stopping=True 

96 ) 

97 generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) 

98 

99 if task == TaskType.EXPLAIN: 

100 return self._format_result( 

101 success=True, 

102 explanation=generated_text.strip(), 

103 suggestions=self._get_explanation_suggestions() 

104 ) 

105 elif task == TaskType.TRANSLATE: 

106 return self._format_result( 

107 success=True, 

108 result=generated_text.strip(), 

109 explanation=get_string("codet5_translate_explanation"), 

110 suggestions=[get_string("codet5_translate_suggestion")] 

111 ) 

112 else: 

113 return self._format_result(success=True, result=generated_text.strip()) 

114 

115 except Exception as e: 

116 logger.error(f"CodeT5 processing failed: {e}", exc_info=True) 

117 return self._format_result( 

118 success=False, 

119 explanation=get_string("codet5_error", error=str(e)) 

120 ) 

121 

122 def _build_codet5_prompt( 

123 self, 

124 task: TaskType, 

125 code: str, 

126 language: Language, 

127 context: Optional[str], 

128 trace: Optional[str] 

129 ) -> str: 

130 """Build an improved, task-specific prompt for CodeT5.""" 

131 

132 base_instruction = f"As an expert programmer, please perform the following task in {settings.language}." 

133 

134 if task == TaskType.EXPLAIN: 

135 if trace: 

136 instruction = ( 

137 f"{base_instruction} Explain the root cause of the following error trace " 

138 f"in the context of the provided {language.value} code." 

139 ) 

140 return f"{instruction}\n\nError Trace:\n{trace}\n\nCode:\n{code}" 

141 else: 

142 instruction = ( 

143 f"{base_instruction} Provide a concise summary of the following " 

144 f"{language.value} code. Describe its purpose and functionality." 

145 ) 

146 return f"{instruction}\n\nCode:\n{code}" 

147 

148 elif task == TaskType.TRANSLATE: 

149 target_language = "the target language" 

150 if context: 

151 match = re.search(r"to (\w+)", context, re.IGNORECASE) 

152 if match: 

153 target_language = match.group(1) 

154 

155 instruction = f"Translate the following {language.value} code to {target_language}." 

156 return f"{instruction}\n\n{code}" 

157 

158 else: 

159 return f"Process the following {language.value} code:\n{code}" 

160 

161 def _get_explanation_suggestions(self) -> list: 

162 """Get suggestions for explanation tasks using localized strings.""" 

163 return [ 

164 get_string("codet5_explanation_suggestion_1"), 

165 get_string("codet5_explanation_suggestion_2") 

166 ] 

167 

168 async def shutdown(self): 

169 """Cleanup CodeT5""" 

170 logger.info("Shutting down CodeT5 engine") 

171 if self.model_instance: 

172 del self.model_instance 

173 self.model_instance = None 

174 if self.tokenizer: 

175 del self.tokenizer 

176 self.tokenizer = None 

177 if self.torch: 

178 del self.torch 

179 self.torch = None 

180 self.initialized = False