Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- Dockerfile +38 -0
- README.md +31 -7
- cmd/server/main.go +250 -0
- go.mod +38 -0
- go.sum +66 -0
- models/latest_checkpoint.json +3 -0
- pkg/model/model.go +375 -0
- pkg/model/tokenizer.go +124 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
models/latest_checkpoint.json filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Build stage
|
| 2 |
+
FROM golang:1.25-bullseye AS builder
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Copy go.mod and go.sum first for caching
|
| 7 |
+
COPY go.mod go.sum ./
|
| 8 |
+
RUN go mod download
|
| 9 |
+
|
| 10 |
+
# Copy the rest of the code
|
| 11 |
+
COPY . .
|
| 12 |
+
|
| 13 |
+
# Build the server binary
|
| 14 |
+
RUN CGO_ENABLED=0 GOOS=linux go build -o server ./cmd/server/main.go
|
| 15 |
+
|
| 16 |
+
# Final stage
|
| 17 |
+
FROM debian:bullseye-slim
|
| 18 |
+
|
| 19 |
+
WORKDIR /app
|
| 20 |
+
|
| 21 |
+
# Install CA certificates for external downloads if needed
|
| 22 |
+
RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/*
|
| 23 |
+
|
| 24 |
+
# Copy the binary from builder
|
| 25 |
+
COPY --from=builder /app/server .
|
| 26 |
+
|
| 27 |
+
# Copy the models directory for weights
|
| 28 |
+
COPY ./models ./models
|
| 29 |
+
|
| 30 |
+
# Set environment variables
|
| 31 |
+
ENV PORT=7860
|
| 32 |
+
ENV MODEL_PATH=models/latest_checkpoint.json
|
| 33 |
+
|
| 34 |
+
# Expose the port
|
| 35 |
+
EXPOSE 7860
|
| 36 |
+
|
| 37 |
+
# Run the server
|
| 38 |
+
CMD ["./server"]
|
README.md
CHANGED
|
@@ -1,12 +1,36 @@
|
|
| 1 |
---
|
| 2 |
-
title: MicroGPT API
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: indigo
|
| 6 |
-
sdk:
|
| 7 |
-
|
| 8 |
-
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: MicroGPT OpenAI API
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: blue
|
| 5 |
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
|
|
|
| 8 |
pinned: false
|
| 9 |
+
license: mit
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# MicroGPT OpenAI-Compatible API
|
| 13 |
+
|
| 14 |
+
This Space hosts a Go-based inference server for the **MicroGPT** model, providing an OpenAI-compatible API.
|
| 15 |
+
|
| 16 |
+
## API Endpoints
|
| 17 |
+
|
| 18 |
+
- **`POST /v1/chat/completions`**: standard OpenAI chat format.
|
| 19 |
+
- **`GET /v1/models`**: returns model metadata.
|
| 20 |
+
|
| 21 |
+
## Local Test
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
curl -X POST http://localhost:7860/v1/chat/completions \
|
| 25 |
+
-H "Content-Type: application/json" \
|
| 26 |
+
-d '{
|
| 27 |
+
"model": "microgpt",
|
| 28 |
+
"messages": [
|
| 29 |
+
{"role": "user", "content": "Help me prioritize my day"}
|
| 30 |
+
]
|
| 31 |
+
}'
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Repository
|
| 35 |
+
|
| 36 |
+
Built with [MicroGPT Go Edition](https://github.com/Traves-Theberge/microgpt-tui-go).
|
cmd/server/main.go
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"fmt"
|
| 6 |
+
"log"
|
| 7 |
+
"net/http"
|
| 8 |
+
"os"
|
| 9 |
+
"strings"
|
| 10 |
+
"time"
|
| 11 |
+
|
| 12 |
+
"microgpt-go/pkg/model"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
type ChatMessage struct {
|
| 16 |
+
Role string `json:"role"`
|
| 17 |
+
Content string `json:"content"`
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
type ChatCompletionRequest struct {
|
| 21 |
+
Model string `json:"model"`
|
| 22 |
+
Messages []ChatMessage `json:"messages"`
|
| 23 |
+
Temperature float64 `json:"temperature"`
|
| 24 |
+
MaxTokens int `json:"max_tokens"`
|
| 25 |
+
TopP float64 `json:"top_p"`
|
| 26 |
+
Stream bool `json:"stream"`
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
type ChatCompletionResponse struct {
|
| 30 |
+
ID string `json:"id"`
|
| 31 |
+
Object string `json:"object"`
|
| 32 |
+
Created int64 `json:"created"`
|
| 33 |
+
Model string `json:"model"`
|
| 34 |
+
Choices []struct {
|
| 35 |
+
Message ChatMessage `json:"message"`
|
| 36 |
+
Index int `json:"index"`
|
| 37 |
+
FinishReason string `json:"finish_reason"`
|
| 38 |
+
} `json:"choices"`
|
| 39 |
+
Usage struct {
|
| 40 |
+
PromptTokens int `json:"prompt_tokens"`
|
| 41 |
+
CompletionTokens int `json:"completion_tokens"`
|
| 42 |
+
TotalTokens int `json:"total_tokens"`
|
| 43 |
+
} `json:"usage"`
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
var (
|
| 47 |
+
gpt func(tokenID, posID int, keys, values [][][]*model.Value) []*model.Value
|
| 48 |
+
tokenizer model.TokenizerRuntime
|
| 49 |
+
config model.TrainingCheckpointConfig
|
| 50 |
+
state map[string][][]*model.Value
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
func initModel() {
|
| 54 |
+
ckptPath := os.Getenv("MODEL_PATH")
|
| 55 |
+
if ckptPath == "" {
|
| 56 |
+
ckptPath = "models/latest_checkpoint.json"
|
| 57 |
+
}
|
| 58 |
+
log.Printf("Loading model from %s...", ckptPath)
|
| 59 |
+
ckpt, err := model.LoadCheckpoint(ckptPath)
|
| 60 |
+
if err != nil {
|
| 61 |
+
log.Fatalf("Failed to load checkpoint: %v", err)
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
tokenizer, err = model.TokenizerFromCheckpoint(ckpt)
|
| 65 |
+
if err != nil {
|
| 66 |
+
log.Fatalf("Failed to load tokenizer: %v", err)
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
state = model.ImportState(ckpt.State)
|
| 70 |
+
config = ckpt.Config
|
| 71 |
+
gpt = model.BuildGPT(state, config.NLayer, config.NEmbd, config.NHead)
|
| 72 |
+
log.Println("Model loaded successfully.")
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
func handleChat(w http.ResponseWriter, r *http.Request) {
|
| 76 |
+
if r.Method != http.MethodPost {
|
| 77 |
+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
| 78 |
+
return
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
var req ChatCompletionRequest
|
| 82 |
+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
| 83 |
+
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
| 84 |
+
return
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
if req.Temperature <= 0 {
|
| 88 |
+
req.Temperature = 0.5
|
| 89 |
+
}
|
| 90 |
+
if req.TopP <= 0 {
|
| 91 |
+
req.TopP = 0.9
|
| 92 |
+
}
|
| 93 |
+
if req.MaxTokens <= 0 {
|
| 94 |
+
req.MaxTokens = 128
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
// Simple prompt construction from messages
|
| 98 |
+
var promptBuilder strings.Builder
|
| 99 |
+
for _, msg := range req.Messages {
|
| 100 |
+
role := "User"
|
| 101 |
+
if msg.Role == "assistant" {
|
| 102 |
+
role = "Assistant"
|
| 103 |
+
}
|
| 104 |
+
fmt.Fprintf(&promptBuilder, "%s: %s\n", role, msg.Content)
|
| 105 |
+
}
|
| 106 |
+
promptBuilder.WriteString("Assistant: ")
|
| 107 |
+
promptText := promptBuilder.String()
|
| 108 |
+
|
| 109 |
+
promptTokens := tokenizer.EncodeDoc(promptText)
|
| 110 |
+
if len(promptTokens) > config.BlockSize-1 {
|
| 111 |
+
promptTokens = promptTokens[len(promptTokens)-(config.BlockSize-1):]
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
keys := make([][][]*model.Value, config.NLayer)
|
| 115 |
+
values := make([][][]*model.Value, config.NLayer)
|
| 116 |
+
tokenID := tokenizer.BosID
|
| 117 |
+
pos := 0
|
| 118 |
+
|
| 119 |
+
// Process prompt tokens (pre-fill KV cache)
|
| 120 |
+
for _, nextID := range promptTokens {
|
| 121 |
+
if pos >= config.BlockSize {
|
| 122 |
+
break
|
| 123 |
+
}
|
| 124 |
+
_ = gpt(tokenID, pos, keys, values)
|
| 125 |
+
tokenID = nextID
|
| 126 |
+
pos++
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
// Generate response
|
| 130 |
+
completionTokens := 0
|
| 131 |
+
outTokens := make([]int, 0, req.MaxTokens)
|
| 132 |
+
recent := make([]int, 0, 64)
|
| 133 |
+
stopSeqs := []string{"\nUser:", "\nAssistant:"}
|
| 134 |
+
|
| 135 |
+
for pos < config.BlockSize && completionTokens < req.MaxTokens {
|
| 136 |
+
logits := gpt(tokenID, pos, keys, values)
|
| 137 |
+
recentSet := map[int]bool{}
|
| 138 |
+
for _, id := range recent {
|
| 139 |
+
recentSet[id] = true
|
| 140 |
+
}
|
| 141 |
+
weights := model.NextTokenWeights(logits, req.Temperature, 40, req.TopP, recentSet, 1.1)
|
| 142 |
+
tokenID = model.SampleWeighted(weights)
|
| 143 |
+
|
| 144 |
+
if tokenID == tokenizer.BosID {
|
| 145 |
+
break
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
outTokens = append(outTokens, tokenID)
|
| 149 |
+
recent = append(recent, tokenID)
|
| 150 |
+
if len(recent) > 64 {
|
| 151 |
+
recent = recent[len(recent)-64:]
|
| 152 |
+
}
|
| 153 |
+
completionTokens++
|
| 154 |
+
pos++
|
| 155 |
+
|
| 156 |
+
// Check for stop sequences in decoded text
|
| 157 |
+
fullText := tokenizer.DecodeTokens(outTokens)
|
| 158 |
+
stopFound := false
|
| 159 |
+
for _, stop := range stopSeqs {
|
| 160 |
+
if strings.Contains(fullText, stop) {
|
| 161 |
+
stopFound = true
|
| 162 |
+
break
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
if stopFound {
|
| 166 |
+
break
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
responseText := strings.TrimSpace(tokenizer.DecodeTokens(outTokens))
|
| 171 |
+
// Clean up any trailing stop sequence markers
|
| 172 |
+
for _, stop := range stopSeqs {
|
| 173 |
+
if idx := strings.Index(responseText, strings.TrimSpace(stop)); idx >= 0 {
|
| 174 |
+
responseText = responseText[:idx]
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
resp := ChatCompletionResponse{
|
| 179 |
+
ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
|
| 180 |
+
Object: "chat.completion",
|
| 181 |
+
Created: time.Now().Unix(),
|
| 182 |
+
Model: "microgpt",
|
| 183 |
+
Choices: []struct {
|
| 184 |
+
Message ChatMessage `json:"message"`
|
| 185 |
+
Index int `json:"index"`
|
| 186 |
+
FinishReason string `json:"finish_reason"`
|
| 187 |
+
}{
|
| 188 |
+
{
|
| 189 |
+
Message: ChatMessage{
|
| 190 |
+
Role: "assistant",
|
| 191 |
+
Content: strings.TrimSpace(responseText),
|
| 192 |
+
},
|
| 193 |
+
Index: 0,
|
| 194 |
+
FinishReason: "stop",
|
| 195 |
+
},
|
| 196 |
+
},
|
| 197 |
+
}
|
| 198 |
+
resp.Usage.PromptTokens = len(promptTokens)
|
| 199 |
+
resp.Usage.CompletionTokens = completionTokens
|
| 200 |
+
resp.Usage.TotalTokens = resp.Usage.PromptTokens + resp.Usage.CompletionTokens
|
| 201 |
+
|
| 202 |
+
w.Header().Set("Content-Type", "application/json")
|
| 203 |
+
json.NewEncoder(w).Encode(resp)
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
func handleModels(w http.ResponseWriter, r *http.Request) {
|
| 207 |
+
resp := struct {
|
| 208 |
+
Object string `json:"object"`
|
| 209 |
+
Data []struct {
|
| 210 |
+
ID string `json:"id"`
|
| 211 |
+
Object string `json:"object"`
|
| 212 |
+
Created int64 `json:"created"`
|
| 213 |
+
OwnedBy string `json:"owned_by"`
|
| 214 |
+
} `json:"data"`
|
| 215 |
+
}{
|
| 216 |
+
Object: "list",
|
| 217 |
+
Data: []struct {
|
| 218 |
+
ID string `json:"id"`
|
| 219 |
+
Object string `json:"object"`
|
| 220 |
+
Created int64 `json:"created"`
|
| 221 |
+
OwnedBy string `json:"owned_by"`
|
| 222 |
+
}{
|
| 223 |
+
{
|
| 224 |
+
ID: "microgpt",
|
| 225 |
+
Object: "model",
|
| 226 |
+
Created: time.Now().Unix(),
|
| 227 |
+
OwnedBy: "microgpt",
|
| 228 |
+
},
|
| 229 |
+
},
|
| 230 |
+
}
|
| 231 |
+
w.Header().Set("Content-Type", "application/json")
|
| 232 |
+
json.NewEncoder(w).Encode(resp)
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
func main() {
|
| 236 |
+
initModel()
|
| 237 |
+
|
| 238 |
+
http.HandleFunc("/v1/chat/completions", handleChat)
|
| 239 |
+
http.HandleFunc("/v1/models", handleModels)
|
| 240 |
+
|
| 241 |
+
port := os.Getenv("PORT")
|
| 242 |
+
if port == "" {
|
| 243 |
+
port = "7860" // Standard port for HF Spaces
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
log.Printf("Starting OpenAI-compatible server on port %s...", port)
|
| 247 |
+
if err := http.ListenAndServe(":"+port, nil); err != nil {
|
| 248 |
+
log.Fatalf("Failed to start server: %v", err)
|
| 249 |
+
}
|
| 250 |
+
}
|
go.mod
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
module microgpt-go
|
| 2 |
+
|
| 3 |
+
go 1.25
|
| 4 |
+
|
| 5 |
+
require (
|
| 6 |
+
github.com/charmbracelet/bubbles v1.0.0
|
| 7 |
+
github.com/charmbracelet/bubbletea v1.3.10
|
| 8 |
+
github.com/charmbracelet/harmonica v0.2.0
|
| 9 |
+
github.com/charmbracelet/lipgloss v1.1.0
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
require (
|
| 13 |
+
github.com/atotto/clipboard v0.1.4 // indirect
|
| 14 |
+
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
| 15 |
+
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
| 16 |
+
github.com/charmbracelet/x/ansi v0.11.6 // indirect
|
| 17 |
+
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
| 18 |
+
github.com/charmbracelet/x/term v0.2.2 // indirect
|
| 19 |
+
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
| 20 |
+
github.com/clipperhouse/stringish v0.1.1 // indirect
|
| 21 |
+
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
| 22 |
+
github.com/dlclark/regexp2 v1.10.0 // indirect
|
| 23 |
+
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
| 24 |
+
github.com/google/uuid v1.3.0 // indirect
|
| 25 |
+
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
| 26 |
+
github.com/mattn/go-isatty v0.0.20 // indirect
|
| 27 |
+
github.com/mattn/go-localereader v0.0.1 // indirect
|
| 28 |
+
github.com/mattn/go-runewidth v0.0.19 // indirect
|
| 29 |
+
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
| 30 |
+
github.com/muesli/cancelreader v0.2.2 // indirect
|
| 31 |
+
github.com/muesli/termenv v0.16.0 // indirect
|
| 32 |
+
github.com/pkoukk/tiktoken-go v0.1.8 // indirect
|
| 33 |
+
github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
|
| 34 |
+
github.com/rivo/uniseg v0.4.7 // indirect
|
| 35 |
+
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
| 36 |
+
golang.org/x/sys v0.38.0 // indirect
|
| 37 |
+
golang.org/x/text v0.3.8 // indirect
|
| 38 |
+
)
|
go.sum
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
| 2 |
+
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
| 3 |
+
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
| 4 |
+
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
| 5 |
+
github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY=
|
| 6 |
+
github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E=
|
| 7 |
+
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
|
| 8 |
+
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
|
| 9 |
+
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
| 10 |
+
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
| 11 |
+
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
| 12 |
+
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
| 13 |
+
github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ=
|
| 14 |
+
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
|
| 15 |
+
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
| 16 |
+
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
| 17 |
+
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
| 18 |
+
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
| 19 |
+
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
| 20 |
+
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
| 21 |
+
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ=
|
| 22 |
+
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
|
| 23 |
+
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
| 24 |
+
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
| 25 |
+
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
| 26 |
+
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
| 27 |
+
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
| 28 |
+
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
| 29 |
+
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
| 30 |
+
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
| 31 |
+
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
| 32 |
+
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
| 33 |
+
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
| 34 |
+
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
| 35 |
+
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
| 36 |
+
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
| 37 |
+
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
| 38 |
+
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
| 39 |
+
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
| 40 |
+
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
| 41 |
+
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
| 42 |
+
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
| 43 |
+
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
| 44 |
+
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
| 45 |
+
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
| 46 |
+
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
| 47 |
+
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
| 48 |
+
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
| 49 |
+
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
| 50 |
+
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
| 51 |
+
github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
|
| 52 |
+
github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
| 53 |
+
github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
|
| 54 |
+
github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
|
| 55 |
+
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
| 56 |
+
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
| 57 |
+
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
| 58 |
+
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
| 59 |
+
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
| 60 |
+
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
| 61 |
+
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
| 62 |
+
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
| 63 |
+
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
| 64 |
+
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
| 65 |
+
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
|
| 66 |
+
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
models/latest_checkpoint.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:964f39971833a67b2ec3a3cdd1376586aa3d3cc2b55cb11f8dc581c27a304720
|
| 3 |
+
size 19575802
|
pkg/model/model.go
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package model
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"fmt"
|
| 6 |
+
"math"
|
| 7 |
+
"math/rand"
|
| 8 |
+
"os"
|
| 9 |
+
"sort"
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
// Value represents a scalar for autograd
|
| 13 |
+
type Value struct {
|
| 14 |
+
Data float64
|
| 15 |
+
Grad float64
|
| 16 |
+
Children []*Value
|
| 17 |
+
LocalGrads []float64
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
func V(x float64) *Value {
|
| 21 |
+
return &Value{Data: x}
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
func Add(a, b *Value) *Value {
|
| 25 |
+
out := &Value{Data: a.Data + b.Data, Children: []*Value{a, b}, LocalGrads: []float64{1, 1}}
|
| 26 |
+
return out
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
func Sub(a, b *Value) *Value {
|
| 30 |
+
out := &Value{Data: a.Data - b.Data, Children: []*Value{a, b}, LocalGrads: []float64{1, -1}}
|
| 31 |
+
return out
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
func Mul(a, b *Value) *Value {
|
| 35 |
+
out := &Value{Data: a.Data * b.Data, Children: []*Value{a, b}, LocalGrads: []float64{b.Data, a.Data}}
|
| 36 |
+
return out
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
func Pow(a *Value, p float64) *Value {
|
| 40 |
+
out := &Value{Data: math.Pow(a.Data, p), Children: []*Value{a}, LocalGrads: []float64{p * math.Pow(a.Data, p-1)}}
|
| 41 |
+
return out
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
func Div(a, b *Value) *Value {
|
| 45 |
+
return Mul(a, Pow(b, -1))
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
func Neg(a *Value) *Value {
|
| 49 |
+
return Mul(a, V(-1))
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
func Log(a *Value) *Value {
|
| 53 |
+
out := &Value{Data: math.Log(a.Data), Children: []*Value{a}, LocalGrads: []float64{1 / a.Data}}
|
| 54 |
+
return out
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
func Exp(a *Value) *Value {
|
| 58 |
+
out := &Value{Data: math.Exp(a.Data), Children: []*Value{a}, LocalGrads: []float64{math.Exp(a.Data)}}
|
| 59 |
+
return out
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
func ReLU(a *Value) *Value {
|
| 63 |
+
val := 0.0
|
| 64 |
+
grad := 0.0
|
| 65 |
+
if a.Data > 0 {
|
| 66 |
+
val = a.Data
|
| 67 |
+
grad = 1
|
| 68 |
+
}
|
| 69 |
+
out := &Value{Data: val, Children: []*Value{a}, LocalGrads: []float64{grad}}
|
| 70 |
+
return out
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
func Backward(out *Value) {
|
| 74 |
+
topo := make([]*Value, 0)
|
| 75 |
+
visited := make(map[*Value]bool)
|
| 76 |
+
var buildTopo func(*Value)
|
| 77 |
+
buildTopo = func(v *Value) {
|
| 78 |
+
if !visited[v] {
|
| 79 |
+
visited[v] = true
|
| 80 |
+
for _, child := range v.Children {
|
| 81 |
+
buildTopo(child)
|
| 82 |
+
}
|
| 83 |
+
topo = append(topo, v)
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
buildTopo(out)
|
| 87 |
+
|
| 88 |
+
for _, v := range topo {
|
| 89 |
+
v.Grad = 0
|
| 90 |
+
}
|
| 91 |
+
out.Grad = 1
|
| 92 |
+
for i := len(topo) - 1; i >= 0; i-- {
|
| 93 |
+
v := topo[i]
|
| 94 |
+
for j, child := range v.Children {
|
| 95 |
+
child.Grad += v.LocalGrads[j] * v.Grad
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
func linear(x []*Value, w [][]*Value) []*Value {
|
| 101 |
+
nout := len(w)
|
| 102 |
+
nin := len(x)
|
| 103 |
+
out := make([]*Value, nout)
|
| 104 |
+
for i := 0; i < nout; i++ {
|
| 105 |
+
s := V(0)
|
| 106 |
+
for j := 0; j < nin; j++ {
|
| 107 |
+
s = Add(s, Mul(x[j], w[i][j]))
|
| 108 |
+
}
|
| 109 |
+
out[i] = s
|
| 110 |
+
}
|
| 111 |
+
return out
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
func softmax(logits []*Value) []*Value {
|
| 115 |
+
maxVal := -math.MaxFloat64
|
| 116 |
+
for _, l := range logits {
|
| 117 |
+
if l.Data > maxVal {
|
| 118 |
+
maxVal = l.Data
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
exps := make([]*Value, len(logits))
|
| 122 |
+
sumExp := V(0)
|
| 123 |
+
for i, l := range logits {
|
| 124 |
+
exps[i] = Exp(Sub(l, V(maxVal)))
|
| 125 |
+
sumExp = Add(sumExp, exps[i])
|
| 126 |
+
}
|
| 127 |
+
out := make([]*Value, len(logits))
|
| 128 |
+
invSum := Div(V(1), sumExp)
|
| 129 |
+
for i := range exps {
|
| 130 |
+
out[i] = Mul(exps[i], invSum)
|
| 131 |
+
}
|
| 132 |
+
return out
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
func rmsnorm(x []*Value) []*Value {
|
| 136 |
+
meanSq := V(0)
|
| 137 |
+
for _, v := range x {
|
| 138 |
+
meanSq = Add(meanSq, Pow(v, 2))
|
| 139 |
+
}
|
| 140 |
+
meanSq = Mul(V(1/float64(len(x))), meanSq)
|
| 141 |
+
invStd := Div(V(1), Pow(Add(meanSq, V(1e-6)), 0.5))
|
| 142 |
+
out := make([]*Value, len(x))
|
| 143 |
+
for i, v := range x {
|
| 144 |
+
out[i] = Mul(v, invStd)
|
| 145 |
+
}
|
| 146 |
+
return out
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// TrainingCheckpoint structs
|
| 150 |
+
type TrainingCheckpoint struct {
|
| 151 |
+
Version int `json:"version"`
|
| 152 |
+
CreatedAt string `json:"created_at"`
|
| 153 |
+
Config TrainingCheckpointConfig `json:"config"`
|
| 154 |
+
Tokenization string `json:"tokenization,omitempty"`
|
| 155 |
+
BPEEncoding string `json:"bpe_encoding,omitempty"`
|
| 156 |
+
BPETokenIDs []int `json:"bpe_token_ids,omitempty"`
|
| 157 |
+
Vocab []string `json:"vocab,omitempty"`
|
| 158 |
+
State map[string][][]float64 `json:"state"`
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
type TrainingCheckpointConfig struct {
|
| 162 |
+
NLayer int `json:"n_layer"`
|
| 163 |
+
NEmbd int `json:"n_embd"`
|
| 164 |
+
NHead int `json:"n_head"`
|
| 165 |
+
BlockSize int `json:"block_size"`
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
func ImportState(src map[string][][]float64) map[string][][]*Value {
|
| 169 |
+
out := make(map[string][][]*Value, len(src))
|
| 170 |
+
for name, mat := range src {
|
| 171 |
+
rows := make([][]*Value, len(mat))
|
| 172 |
+
for i, row := range mat {
|
| 173 |
+
r := make([]*Value, len(row))
|
| 174 |
+
for j, v := range row {
|
| 175 |
+
r[j] = V(v)
|
| 176 |
+
}
|
| 177 |
+
rows[i] = r
|
| 178 |
+
}
|
| 179 |
+
out[name] = rows
|
| 180 |
+
}
|
| 181 |
+
return out
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
func LoadCheckpoint(path string) (TrainingCheckpoint, error) {
|
| 185 |
+
b, err := os.ReadFile(path)
|
| 186 |
+
if err != nil {
|
| 187 |
+
return TrainingCheckpoint{}, err
|
| 188 |
+
}
|
| 189 |
+
var ckpt TrainingCheckpoint
|
| 190 |
+
if err := json.Unmarshal(b, &ckpt); err != nil {
|
| 191 |
+
return TrainingCheckpoint{}, err
|
| 192 |
+
}
|
| 193 |
+
if ckpt.Config.NLayer < 1 || ckpt.Config.NEmbd < 1 || ckpt.Config.NHead < 1 || ckpt.Config.BlockSize < 2 {
|
| 194 |
+
return TrainingCheckpoint{}, fmt.Errorf("invalid checkpoint config")
|
| 195 |
+
}
|
| 196 |
+
if ckpt.Config.NEmbd%ckpt.Config.NHead != 0 {
|
| 197 |
+
return TrainingCheckpoint{}, fmt.Errorf("invalid checkpoint: n_embd must be divisible by n_head")
|
| 198 |
+
}
|
| 199 |
+
return ckpt, nil
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
func BuildGPT(state map[string][][]*Value, nLayer, nEmbd, nHead int) func(tokenID, posID int, keys, values [][][]*Value) []*Value {
|
| 203 |
+
headDim := nEmbd / nHead
|
| 204 |
+
return func(tokenID, posID int, keys, values [][][]*Value) []*Value {
|
| 205 |
+
tokEmb := state["wte"][tokenID]
|
| 206 |
+
posEmb := state["wpe"][posID]
|
| 207 |
+
x := make([]*Value, len(tokEmb))
|
| 208 |
+
for i := range tokEmb {
|
| 209 |
+
x[i] = Add(tokEmb[i], posEmb[i])
|
| 210 |
+
}
|
| 211 |
+
x = rmsnorm(x)
|
| 212 |
+
|
| 213 |
+
for li := 0; li < nLayer; li++ {
|
| 214 |
+
xResidual := x
|
| 215 |
+
x = rmsnorm(x)
|
| 216 |
+
q := linear(x, state[fmt.Sprintf("layer%d.attn_wq", li)])
|
| 217 |
+
k := linear(x, state[fmt.Sprintf("layer%d.attn_wk", li)])
|
| 218 |
+
v := linear(x, state[fmt.Sprintf("layer%d.attn_wv", li)])
|
| 219 |
+
keys[li] = append(keys[li], k)
|
| 220 |
+
values[li] = append(values[li], v)
|
| 221 |
+
|
| 222 |
+
xAttn := make([]*Value, 0, nEmbd)
|
| 223 |
+
for h := 0; h < nHead; h++ {
|
| 224 |
+
hs := h * headDim
|
| 225 |
+
qH := q[hs : hs+headDim]
|
| 226 |
+
|
| 227 |
+
kH := make([][]*Value, len(keys[li]))
|
| 228 |
+
vH := make([][]*Value, len(values[li]))
|
| 229 |
+
for t := 0; t < len(keys[li]); t++ {
|
| 230 |
+
kH[t] = keys[li][t][hs : hs+headDim]
|
| 231 |
+
vH[t] = values[li][t][hs : hs+headDim]
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
attnLogits := make([]*Value, len(kH))
|
| 235 |
+
for t := 0; t < len(kH); t++ {
|
| 236 |
+
score := V(0)
|
| 237 |
+
for j := 0; j < headDim; j++ {
|
| 238 |
+
score = Add(score, Mul(qH[j], kH[t][j]))
|
| 239 |
+
}
|
| 240 |
+
attnLogits[t] = Div(score, V(math.Sqrt(float64(headDim))))
|
| 241 |
+
}
|
| 242 |
+
attnWeights := softmax(attnLogits)
|
| 243 |
+
|
| 244 |
+
headOut := make([]*Value, headDim)
|
| 245 |
+
for j := 0; j < headDim; j++ {
|
| 246 |
+
s := V(0)
|
| 247 |
+
for t := 0; t < len(vH); t++ {
|
| 248 |
+
s = Add(s, Mul(attnWeights[t], vH[t][j]))
|
| 249 |
+
}
|
| 250 |
+
headOut[j] = s
|
| 251 |
+
}
|
| 252 |
+
xAttn = append(xAttn, headOut...)
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
x = linear(xAttn, state[fmt.Sprintf("layer%d.attn_wo", li)])
|
| 256 |
+
for i := range x {
|
| 257 |
+
x[i] = Add(x[i], xResidual[i])
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
xResidual = x
|
| 261 |
+
x = rmsnorm(x)
|
| 262 |
+
x = linear(x, state[fmt.Sprintf("layer%d.mlp_fc1", li)])
|
| 263 |
+
for i := range x {
|
| 264 |
+
x[i] = ReLU(x[i])
|
| 265 |
+
}
|
| 266 |
+
x = linear(x, state[fmt.Sprintf("layer%d.mlp_fc2", li)])
|
| 267 |
+
for i := range x {
|
| 268 |
+
x[i] = Add(x[i], xResidual[i])
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
return linear(x, state["lm_head"])
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
// Sampling functions
|
| 277 |
+
func SampleWeighted(weights []float64) int {
|
| 278 |
+
sum := 0.0
|
| 279 |
+
for _, w := range weights {
|
| 280 |
+
sum += w
|
| 281 |
+
}
|
| 282 |
+
r := rand.Float64() * sum
|
| 283 |
+
running := 0.0
|
| 284 |
+
for i, w := range weights {
|
| 285 |
+
running += w
|
| 286 |
+
if r <= running {
|
| 287 |
+
return i
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
return len(weights) - 1
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
func SoftmaxFloat(logits []float64) []float64 {
|
| 294 |
+
maxLogit := -math.MaxFloat64
|
| 295 |
+
for _, l := range logits {
|
| 296 |
+
if l > maxLogit {
|
| 297 |
+
maxLogit = l
|
| 298 |
+
}
|
| 299 |
+
}
|
| 300 |
+
sum := 0.0
|
| 301 |
+
out := make([]float64, len(logits))
|
| 302 |
+
for i, l := range logits {
|
| 303 |
+
out[i] = math.Exp(l - maxLogit)
|
| 304 |
+
sum += out[i]
|
| 305 |
+
}
|
| 306 |
+
for i := range out {
|
| 307 |
+
out[i] /= sum
|
| 308 |
+
}
|
| 309 |
+
return out
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
func NextTokenWeights(logits []*Value, temperature float64, topK int, topP float64, recent map[int]bool, repetitionPenalty float64) []float64 {
|
| 313 |
+
l := make([]float64, len(logits))
|
| 314 |
+
for i, v := range logits {
|
| 315 |
+
l[i] = v.Data
|
| 316 |
+
if recent[i] {
|
| 317 |
+
if l[i] >= 0 {
|
| 318 |
+
l[i] /= repetitionPenalty
|
| 319 |
+
} else {
|
| 320 |
+
l[i] *= repetitionPenalty
|
| 321 |
+
}
|
| 322 |
+
}
|
| 323 |
+
l[i] /= temperature
|
| 324 |
+
}
|
| 325 |
+
w := SoftmaxFloat(l)
|
| 326 |
+
if topK > 0 {
|
| 327 |
+
w = ApplyTopK(w, topK)
|
| 328 |
+
}
|
| 329 |
+
if topP > 0 && topP < 1.0 {
|
| 330 |
+
w = ApplyTopP(w, topP)
|
| 331 |
+
}
|
| 332 |
+
return w
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
func ApplyTopK(weights []float64, k int) []float64 {
|
| 336 |
+
if k >= len(weights) {
|
| 337 |
+
return weights
|
| 338 |
+
}
|
| 339 |
+
type kv struct {
|
| 340 |
+
i int
|
| 341 |
+
w float64
|
| 342 |
+
}
|
| 343 |
+
arr := make([]kv, len(weights))
|
| 344 |
+
for i, w := range weights {
|
| 345 |
+
arr[i] = kv{i, w}
|
| 346 |
+
}
|
| 347 |
+
sort.Slice(arr, func(i, j int) bool { return arr[i].w > arr[j].w })
|
| 348 |
+
out := make([]float64, len(weights))
|
| 349 |
+
for i := 0; i < k; i++ {
|
| 350 |
+
out[arr[i].i] = arr[i].w
|
| 351 |
+
}
|
| 352 |
+
return out
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
func ApplyTopP(weights []float64, p float64) []float64 {
|
| 356 |
+
type kv struct {
|
| 357 |
+
i int
|
| 358 |
+
w float64
|
| 359 |
+
}
|
| 360 |
+
arr := make([]kv, len(weights))
|
| 361 |
+
for i, w := range weights {
|
| 362 |
+
arr[i] = kv{i, w}
|
| 363 |
+
}
|
| 364 |
+
sort.Slice(arr, func(i, j int) bool { return arr[i].w > arr[j].w })
|
| 365 |
+
out := make([]float64, len(weights))
|
| 366 |
+
sum := 0.0
|
| 367 |
+
for i := 0; i < len(arr); i++ {
|
| 368 |
+
sum += arr[i].w
|
| 369 |
+
out[arr[i].i] = arr[i].w
|
| 370 |
+
if sum >= p {
|
| 371 |
+
break
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
return out
|
| 375 |
+
}
|
pkg/model/tokenizer.go
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package model
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"fmt"
|
| 5 |
+
"strings"
|
| 6 |
+
|
| 7 |
+
tiktoken "github.com/pkoukk/tiktoken-go"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
type TokenizerRuntime struct {
|
| 11 |
+
Mode string
|
| 12 |
+
CharToLocal map[rune]int
|
| 13 |
+
LocalToChar []rune
|
| 14 |
+
BpeEncoding string
|
| 15 |
+
Bpe *tiktoken.Tiktoken
|
| 16 |
+
BpeToLocal map[int]int
|
| 17 |
+
LocalToBPE []int
|
| 18 |
+
UnkID int
|
| 19 |
+
BosID int
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
func (t TokenizerRuntime) VocabSize() int {
|
| 23 |
+
if t.Mode == "bpe_cl100k" {
|
| 24 |
+
return len(t.LocalToBPE) + 2
|
| 25 |
+
}
|
| 26 |
+
return len(t.LocalToChar) + 1
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
func (t TokenizerRuntime) EncodeDoc(doc string) []int {
|
| 30 |
+
if t.Mode == "bpe_cl100k" {
|
| 31 |
+
raw := t.Bpe.EncodeOrdinary(doc)
|
| 32 |
+
out := make([]int, 0, len(raw))
|
| 33 |
+
for _, id := range raw {
|
| 34 |
+
if local, ok := t.BpeToLocal[id]; ok {
|
| 35 |
+
out = append(out, local)
|
| 36 |
+
} else {
|
| 37 |
+
out = append(out, t.UnkID)
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
return out
|
| 41 |
+
}
|
| 42 |
+
out := make([]int, 0, len(doc))
|
| 43 |
+
for _, r := range doc {
|
| 44 |
+
if id, ok := t.CharToLocal[r]; ok {
|
| 45 |
+
out = append(out, id)
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
return out
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
func (t TokenizerRuntime) DecodeTokens(tokens []int) string {
|
| 52 |
+
if t.Mode == "bpe_cl100k" {
|
| 53 |
+
raw := make([]int, 0, len(tokens))
|
| 54 |
+
for _, local := range tokens {
|
| 55 |
+
if local >= 0 && local < len(t.LocalToBPE) {
|
| 56 |
+
raw = append(raw, t.LocalToBPE[local])
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
return t.Bpe.Decode(raw)
|
| 60 |
+
}
|
| 61 |
+
out := make([]rune, 0, len(tokens))
|
| 62 |
+
for _, id := range tokens {
|
| 63 |
+
if id >= 0 && id < len(t.LocalToChar) {
|
| 64 |
+
out = append(out, t.LocalToChar[id])
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
return string(out)
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
func TokenizerFromCheckpoint(ckpt TrainingCheckpoint) (TokenizerRuntime, error) {
|
| 71 |
+
if ckpt.Tokenization == "bpe_cl100k" || len(ckpt.BPETokenIDs) > 0 {
|
| 72 |
+
encName := strings.TrimSpace(ckpt.BPEEncoding)
|
| 73 |
+
if encName == "" {
|
| 74 |
+
encName = "cl100k_base"
|
| 75 |
+
}
|
| 76 |
+
enc, err := tiktoken.GetEncoding(encName)
|
| 77 |
+
if err != nil {
|
| 78 |
+
return TokenizerRuntime{}, err
|
| 79 |
+
}
|
| 80 |
+
localToBPE := append([]int(nil), ckpt.BPETokenIDs...)
|
| 81 |
+
bpeToLocal := make(map[int]int, len(localToBPE))
|
| 82 |
+
for i, id := range localToBPE {
|
| 83 |
+
bpeToLocal[id] = i
|
| 84 |
+
}
|
| 85 |
+
return TokenizerRuntime{
|
| 86 |
+
Mode: "bpe_cl100k",
|
| 87 |
+
BpeEncoding: encName,
|
| 88 |
+
Bpe: enc,
|
| 89 |
+
BpeToLocal: bpeToLocal,
|
| 90 |
+
LocalToBPE: localToBPE,
|
| 91 |
+
UnkID: len(localToBPE),
|
| 92 |
+
BosID: len(localToBPE) + 1,
|
| 93 |
+
}, nil
|
| 94 |
+
}
|
| 95 |
+
uchars, err := stringsToRunes(ckpt.Vocab)
|
| 96 |
+
if err != nil {
|
| 97 |
+
return TokenizerRuntime{}, err
|
| 98 |
+
}
|
| 99 |
+
if len(uchars) == 0 {
|
| 100 |
+
return TokenizerRuntime{}, fmt.Errorf("checkpoint has empty character vocab")
|
| 101 |
+
}
|
| 102 |
+
charToLocal := make(map[rune]int, len(uchars))
|
| 103 |
+
for i, r := range uchars {
|
| 104 |
+
charToLocal[r] = i
|
| 105 |
+
}
|
| 106 |
+
return TokenizerRuntime{
|
| 107 |
+
Mode: "char",
|
| 108 |
+
CharToLocal: charToLocal,
|
| 109 |
+
LocalToChar: uchars,
|
| 110 |
+
BosID: len(uchars),
|
| 111 |
+
}, nil
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
func stringsToRunes(ss []string) ([]rune, error) {
|
| 115 |
+
out := make([]rune, 0, len(ss))
|
| 116 |
+
for _, s := range ss {
|
| 117 |
+
r := []rune(s)
|
| 118 |
+
if len(r) != 1 {
|
| 119 |
+
return nil, fmt.Errorf("invalid vocab token %q: expected one rune", s)
|
| 120 |
+
}
|
| 121 |
+
out = append(out, r[0])
|
| 122 |
+
}
|
| 123 |
+
return out, nil
|
| 124 |
+
}
|