| const express = require('express'); |
| const ort = require('onnxruntime-node'); |
| const tiktoken = require('js-tiktoken'); |
| const cors = require('cors'); |
| const path = require('path'); |
|
|
| const app = express(); |
| app.use(cors()); |
| app.use(express.json()); |
|
|
| const enc = tiktoken.getEncoding("gpt2"); |
| let session = null; |
|
|
| async function initModel() { |
| console.log("--- DEBUG: File-Check ---"); |
| const fs = require('fs'); |
| try { |
| const modelPath = path.join(__dirname, 'SmaLLMPro_350M_int8.onnx'); |
| |
| console.log("Searched model path:", modelPath); |
| if (fs.existsSync(modelPath)) { |
| session = await ort.InferenceSession.create(modelPath); |
| console.log("Model loaded :D!"); |
| } else { |
| console.error("File not found!"); |
| } |
| } catch (e) { |
| console.error("Error:", e.message); |
| } |
| } |
| initModel(); |
|
|
| app.post('/chat', async (req, res) => { |
| if (!session) return res.status(503).json({ error: "Model loading ..." }); |
|
|
| let clientConnected = true; |
|
|
| res.on('close', () => { |
| clientConnected = false; |
| console.log("Connection closed."); |
| }); |
| |
| const { prompt, temp, topK, maxLen, penalty } = req.body; |
| res.setHeader('Content-Type', 'text/event-stream'); |
| res.setHeader('Cache-Control', 'no-cache'); |
|
|
| const formattedPrompt = `Instruction:\n${prompt}\n\nResponse:\n`; |
| let tokens = enc.encode(formattedPrompt); |
|
|
| const VOCAB_SIZE = 50304; |
|
|
| try { |
| for (let i = 0; i < maxLen; i++) { |
| if (!clientConnected) { |
| console.log("Inference stopped, because client disconnected."); |
| break; |
| } |
| |
| const ctx = tokens.slice(-1024); |
|
|
| const paddedInput = new BigInt64Array(1024).fill(0n); |
| |
| for (let i = 0; i < ctx.length; i++) { |
| paddedInput[1024 - ctx.length + i] = BigInt(ctx[i]); |
| } |
| |
| const tensor = new ort.Tensor('int64', paddedInput, [1, 1024]); |
| |
| const results = await session.run({ input: tensor }); |
| |
| const outputName = session.outputNames[0]; |
| const logits = Array.from(results[outputName].data.slice(-VOCAB_SIZE)); |
|
|
| if (penalty !== 1.0) { |
| for (const token of tokens) { |
| if (token < VOCAB_SIZE) { |
| if (logits[token] > 0) { |
| logits[token] /= penalty; |
| } else { |
| logits[token] *= penalty; |
| } |
| } |
| } |
| } |
|
|
| let scaledLogits = logits.map(l => l / temp); |
| const maxLogit = Math.max(...scaledLogits); |
| const exps = scaledLogits.map(l => Math.exp(l - maxLogit)); |
| const sumExps = exps.reduce((a, b) => a + b, 0); |
| let probs = exps.map(e => e / sumExps); |
|
|
| let indexedProbs = probs.map((p, i) => ({ p, i })); |
| indexedProbs.sort((a, b) => b.p - a.p); |
| indexedProbs = indexedProbs.slice(0, topK); |
|
|
| const totalTopKProb = indexedProbs.reduce((a, b) => a + b.p, 0); |
| let r = Math.random() * totalTopKProb; |
| let nextToken = indexedProbs[0].i; |
| |
| for (let pObj of indexedProbs) { |
| r -= pObj.p; |
| if (r <= 0) { |
| nextToken = pObj.i; |
| break; |
| } |
| } |
|
|
| if (nextToken === 50256) break; |
|
|
| tokens.push(nextToken); |
| |
| const newText = enc.decode([nextToken]); |
| |
| if (clientConnected) { |
| res.write(`data: ${JSON.stringify({ token: newText })}\n\n`); |
| } |
|
|
| await new Promise(r => setTimeout(r, 1)); |
| } |
| } catch (err) { |
| console.error("Error:", err); |
| res.write(`data: ${JSON.stringify({ error: err.message })}\n\n`); |
| } finally { |
| res.end(); |
| } |
| }); |
|
|
| app.get('/', (req, res) => res.send("SmaLLMPro Backend is Running")); |
| app.listen(7860, '0.0.0.0'); |
|
|