| | import torch |
| | from transformers import T5Tokenizer, T5ForConditionalGeneration |
| |
|
| | def pretty_print(text, prompt=True): |
| | s = "" |
| | if prompt: |
| | for section in text.split(', '): |
| | premises = section.split(" and ") |
| | if len(premises) > 1: |
| | for premise in premises[:-1]: |
| | s += premise + "\n\n\n" + "and" + "\n\n\n" |
| | s += premises[-1] + "\n\n\n" |
| | else: |
| | s += section + "\n\n\n" |
| | else: |
| | for equation in text.split("and"): |
| | s += equation + "\n\n\n" |
| | return print(s[:-3]) |
| |
|
| |
|
| | def load_model(model_id): |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | tokenizer = T5Tokenizer.from_pretrained(model_id) |
| | model = T5ForConditionalGeneration.from_pretrained(model_id).to(device) |
| | return tokenizer, model |
| |
|
| |
|
| | def inference(prompt, tokenizer, model): |
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | input_ids = tokenizer.encode(prompt, return_tensors='pt', max_length=512, truncation=True).to(device) |
| | output = model.generate(input_ids=input_ids, max_length=512, early_stopping=True) |
| | generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
| | |
| | |
| | derivation = generated_text.replace("\\ ","\\") |
| | partial_symbols = derivation.split(" ") |
| | backslash_syms = set([i for i in partial_symbols if "\\" in i]) |
| | for i in range(len(partial_symbols)): |
| | sym = partial_symbols[i] |
| | for b_sym in backslash_syms: |
| | if b_sym.replace("\\","") == sym: |
| | partial_symbols[i] = b_sym |
| | return " ".join(partial_symbols) |
| |
|