| import os |
| import torch |
| import torch.nn.functional as F |
| from collections import OrderedDict |
| import string |
| from model import ChatGCLM, MAX_SEQ_LEN |
|
|
| MODEL_PATH = None |
| for f in os.listdir("."): |
| if f.startswith("Turing_") and f.endswith(".pt"): |
| MODEL_PATH = f |
| break |
|
|
| if MODEL_PATH is None: |
| print("Error: No model checkpoint found!") |
| print("Please train the model first with: python3 train.py") |
| exit(1) |
|
|
| EOS_ID = 2 |
| OFFSET = 3 |
| CHARS = string.printable |
|
|
| def encode(text): |
| return [CHARS.index(c) + OFFSET for c in text if c in CHARS] |
|
|
| def decode(ids): |
| return "".join([CHARS[i - OFFSET] for i in ids if i >= OFFSET]) |
|
|
| def load_model(device): |
| vocab_size = len(CHARS) + OFFSET |
| |
| model = ChatGCLM(vocab_size).to(device) |
| if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0: |
| print(f"Loading model from: {MODEL_PATH}") |
| ckpt = torch.load(MODEL_PATH, map_location=device) |
|
|
| if isinstance(ckpt, dict): |
| if 'model_state_dict' in ckpt: |
| state_dict = ckpt['model_state_dict'] |
| elif 'state_dict' in ckpt: |
| state_dict = ckpt['state_dict'] |
| else: |
| state_dict = ckpt |
| else: |
| state_dict = ckpt |
|
|
| def _strip_module_prefix(sd): |
| keys = list(sd.keys()) |
| if any(k.startswith('module.') for k in keys): |
| new_sd = OrderedDict() |
| for k, v in sd.items(): |
| new_key = k[len('module.'): ] if k.startswith('module.') else k |
| new_sd[new_key] = v |
| return new_sd |
| return sd |
|
|
| state_dict = _strip_module_prefix(state_dict) |
|
|
| res = model.load_state_dict(state_dict, strict=False) |
| missing = getattr(res, 'missing_keys', None) |
| unexpected = getattr(res, 'unexpected_keys', None) |
| if missing: |
| print(f"Warning: missing keys when loading state_dict: {missing}") |
| if unexpected: |
| print(f"Warning: unexpected keys in state_dict: {unexpected}") |
|
|
| model.eval() |
| return model |
| else: |
| print(f"Error: Could not load model from {MODEL_PATH}") |
| return None |
|
|
| @torch.no_grad() |
| def generate(model, prompt, device, max_new_tokens=200, temperature=0.8, top_k=50): |
| model.eval() |
| input_ids = encode(prompt) |
| x = torch.tensor([input_ids], dtype=torch.long, device=device) |
| |
| print(f"\n{'='*70}") |
| print(f"PROMPT: {prompt}") |
| print(f"{'='*70}") |
| print("GENERATED TEXT:") |
| print(prompt, end="", flush=True) |
| |
| generated_tokens = [] |
| for _ in range(max_new_tokens): |
| ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x |
| logits = model(ctx) |
| next_token_logits = logits[:, -1, :] / temperature |
| |
| if top_k is not None: |
| v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) |
| next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') |
| |
| probs = F.softmax(next_token_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| idx = next_token.item() |
| |
| if idx == EOS_ID: |
| break |
| |
| x = torch.cat((x, next_token), dim=1) |
| generated_tokens.append(idx) |
| token_text = decode([idx]) |
| print(token_text, end="", flush=True) |
| |
| print(f"\n{'='*70}\n") |
| return decode(generated_tokens) |
|
|
| if __name__ == "__main__": |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| print(f"Using device: {device}") |
| |
| model = load_model(device) |
| |
| if model is None: |
| exit(1) |
| |
| test_prompts = [ |
| "Once upon a time", |
| "The future of AI is", |
| "In a world where", |
| ] |
| |
| print("\n" + "="*70) |
| print("ChatGCLM Text Generation Demo") |
| print("="*70) |
| |
| for prompt in test_prompts: |
| generate(model, prompt, device, max_new_tokens=150, temperature=0.8, top_k=50) |
| |
| print("\n" + "="*70) |
| print("Interactive Mode - Enter your own prompts!") |
| print("="*70) |
| |
| while True: |
| user_prompt = input("\nEnter prompt (or 'exit' to quit): ") |
| if user_prompt.lower() == 'exit': |
| break |
| if user_prompt.strip(): |
| generate(model, user_prompt, device, max_new_tokens=200, temperature=0.8, top_k=50) |