| import torch |
| from pathlib import Path |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
| from peft import PeftModel |
|
|
|
|
| class CodetteModelLoader: |
|
|
| def __init__( |
| self, |
| base_model="meta-llama/Llama-3.1-8B-Instruct", |
| adapters=None, |
| ): |
| self.base_model_name = base_model |
| self.adapters = adapters or {} |
| self.model = None |
| self.tokenizer = None |
| self.active_adapter = None |
|
|
| self._load_base_model() |
|
|
| def _load_base_model(self): |
|
|
| quant_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained( |
| self.base_model_name, |
| trust_remote_code=True |
| ) |
|
|
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
| base_model = AutoModelForCausalLM.from_pretrained( |
| self.base_model_name, |
| quantization_config=quant_config, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
|
|
| self.model = base_model |
|
|
| def load_adapters(self): |
|
|
| first = True |
|
|
| for name, path in self.adapters.items(): |
|
|
| path = str(Path(path)) |
|
|
| if first: |
|
|
| self.model = PeftModel.from_pretrained( |
| self.model, |
| path, |
| adapter_name=name, |
| is_trainable=False, |
| ) |
|
|
| self.active_adapter = name |
| first = False |
|
|
| else: |
|
|
| self.model.load_adapter( |
| path, |
| adapter_name=name, |
| ) |
|
|
| def set_active_adapter(self, name): |
|
|
| if name not in self.model.peft_config: |
| raise ValueError(f"Adapter not loaded: {name}") |
|
|
| self.model.set_adapter(name) |
| self.active_adapter = name |
|
|
| def format_messages(self, messages): |
|
|
| return self.tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
|
|
| def tokenize(self, prompt): |
|
|
| return self.tokenizer( |
| prompt, |
| return_tensors="pt" |
| ).to(self.model.device) |