| import os |
| import re |
| import sys |
| from tqdm import tqdm |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
| |
| langs_supported = { |
| "eng_Latn": "en", |
| "ben_Beng": "bn", |
| "guj_Gujr": "gu", |
| "hin_Deva": "hi", |
| "kan_Knda": "kn", |
| "mal_Mlym": "ml", |
| "mar_Deva": "mr", |
| "npi_Deva": "ne", |
| "ory_Orya": "or", |
| "pan_Guru": "pa", |
| "snd_Arab": "sd", |
| "tam_Taml": "ta", |
| "urd_Arab": "ur", |
| } |
|
|
|
|
| def predict(batch, tokenizer, model, bos_token_id): |
| encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device) |
| generated_tokens = model.generate( |
| **encoded_batch, |
| num_beams=5, |
| max_length=256, |
| min_length=0, |
| forced_bos_token_id=bos_token_id, |
| ) |
| hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
| return hypothesis |
|
|
|
|
| def main(devtest_data_dir, batch_size): |
| |
| model_name = "facebook/m2m100-12B-last-ckpt" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
| model.eval() |
|
|
| |
| for pair in sorted(os.listdir(devtest_data_dir)): |
| if "-" not in pair: |
| continue |
|
|
| src_lang, tgt_lang = pair.split("-") |
|
|
| |
| if ( |
| src_lang not in langs_supported.keys() |
| or tgt_lang not in langs_supported.keys() |
| ): |
| print(f"Skipping {src_lang}-{tgt_lang} ...") |
| continue |
|
|
| |
| |
| |
| print(f"Evaluating {src_lang}-{tgt_lang} ...") |
|
|
| infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}") |
| outfname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}.pred.m2m100") |
|
|
| with open(infname, "r") as f: |
| src_sents = f.read().split("\n") |
|
|
| add_new_line = False |
| if src_sents[-1] == "": |
| add_new_line = True |
| src_sents = src_sents[:-1] |
|
|
| |
| tokenizer.src_lang = langs_supported[src_lang] |
|
|
| |
| hypothesis = [] |
| for i in tqdm(range(0, len(src_sents), batch_size)): |
| start, end = i, int(min(len(src_sents), i + batch_size)) |
| batch = src_sents[start:end] |
| bos_token_id = tokenizer.lang_code_to_id[langs_supported[tgt_lang]] |
| hypothesis += predict(batch, tokenizer, model, bos_token_id) |
|
|
| assert len(hypothesis) == len(src_sents) |
|
|
| hypothesis = [ |
| re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip() |
| for x in hypothesis |
| ] |
| if add_new_line: |
| hypothesis = hypothesis |
|
|
| with open(outfname, "w") as f: |
| f.write("\n".join(hypothesis)) |
|
|
| |
| |
| |
| infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}") |
| outfname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}.pred.m2m100") |
|
|
| with open(infname, "r") as f: |
| src_sents = f.read().split("\n") |
|
|
| add_new_line = False |
| if src_sents[-1] == "": |
| add_new_line = True |
| src_sents = src_sents[:-1] |
|
|
| |
| tokenizer.src_lang = langs_supported[tgt_lang] |
|
|
| |
| hypothesis = [] |
| for i in tqdm(range(0, len(src_sents), batch_size)): |
| start, end = i, int(min(len(src_sents), i + batch_size)) |
| batch = src_sents[start:end] |
| bos_token_id = tokenizer.lang_code_to_id[langs_supported[src_lang]] |
| hypothesis += predict(batch, tokenizer, model, bos_token_id) |
|
|
| assert len(hypothesis) == len(src_sents) |
|
|
| hypothesis = [ |
| re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip() |
| for x in hypothesis |
| ] |
| if add_new_line: |
| hypothesis = hypothesis |
|
|
| with open(outfname, "w") as f: |
| f.write("\n".join(hypothesis)) |
|
|
|
|
| if __name__ == "__main__": |
| |
| devtest_data_dir = sys.argv[1] |
| batch_size = int(sys.argv[2]) |
|
|
| if not torch.cuda.is_available(): |
| print("No GPU available") |
| sys.exit(1) |
|
|
| main(devtest_data_dir, batch_size) |
|
|