| |
| import torch |
| from transformers import AutoTokenizer, GPT2LMHeadModel |
|
|
| O_TKN = '<origin>' |
| C_TKN = '<correct>' |
| BOS = "</s>" |
| EOS = "</s>" |
| PAD = "<pad>" |
| MASK = '<unused0>' |
| SENT = '<unused1>' |
|
|
|
|
| def chat(): |
| tokenizer = AutoTokenizer.from_pretrained('skt/kogpt2-base-v2', |
| eos_token=EOS, unk_token='<unk>', |
| pad_token=PAD, mask_token=MASK) |
| model = GPT2LMHeadModel.from_pretrained('Moo/kogpt2-proofreader') |
| with torch.no_grad(): |
| while True: |
| q = input('원래문장: ').strip() |
| if q == 'quit': |
| break |
| a = '' |
| while True: |
| input_ids = torch.LongTensor(tokenizer.encode(O_TKN + q + C_TKN + a)).unsqueeze(dim=0) |
| pred = model(input_ids) |
| gen = tokenizer.convert_ids_to_tokens( |
| torch.argmax( |
| pred[0], |
| dim=-1).squeeze().numpy().tolist())[-1] |
| if gen == EOS: |
| break |
| a += gen.replace('▁', ' ') |
| print(f"교정: {a.strip()}") |
|
|
|
|
| if __name__ == "__main__": |
| chat() |
|
|