| | """ |
| | Universal CodeAct Interactive Demo |
| | Supports: CUDA (NVIDIA), MLX (Apple Silicon), CPU |
| | Auto-detects best available backend |
| | """ |
| | import re |
| | import sys |
| | import os |
| | import argparse |
| | from io import StringIO |
| |
|
| | |
| | def detect_backend(): |
| | """Auto-detect the best available backend""" |
| | |
| | try: |
| | import mlx.core as mx |
| | return "mlx" |
| | except ImportError: |
| | pass |
| |
|
| | |
| | try: |
| | import torch |
| | if torch.cuda.is_available(): |
| | return "cuda" |
| | except ImportError: |
| | pass |
| |
|
| | |
| | try: |
| | import torch |
| | if torch.backends.mps.is_available(): |
| | return "mps" |
| | except: |
| | pass |
| |
|
| | |
| | return "cpu" |
| |
|
| | |
| | class MLXBackend: |
| | def __init__(self, model_name, adapter_path=None): |
| | from mlx_lm import load, generate |
| | self.generate_fn = generate |
| |
|
| | if adapter_path and os.path.exists(adapter_path): |
| | print(f"Loading MLX model with adapter: {adapter_path}") |
| | self.model, self.tokenizer = load(model_name, adapter_path=adapter_path) |
| | else: |
| | print(f"Loading MLX model: {model_name}") |
| | self.model, self.tokenizer = load(model_name) |
| |
|
| | def generate(self, prompt, max_tokens=400): |
| | return self.generate_fn( |
| | self.model, |
| | self.tokenizer, |
| | prompt=prompt, |
| | max_tokens=max_tokens, |
| | verbose=False |
| | ) |
| |
|
| | |
| | class PyTorchBackend: |
| | def __init__(self, model_name, device="auto", adapter_path=None): |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
|
| | |
| | if device == "auto": |
| | if torch.cuda.is_available(): |
| | self.device = "cuda" |
| | elif torch.backends.mps.is_available(): |
| | self.device = "mps" |
| | else: |
| | self.device = "cpu" |
| | else: |
| | self.device = device |
| |
|
| | print(f"Loading PyTorch model on {self.device}: {model_name}") |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | model_name, |
| | trust_remote_code=True |
| | ) |
| |
|
| | |
| | dtype = torch.float16 if self.device in ["cuda", "mps"] else torch.float32 |
| |
|
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | torch_dtype=dtype, |
| | device_map=self.device if self.device == "cuda" else None, |
| | trust_remote_code=True, |
| | low_cpu_mem_usage=True |
| | ) |
| |
|
| | if self.device != "cuda": |
| | self.model = self.model.to(self.device) |
| |
|
| | |
| | if adapter_path and os.path.exists(adapter_path): |
| | try: |
| | from peft import PeftModel |
| | print(f"Loading LoRA adapter: {adapter_path}") |
| | self.model = PeftModel.from_pretrained(self.model, adapter_path) |
| | except ImportError: |
| | print("Warning: peft not installed, skipping adapter") |
| |
|
| | def generate(self, prompt, max_tokens=400): |
| | import torch |
| |
|
| | inputs = self.tokenizer(prompt, return_tensors="pt") |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
|
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=max_tokens, |
| | temperature=0.7, |
| | do_sample=True, |
| | top_p=0.95, |
| | pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id |
| | ) |
| |
|
| | response = self.tokenizer.decode( |
| | outputs[0][len(inputs['input_ids'][0]):], |
| | skip_special_tokens=True |
| | ) |
| | return response |
| |
|
| | |
| | def execute_code(code): |
| | """Execute Python code and capture output""" |
| | stdout_buffer = StringIO() |
| | stderr_buffer = StringIO() |
| | old_stdout, old_stderr = sys.stdout, sys.stderr |
| |
|
| | try: |
| | sys.stdout = stdout_buffer |
| | sys.stderr = stderr_buffer |
| | namespace = {} |
| | exec(code, namespace) |
| | output = stdout_buffer.getvalue() |
| | errors = stderr_buffer.getvalue() |
| | return {"success": True, "output": output.strip() or None, "error": errors.strip() or None} |
| | except Exception as e: |
| | return {"success": False, "output": None, "error": str(e)} |
| | finally: |
| | sys.stdout, sys.stderr = old_stdout, old_stderr |
| |
|
| | |
| | class CodeActDemo: |
| | def __init__(self, backend="auto", model_name=None, adapter_path=None): |
| | |
| | if model_name is None: |
| | model_name = "Qwen/Qwen2.5-3B" |
| |
|
| | |
| | if adapter_path is None: |
| | adapter_path = "./models/codeact-mlx-qwen2.5-3b" |
| |
|
| | |
| | if backend == "auto": |
| | backend = detect_backend() |
| |
|
| | print(f"\n{'='*60}") |
| | print(f"CodeAct Interactive Demo") |
| | print(f"Backend: {backend.upper()}") |
| | print(f"{'='*60}\n") |
| |
|
| | self.backend_name = backend |
| |
|
| | |
| | if backend == "mlx": |
| | self.backend = MLXBackend(model_name, adapter_path) |
| | else: |
| | self.backend = PyTorchBackend(model_name, device=backend, adapter_path=adapter_path) |
| |
|
| | self.tokenizer = self.backend.tokenizer if hasattr(self.backend, 'tokenizer') else None |
| | self.conversation_history = [] |
| |
|
| | self.system_prompt = """You are a helpful AI assistant that executes Python code. |
| | Use these tags: |
| | - <thought>reasoning</thought> for thinking |
| | - <execute>code</execute> for code |
| | - <solution>answer</solution> for final answer |
| | - <feedback>assessment</feedback> for self-evaluation""" |
| |
|
| | print("Model loaded successfully!\n") |
| |
|
| | def parse_response(self, response): |
| | """Extract tags from response""" |
| | parts = {'thought': None, 'execute': None, 'solution': None, 'feedback': None} |
| | for tag in parts: |
| | match = re.search(f'<{tag}>(.*?)</{tag}>', response, re.DOTALL) |
| | if match: |
| | parts[tag] = match.group(1).strip() |
| | return parts |
| |
|
| | def build_prompt(self, user_input, execution_result=None): |
| | """Build prompt with conversation history""" |
| | messages = [{"role": "system", "content": self.system_prompt}] |
| | messages.extend(self.conversation_history) |
| |
|
| | if execution_result: |
| | content = f"Previous execution result: {execution_result}\n\nUser: {user_input}" |
| | else: |
| | content = user_input |
| |
|
| | messages.append({"role": "user", "content": content}) |
| |
|
| | |
| | if hasattr(self.backend, 'tokenizer') and hasattr(self.backend.tokenizer, 'apply_chat_template'): |
| | return self.backend.tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True |
| | ) |
| | else: |
| | return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:" |
| |
|
| | def chat(self, user_input, execution_result=None): |
| | """Generate response""" |
| | prompt = self.build_prompt(user_input, execution_result) |
| | return self.backend.generate(prompt, max_tokens=400) |
| |
|
| | def run(self): |
| | """Run interactive loop""" |
| | print("="*60) |
| | print(f"Running on: {self.backend_name.upper()}") |
| | print("="*60) |
| | print("\nCommands:") |
| | print(" - Type your question and press Enter") |
| | print(" - 'clear' - Clear conversation history") |
| | print(" - 'quit' - Exit") |
| | print("="*60 + "\n") |
| |
|
| | last_execution_result = None |
| |
|
| | while True: |
| | try: |
| | user_input = input("\nYou: ").strip() |
| |
|
| | if not user_input: |
| | continue |
| |
|
| | if user_input.lower() in ['quit', 'exit', 'q']: |
| | print("\nGoodbye!") |
| | break |
| |
|
| | if user_input.lower() == 'clear': |
| | self.conversation_history = [] |
| | last_execution_result = None |
| | print("Conversation cleared") |
| | continue |
| |
|
| | print("\n[Generating...]", end=" ", flush=True) |
| | response = self.chat(user_input, last_execution_result) |
| | print("Done!\n") |
| |
|
| | parts = self.parse_response(response) |
| |
|
| | if parts['thought']: |
| | print(f"Thought:\n{parts['thought']}\n") |
| |
|
| | if parts['execute']: |
| | print(f"Code:\n```python\n{parts['execute']}\n```\n") |
| | print("Executing...\n") |
| |
|
| | result = execute_code(parts['execute']) |
| |
|
| | if result["success"]: |
| | if result["output"]: |
| | print(f"Output:\n{result['output']}") |
| | last_execution_result = f"Output: {result['output']}" |
| |
|
| | print("\n" + "-"*40) |
| | feedback = input("Is this correct? (y/n/skip): ").strip().lower() |
| |
|
| | if feedback == 'n': |
| | print("\nMarked as incorrect") |
| | last_execution_result += " [INCORRECT]" |
| | elif feedback == 'y': |
| | print("\nCorrect!") |
| | last_execution_result = None |
| | else: |
| | last_execution_result = None |
| |
|
| | self.conversation_history.append({"role": "user", "content": user_input}) |
| | self.conversation_history.append({"role": "assistant", "content": response}) |
| | else: |
| | print("Code executed (no output)") |
| | last_execution_result = None |
| |
|
| | if result["error"]: |
| | print(f"Warnings: {result['error']}") |
| | else: |
| | print(f"Error: {result['error']}") |
| | last_execution_result = f"Error: {result['error']}" |
| |
|
| | if parts['solution']: |
| | print(f"\nSolution:\n{parts['solution']}") |
| |
|
| | if parts['feedback']: |
| | print(f"\nFeedback:\n{parts['feedback']}") |
| |
|
| | if not any(parts.values()): |
| | print(f"Response:\n{response[:500]}") |
| |
|
| | |
| | if len(self.conversation_history) > 10: |
| | self.conversation_history = self.conversation_history[-10:] |
| |
|
| | print("\n" + "="*60) |
| |
|
| | except KeyboardInterrupt: |
| | print("\n\nInterrupted. Goodbye!") |
| | break |
| | except Exception as e: |
| | print(f"\nError: {e}") |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="CodeAct Interactive Demo") |
| | parser.add_argument("--backend", choices=["auto", "cuda", "mps", "mlx", "cpu"], |
| | default="auto", help="Backend to use (default: auto)") |
| | parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-3B", |
| | help="Model name or path") |
| | parser.add_argument("--adapter", type=str, default=None, |
| | help="Path to LoRA adapter") |
| |
|
| | args = parser.parse_args() |
| |
|
| | demo = CodeActDemo( |
| | backend=args.backend, |
| | model_name=args.model, |
| | adapter_path=args.adapter |
| | ) |
| | demo.run() |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|