| from pydantic import BaseModel |
| from io import BytesIO |
| import requests |
| from model import TransformerSeq2Seq,translate |
| from utils import load_tokenizers_and_embeddings |
|
|
| import torch |
|
|
| |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| resources = load_tokenizers_and_embeddings() |
| tokenizer_vi = resources["tokenizer_vi"] |
| embedding_matrix_vi = resources["embedding_vi"] |
| tokenizer_en = resources["tokenizer_en"] |
| embedding_matrix_en = resources["embedding_en"] |
| device = resources["device"] |
|
|
| print("✅ Tokenizers & embeddings loaded!") |
| if isinstance(embedding_matrix_en, torch.Tensor): |
| embed_dim = embedding_matrix_en.size(1) |
| else: |
| embed_dim = embedding_matrix_en.embedding_dim |
| max_len = 128 |
| batch_size = 32 |
| |
| model = TransformerSeq2Seq( |
| embed_dim=embed_dim, |
| vocab_size=tokenizer_vi.vocab_size, |
| embedding_decoder=embedding_matrix_vi, |
| num_heads=4, |
| num_layers=2, |
| dim_feedforward=256, |
| dropout=0.1, |
| freeze_decoder_emb=True, |
| max_len=max_len |
| ) |
| MODEL_URL = "https://huggingface.co/nemabruh404/Machine_Translation/resolve/main/model_state_dict.pt" |
|
|
| |
| checkpoint_bytes = BytesIO(requests.get(MODEL_URL).content) |
| checkpoint = torch.load(checkpoint_bytes, map_location=device) |
|
|
| |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.to(device) |
| model.eval() |
|
|
| print("✅ Model loaded from Hugging Face Hub") |
| print("Model loaded") |
|
|
| def hf_inference_fn(inputs: str): |
| return translate( |
| model=model, |
| src_sentence=inputs, |
| tokenizer_src=tokenizer_en, |
| tokenizer_tgt=tokenizer_vi, |
| embedding_src=embedding_matrix_en, |
| device=device, |
| max_len=max_len |
| ) |
| from fastapi import FastAPI |
| from pydantic import BaseModel |
|
|
| app = FastAPI() |
|
|
| class TranslateRequest(BaseModel): |
| text: str |
|
|
| @app.post("/translate") |
| def translate(req: TranslateRequest): |
| return {"translation": hf_inference_fn(req.text)} |
|
|