Coverage for app \ engines \ codet5.py: 16%
91 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-30 09:36 +0100
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-30 09:36 +0100
1"""
2CodeT5 engine implementation
4Uses CodeT5-small for code explanation and translation.
5Loaded via HuggingFace transformers.
6"""
7import logging
8from typing import Dict, Any, Optional
9from pathlib import Path
11from app.engines.base import BaseEngine
12from app.models.schemas import TaskType, Language
13from app.config import settings
14from app.utils.localization import get_string
16logger = logging.getLogger(__name__)
19class CodeT5Engine(BaseEngine):
20 """CodeT5-small engine for explanations and translation"""
22 def __init__(self):
23 super().__init__(
24 name="codet5",
25 model_path=str(settings.codet5_path)
26 )
27 self.tokenizer = None
28 self.model_instance = None
29 self.torch = None
31 async def initialize(self):
32 """Load CodeT5 model"""
33 if self.initialized:
34 logger.info("CodeT5 already initialized")
35 return
37 logger.info(f"Loading CodeT5 from {self.model_path}")
39 try:
40 from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
41 import torch
42 self.torch = torch
44 model_name = "Salesforce/codet5-small"
45 if Path(self.model_path).exists():
46 model_name = str(self.model_path)
47 else:
48 logger.warning(f"Local model not found at {self.model_path}, using default: {model_name}")
50 self.tokenizer = AutoTokenizer.from_pretrained(model_name)
51 self.model_instance = AutoModelForSeq2SeqLM.from_pretrained(model_name)
52 self.model_instance.eval()
54 if self.torch.cuda.is_available():
55 self.model_instance = self.model_instance.cuda()
56 logger.info("CodeT5 loaded on GPU")
57 else:
58 logger.info("CodeT5 loaded on CPU")
60 self.initialized = True
61 logger.info("CodeT5 loaded successfully")
63 except Exception as e:
64 logger.error(f"Failed to load CodeT5: {e}")
65 raise
67 async def process(
68 self,
69 task: TaskType,
70 code: str,
71 language: Language,
72 context: Optional[str] = None,
73 trace: Optional[str] = None,
74 **kwargs
75 ) -> Dict[str, Any]:
76 """Process task with CodeT5"""
77 if not self.initialized:
78 await self.initialize()
80 try:
81 prompt = self._build_codet5_prompt(task, code, language, context, trace)
82 logger.info(f"CodeT5 processing {task} for {language}")
83 logger.debug(f"Prompt: {prompt[:200]}...")
85 inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
86 if self.torch.cuda.is_available():
87 inputs = {k: v.cuda() for k, v in inputs.items()}
89 with self.torch.no_grad():
90 outputs = self.model_instance.generate(
91 **inputs,
92 max_length=settings.max_tokens,
93 temperature=settings.temperature,
94 num_beams=2,
95 early_stopping=True
96 )
97 generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
99 if task == TaskType.EXPLAIN:
100 return self._format_result(
101 success=True,
102 explanation=generated_text.strip(),
103 suggestions=self._get_explanation_suggestions()
104 )
105 elif task == TaskType.TRANSLATE:
106 return self._format_result(
107 success=True,
108 result=generated_text.strip(),
109 explanation=get_string("codet5_translate_explanation"),
110 suggestions=[get_string("codet5_translate_suggestion")]
111 )
112 else:
113 return self._format_result(success=True, result=generated_text.strip())
115 except Exception as e:
116 logger.error(f"CodeT5 processing failed: {e}", exc_info=True)
117 return self._format_result(
118 success=False,
119 explanation=get_string("codet5_error", error=str(e))
120 )
122 def _build_codet5_prompt(
123 self,
124 task: TaskType,
125 code: str,
126 language: Language,
127 context: Optional[str],
128 trace: Optional[str]
129 ) -> str:
130 """Build an improved, task-specific prompt for CodeT5."""
132 base_instruction = f"As an expert programmer, please perform the following task in {settings.language}."
134 if task == TaskType.EXPLAIN:
135 if trace:
136 instruction = (
137 f"{base_instruction} Explain the root cause of the following error trace "
138 f"in the context of the provided {language.value} code."
139 )
140 return f"{instruction}\n\nError Trace:\n{trace}\n\nCode:\n{code}"
141 else:
142 instruction = (
143 f"{base_instruction} Provide a concise summary of the following "
144 f"{language.value} code. Describe its purpose and functionality."
145 )
146 return f"{instruction}\n\nCode:\n{code}"
148 elif task == TaskType.TRANSLATE:
149 target_language = "the target language"
150 if context:
151 match = re.search(r"to (\w+)", context, re.IGNORECASE)
152 if match:
153 target_language = match.group(1)
155 instruction = f"Translate the following {language.value} code to {target_language}."
156 return f"{instruction}\n\n{code}"
158 else:
159 return f"Process the following {language.value} code:\n{code}"
161 def _get_explanation_suggestions(self) -> list:
162 """Get suggestions for explanation tasks using localized strings."""
163 return [
164 get_string("codet5_explanation_suggestion_1"),
165 get_string("codet5_explanation_suggestion_2")
166 ]
168 async def shutdown(self):
169 """Cleanup CodeT5"""
170 logger.info("Shutting down CodeT5 engine")
171 if self.model_instance:
172 del self.model_instance
173 self.model_instance = None
174 if self.tokenizer:
175 del self.tokenizer
176 self.tokenizer = None
177 if self.torch:
178 del self.torch
179 self.torch = None
180 self.initialized = False