| |
|
|
| import torch |
| import evaluate |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| from tqdm import tqdm |
| import argparse |
|
|
| def evaluate_model(): |
| """ |
| Loads a fine-tuned model and evaluates its performance on the test set using the BLEU score. |
| """ |
| parser = argparse.ArgumentParser(description="Evaluate a translation model.") |
| parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model directory") |
| parser.add_argument("--source_lang_file", type=str, required=True, help="Path to the source language test file") |
| parser.add_argument("--target_lang_file", type=str, required=True, help="Path to the target language test file") |
| parser.add_argument("--source_lang_tokenizer", type=str, required=True, help="Source language code for tokenizer (e.g., 'nep_Npan')") |
| args = parser.parse_args() |
|
|
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| |
| print("Loading model, tokenizer, and evaluation metric...") |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| model = AutoModelForSeq2SeqLM.from_pretrained(args.model_path).to(DEVICE) |
| bleu_metric = evaluate.load("sacrebleu") |
|
|
| |
| with open(args.source_lang_file, "r", encoding="utf-8") as f: |
| source_sentences = [line.strip() for line in f.readlines()] |
| with open(args.target_lang_file, "r", encoding="utf-8") as f: |
| |
| reference_translations = [[line.strip()] for line in f.readlines()] |
|
|
| |
| print(f"Generating translations for {len(source_sentences)} test sentences...") |
| predictions = [] |
| for sentence in tqdm(source_sentences): |
| tokenizer.src_lang = args.source_lang_tokenizer |
| inputs = tokenizer(sentence, return_tensors="pt").to(DEVICE) |
| |
| generated_tokens = model.generate( |
| **inputs, |
| forced_bos_token_id=tokenizer.vocab["eng_Latn"], |
| max_length=128 |
| ) |
| |
| translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
| predictions.append(translation) |
|
|
| |
| print("Calculating BLEU score...") |
| results = bleu_metric.compute(predictions=predictions, references=reference_translations) |
| |
| |
| bleu_score = results["score"] |
|
|
| print("\n--- Evaluation Complete ---") |
| print(f"BLEU Score: {bleu_score:.2f}") |
| print("---------------------------") |
|
|
| if __name__ == "__main__": |
| evaluate_model() |