File size: 6,108 Bytes
ee3804e f030a27 ee3804e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import torch
import torch.nn as nn
i_t_w = {0: '\t', 1: ' ', 2: '!', 3: '"', 4: '#', 5: '$', 6: '%', 7: '(', 8: ')', 9: '*', 10: '+', 11: ',', 12: '-', 13: '.', 14: '/', 15: '0', 16: '1', 17: '2', 18: '3', 19: '4', 20: '5', 21: '6', 22: '7', 23: '8', 24: '9', 25: ':', 26: ';', 27: '=', 28: '>', 29: '?', 30: 'A', 31: 'B', 32: 'C', 33: 'D', 34: 'E', 35: 'F', 36: 'G', 37: 'H', 38: 'I', 39: 'J', 40: 'K', 41: 'L', 42: 'M', 43: 'N', 44: 'O', 45: 'P', 46: 'Q', 47: 'R', 48: 'S', 49: 'T', 50: 'U', 51: 'V', 52: 'W', 53: 'X', 54: 'Y', 55: 'Z', 56: '[', 57: ']', 58: '_', 59: 'a', 60: 'b', 61: 'c', 62: 'd', 63: 'e', 64: 'f', 65: 'g', 66: 'h', 67: 'i', 68: 'j', 69: 'k', 70: 'l', 71: 'm', 72: 'n', 73: 'o', 74: 'p', 75: 'q', 76: 'r', 77: 's', 78: 't', 79: 'u', 80: 'v', 81: 'w', 82: 'x', 83: 'y', 84: 'z', 85: '«', 86: '°', 87: '»', 88: 'Á', 89: 'É', 90: 'Í', 91: 'Ó', 92: 'Ö', 93: 'Ú', 94: 'Ü', 95: 'à', 96: 'á', 97: 'â', 98: 'ä', 99: 'ç', 100: 'è', 101: 'é', 102: 'ê', 103: 'í', 104: 'î', 105: 'ï', 106: 'ó', 107: 'ô', 108: 'ö', 109: 'ù', 110: 'ú', 111: 'ü', 112: 'ý', 113: 'Ő', 114: 'ő', 115: 'œ', 116: 'Ű', 117: 'ű', 118: '́', 119: 'Ё', 120: 'А', 121: 'Б', 122: 'В', 123: 'Г', 124: 'Д', 125: 'Е', 126: 'Ж', 127: 'З', 128: 'И', 129: 'Й', 130: 'К', 131: 'Л', 132: 'М', 133: 'Н', 134: 'О', 135: 'П', 136: 'Р', 137: 'С', 138: 'Т', 139: 'У', 140: 'Ф', 141: 'Х', 142: 'Ц', 143: 'Ч', 144: 'Ш', 145: 'Щ', 146: 'Ы', 147: 'Ь', 148: 'Э', 149: 'Ю', 150: 'Я', 151: 'а', 152: 'б', 153: 'в', 154: 'г', 155: 'д', 156: 'е', 157: 'ж', 158: 'з', 159: 'и', 160: 'й', 161: 'к', 162: 'л', 163: 'м', 164: 'н', 165: 'о', 166: 'п', 167: 'р', 168: 'с', 169: 'т', 170: 'у', 171: 'ф', 172: 'х', 173: 'ц', 174: 'ч', 175: 'ш', 176: 'щ', 177: 'ъ', 178: 'ы', 179: 'ь', 180: 'э', 181: 'ю', 182: 'я', 183: 'ё', 184: '\u200b', 185: '‑', 186: '–', 187: '—', 188: '‘', 189: '’', 190: '“', 191: '”', 192: '„', 193: '•', 194: '…', 195: '™', 196: '←', 197: '→', 198: '⚠', 199: '⚡', 200: '✅', 201: '利', 202: '顺', 203: '️', 204: '🆓', 205: '🇷', 206: '🇺', 207: '🌍', 208: '🎯', 209: '🏆', 210: '💡', 211: '💰', 212: '📊', 213: '📍', 214: '📝', 215: '📞', 216: '📱', 217: '📲', 218: '🔄', 219: '🔍', 220: '🔓', 221: '🔧', 222: '🚀', 223: '🚨', 224: '🛠', 225: '🛡', 226: '🤔', 227: '🥇', 228: '🥈', 229: '🥉'}
w_t_i = {'\t': 0, ' ': 1, '!': 2, '"': 3, '#': 4, '$': 5, '%': 6, '(': 7, ')': 8, '*': 9, '+': 10, ',': 11, '-': 12, '.': 13, '/': 14, '0': 15, '1': 16, '2': 17, '3': 18, '4': 19, '5': 20, '6': 21, '7': 22, '8': 23, '9': 24, ':': 25, ';': 26, '=': 27, '>': 28, '?': 29, 'A': 30, 'B': 31, 'C': 32, 'D': 33, 'E': 34, 'F': 35, 'G': 36, 'H': 37, 'I': 38, 'J': 39, 'K': 40, 'L': 41, 'M': 42, 'N': 43, 'O': 44, 'P': 45, 'Q': 46, 'R': 47, 'S': 48, 'T': 49, 'U': 50, 'V': 51, 'W': 52, 'X': 53, 'Y': 54, 'Z': 55, '[': 56, ']': 57, '_': 58, 'a': 59, 'b': 60, 'c': 61, 'd': 62, 'e': 63, 'f': 64, 'g': 65, 'h': 66, 'i': 67, 'j': 68, 'k': 69, 'l': 70, 'm': 71, 'n': 72, 'o': 73, 'p': 74, 'q': 75, 'r': 76, 's': 77, 't': 78, 'u': 79, 'v': 80, 'w': 81, 'x': 82, 'y': 83, 'z': 84, '«': 85, '°': 86, '»': 87, 'Á': 88, 'É': 89, 'Í': 90, 'Ó': 91, 'Ö': 92, 'Ú': 93, 'Ü': 94, 'à': 95, 'á': 96, 'â': 97, 'ä': 98, 'ç': 99, 'è': 100, 'é': 101, 'ê': 102, 'í': 103, 'î': 104, 'ï': 105, 'ó': 106, 'ô': 107, 'ö': 108, 'ù': 109, 'ú': 110, 'ü': 111, 'ý': 112, 'Ő': 113, 'ő': 114, 'œ': 115, 'Ű': 116, 'ű': 117, '́': 118, 'Ё': 119, 'А': 120, 'Б': 121, 'В': 122, 'Г': 123, 'Д': 124, 'Е': 125, 'Ж': 126, 'З': 127, 'И': 128, 'Й': 129, 'К': 130, 'Л': 131, 'М': 132, 'Н': 133, 'О': 134, 'П': 135, 'Р': 136, 'С': 137, 'Т': 138, 'У': 139, 'Ф': 140, 'Х': 141, 'Ц': 142, 'Ч': 143, 'Ш': 144, 'Щ': 145, 'Ы': 146, 'Ь': 147, 'Э': 148, 'Ю': 149, 'Я': 150, 'а': 151, 'б': 152, 'в': 153, 'г': 154, 'д': 155, 'е': 156, 'ж': 157, 'з': 158, 'и': 159, 'й': 160, 'к': 161, 'л': 162, 'м': 163, 'н': 164, 'о': 165, 'п': 166, 'р': 167, 'с': 168, 'т': 169, 'у': 170, 'ф': 171, 'х': 172, 'ц': 173, 'ч': 174, 'ш': 175, 'щ': 176, 'ъ': 177, 'ы': 178, 'ь': 179, 'э': 180, 'ю': 181, 'я': 182, 'ё': 183, '\u200b': 184, '‑': 185, '–': 186, '—': 187, '‘': 188, '’': 189, '“': 190, '”': 191, '„': 192, '•': 193, '…': 194, '™': 195, '←': 196, '→': 197, '⚠': 198, '⚡': 199, '✅': 200, '利': 201, '顺': 202, '️': 203, '🆓': 204, '🇷': 205, '🇺': 206, '🌍': 207, '🎯': 208, '🏆': 209, '💡 💡': 210, '💰': 211, '📊': 212, '📍': 213, '📝': 214, '📞': 215, '📱': 216, '📲': 217, '🔄': 218, '🔍': 219, '🔓': 220, '🔧': 221, '🚀': 222, '🚨': 223, '🛠': 224, '🛡': 225, '🤔': 226, '🥇': 227, '🥈': 228, '🥉': 229}
class AI(nn.Module):
def __init__(self):
super().__init__()
self.embai = nn.Embedding(230, 256)
self.lsai = nn.LSTM(256, 512, batch_first=True, dropout=0.3, num_layers=1)
self.linai = nn.Linear(512, 230)
def forward(self, x):
x = self.embai(x)
if x.dim() == 2:
x = x.unsqueeze(0)
x, _ = self.lsai(x)
x = x[:, -1, :]
x = self.linai(x)
return x
model = AI()
model.load_state_dict(torch.load('C:\\Users\\name\\Downloads\\pytorch_model.pth'))
model.eval()
for _ in range(10):
a = input()
a = a[:100]
s = []
ss = []
for ch in a:
idx = w_t_i.get(ch, 1)
s.append(idx)
for _ in range(100):
s_tensor = torch.tensor(s[-100:], dtype=torch.long).unsqueeze(0)
with torch.no_grad():
otv = model(s_tensor)
if otv.dim() == 3:
last_logits = otv[0, -1, :]
else:
last_logits = otv[-1, :]
next_token = torch.argmax(last_logits, dim=-1).item()
s.append(next_token)
ss.append(next_token)
print(''.join(i_t_w[i] for i in ss))
|