| | |
| | import torch |
| | from transformers import AutoTokenizer, GPT2LMHeadModel |
| |
|
| | O_TKN = '<origin>' |
| | C_TKN = '<correct>' |
| | BOS = "</s>" |
| | EOS = "</s>" |
| | PAD = "<pad>" |
| | MASK = '<unused0>' |
| | SENT = '<unused1>' |
| |
|
| |
|
| | def chat(): |
| | tokenizer = AutoTokenizer.from_pretrained('skt/kogpt2-base-v2', |
| | eos_token=EOS, unk_token='<unk>', |
| | pad_token=PAD, mask_token=MASK) |
| | model = GPT2LMHeadModel.from_pretrained('Moo/kogpt2-proofreader') |
| | with torch.no_grad(): |
| | while True: |
| | q = input('원래문장: ').strip() |
| | if q == 'quit': |
| | break |
| | a = '' |
| | while True: |
| | input_ids = torch.LongTensor(tokenizer.encode(O_TKN + q + C_TKN + a)).unsqueeze(dim=0) |
| | pred = model(input_ids) |
| | gen = tokenizer.convert_ids_to_tokens( |
| | torch.argmax( |
| | pred[0], |
| | dim=-1).squeeze().numpy().tolist())[-1] |
| | if gen == EOS: |
| | break |
| | a += gen.replace('▁', ' ') |
| | print(f"교정: {a.strip()}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | chat() |
| |
|