| import torch |
| from transformers import set_seed, pipeline |
| from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| import time |
|
|
| |
| def translate_helsinki_nlp(s:str, src_iso:str, dest_iso:str)-> str: |
| ''' |
| Translate the text using HelsinkiNLP's Opus models for Mossi language. |
| |
| Parameters |
| ---------- |
| s: str |
| The text |
| src_iso: |
| The ISO-3 code of the source language |
| dest_iso: |
| The ISO-3 code of the destination language |
| |
| Returns |
| ---------- |
| translation:str |
| The translated text |
| ''' |
| |
| set_seed(555) |
|
|
| |
| translator = pipeline("translation", model=f"Helsinki-NLP/opus-mt-{src_iso}-{dest_iso}") |
| translation = translator(s)[0]['translation_text'] |
|
|
| return translation |
|
|
| |
| def translate_masakhane(s:str, src_iso:str, dest_iso:str)-> str: |
| ''' |
| Translate the text using Masakhane's M2M models for Mossi language. |
| |
| Parameters |
| ---------- |
| s: str |
| The text |
| src_iso: |
| The ISO-3 code of the source language |
| dest_iso: |
| The ISO-3 code of the destination language |
| |
| Returns |
| ---------- |
| translation:str |
| The translated text |
| ''' |
| |
| set_seed(555) |
|
|
| |
| model = M2M100ForConditionalGeneration.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news") |
| tokenizer = M2M100Tokenizer.from_pretrained(f"masakhane/m2m100_418m_{src_iso}_{dest_iso}_news") |
|
|
| |
| encoded = tokenizer(s, return_tensors="pt") |
| generated_tokens = model.generate(**encoded) |
| translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
| return translation |
|
|
| |
| def translate_facebook(s:str, src_iso:str, dest_iso:str)-> str: |
| ''' |
| Translate the text using Meta's NLLB model for Mossi language. |
| |
| Parameters |
| ---------- |
| s: str |
| The text |
| src_iso: |
| The ISO-3 code of the source language |
| dest_iso: |
| The ISO-3 code of the destination language |
| |
| Returns |
| ---------- |
| translation:str |
| The translated text |
| ''' |
| |
| set_seed(555) |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M",src_lang=f"{src_iso}_Latn") |
| |
| model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") |
|
|
| |
| encoded = tokenizer(s, return_tensors="pt") |
| translated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.convert_tokens_to_ids(f"{dest_iso}_Latn"), max_length=30) |
| translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] |
|
|
| return translation |
|
|
|
|
| |
| def translate(s, src_iso, dest_iso): |
| ''' |
| Translate the text using all available models (Meta, Masakhane, and Helsinki NLP where applicable). |
| |
| Parameters |
| ---------- |
| s: str |
| The text |
| src_iso: |
| The ISO-3 code of the source language |
| dest_iso: |
| The ISO-3 code of the destination language |
| |
| Returns |
| ---------- |
| translation:str |
| The translated text, concatenated over different models |
| ''' |
| |
| |
| start_time = time.time() |
| |
| |
| translation= "Meta's NLLB translation is:\n\n" + translate_facebook(s, src_iso, dest_iso) |
| |
| |
| iso_pair = f"{src_iso}-{dest_iso}" |
| if iso_pair in ["mos-eng", 'eng-mos', 'fra-mos']: |
| src_iso = src_iso.lower().replace("eng", "en").replace("fra", "fr") |
| dest_iso = dest_iso.replace("eng", "en").replace("fra", "fr") |
| translation+= f"\n\n\nHelsinkiNLP's Opus translation is:\n\n {translate_helsinki_nlp(s, src_iso, dest_iso)}" |
| |
| if iso_pair in ["mos-fra", "fra-mos"]: |
| src_iso = src_iso.lower().replace("fra", "fr") |
| dest_iso = dest_iso.replace("fra", "fr") |
| translation+= "\n\n\nMasakhane's M2M translation is:\n\n" + translate_masakhane(s, src_iso, dest_iso) |
|
|
| print("Time elapsed: ", int(time.time() - start_time), " seconds") |
| |
| return translation |
| |