| |
| |
| |
| |
| license_statement = "NLP demo software by HyperbeeAI. Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. hello@hyperbee.ai" |
| print("imported utils.py") |
| print(license_statement) |
| print("") |
|
|
| import torch |
| import layers |
| from tokenizers import Tokenizer |
| import time, torch, datasets |
| from tqdm import tqdm |
|
|
| tokenizer_en = None |
| tokenizer_es = None |
|
|
| def tokenize_es(text): |
| return tokenizer_es.encode(text).ids[:48 - 2] |
|
|
| def tokenize_en(text): |
| return tokenizer_en.encode(text).ids[:48 - 1] |
|
|
| def translate_sentence(sentence, src_field, trg_field, model, device): |
|
|
| model.eval() |
| if isinstance(sentence, str): |
| tokens = tokenize_es(sentence) |
| else: |
| tokens = sentence |
|
|
| tokens = [src_field.init_token] + tokens + [src_field.eos_token] + [src_field.pad_token] * (48 - 2 - len(tokens)) |
| src_tensor = torch.LongTensor(tokens).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| enc_out, _ = model.encoder(src_tensor) |
|
|
| trg_indexes = [trg_field.init_token, ] + [trg_field.pad_token] * (48 - 1) |
|
|
| for i in range(48 - 1): |
| start_idx = max(0, i - 7) |
| |
| trg_tensor = torch.LongTensor(trg_indexes[start_idx:start_idx + 8]).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| output, _, _ = model.decoder(trg_tensor, enc_out, max(0, i - 7)) |
|
|
| pred_token = output.argmax(2)[:, min(i, 7)].item() |
| trg_indexes[i + 1] = pred_token |
| if pred_token == trg_field.eos_token: |
| break |
|
|
| try: |
| trg_indexes = trg_indexes[1:trg_indexes.index(trg_field.eos_token)] |
| except ValueError: |
| trg_indexes = trg_indexes[1:] |
|
|
| trg_tokens = tokenizer_en.decode(trg_indexes, skip_special_tokens=False) |
|
|
| return trg_tokens |
|
|
|
|
| def calculate_bleu(data, src_field, trg_field, model, device, spiece=False, output_file = f"test.{time.time()}.out"): |
| |
| if spiece: |
| from tokenizers import pre_tokenizers |
| pre_tokenizer = pre_tokenizers.Digits(individual_digits=True) |
| else: |
| pre_tokenizer = tokenizer_en.pre_tokenizer |
|
|
| trgs = [] |
| pred_trgs = [] |
| print('Evaluate on bleu:') |
| for src, trg in tqdm(zip(open("news-comm-v15/news-comm-v15-all-test.es"), open("news-comm-v15/news-comm-v15-all-test.en"))): |
|
|
| if len(src) < 3 or len(trg) < 3: |
| continue |
|
|
| normalized = pre_tokenizer.pre_tokenize_str(tokenizer_en.normalizer.normalize_str(trg)) |
|
|
| if len(normalized) > 48: |
| continue |
|
|
| trgs.append([ " ".join(map(lambda x: x[0], normalized)) ]) |
|
|
| pred_trg = translate_sentence(src, src_field, trg_field, model, device) |
| pred_trgs.append(pred_trg) |
|
|
|
|
| with open(output_file, "w") as fo: |
| fo.write("\n".join(pred_trgs)) |
|
|
| sacrebleu = datasets.load_metric('sacrebleu') |
| return sacrebleu.compute(predictions=pred_trgs, references=trgs) |
|
|
| tokenizer_es = Tokenizer.from_file(f"assets/es.json") |
| tokenizer_en = Tokenizer.from_file(f"assets/en.json") |
| TRG_PAD_IDX = tokenizer_en.token_to_id("<PAD>") |
|
|