File size: 2,167 Bytes
b813321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3
"""CLI evaluation script for Précis."""

import argparse
import logging
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.config import ModelConfig, DataConfig
from src.model import load_tokenizer
from src.tuning.data import create_dummy_data

from transformers import AutoModelForCausalLM
from peft import PeftModel

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate Précis model")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint")
    parser.add_argument("--num-samples", type=int, default=5, help="Number of samples to evaluate")
    parser.add_argument("--max-new-tokens", type=int, default=256, help="Max tokens to generate")
    return parser.parse_args()


def main():
    args = parse_args()
    config = ModelConfig()
    data_config = DataConfig()

    logger.info(f"Loading checkpoint from {args.checkpoint}")
    tokenizer = load_tokenizer(config)
    
    model = AutoModelForCausalLM.from_pretrained(
        args.checkpoint,
        device_map="auto",
        trust_remote_code=True,
    )

    # Generate on dummy samples
    samples = create_dummy_data(args.num_samples)
    
    for i, sample in enumerate(samples):
        prompt = data_config.format_prompt(sample["text"])
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        outputs = model.generate(
            **inputs,
            max_new_tokens=args.max_new_tokens,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.pad_token_id,
        )
        
        generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
        summary = generated[len(prompt):]
        
        logger.info(f"\n=== Sample {i+1} ===")
        logger.info(f"Input: {sample['text'][:100]}...")
        logger.info(f"Generated: {summary}")
        logger.info(f"Reference: {sample['summary']}")


if __name__ == "__main__":
    main()