| import torch |
| from train import get_model, greedy_decode, get_or_build_tokenizer |
| from config import get_config |
|
|
| INPUT_TEXT = "sun rises in the night" |
|
|
| def inference(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| device = torch.device(device) |
|
|
| config = get_config() |
|
|
| tokenizer_src = get_or_build_tokenizer(config, None, config["lang_src"]) |
| tokenizer_tgt = get_or_build_tokenizer(config, None, config["lang_target"]) |
|
|
| model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device) |
|
|
| model_filename = "weights/tmodel_19.pt" |
| state = torch.load(model_filename, map_location=device) |
| model.load_state_dict(state["model_state_dict"]) |
| model.eval() |
|
|
| tokens = tokenizer_src.encode(INPUT_TEXT).ids |
| tokens = [tokenizer_src.token_to_id("[SOS]")] + tokens + [tokenizer_src.token_to_id("[EOS]")] |
| encoder_input = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device) |
| encoder_mask = (encoder_input != tokenizer_src.token_to_id("[PAD]")).unsqueeze(0).unsqueeze(0).to(device) |
|
|
| model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, config["seq_len"], device) |
| output_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy()) |
|
|
| print("Source:", INPUT_TEXT) |
| print("Predicted:", output_text) |
|
|
|
|
| if __name__ == "__main__": |
| inference() |
|
|