Upload 6 files
Browse files- .gitattributes +1 -0
- checkpoint.pt +3 -0
- infer.py +61 -0
- model.pt +3 -0
- plot.py +86 -0
- train_output.log +2 -0
- training_curves.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
training_curves.png filter=lfs diff=lfs merge=lfs -text
|
checkpoint.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:038c76ebeba56444ee178bd38c2b65def32846dd6a40ac4428cc422b95e55ca4
|
| 3 |
+
size 10311296
|
infer.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
ALPHABET = "0123456789+-*=()| "
|
| 6 |
+
N_DIM = len(ALPHABET)
|
| 7 |
+
CHAR_TO_INT = {c: i for i, c in enumerate(ALPHABET)}
|
| 8 |
+
INT_TO_CHAR = {i: c for i, c in enumerate(ALPHABET)}
|
| 9 |
+
SPACE_IDX = CHAR_TO_INT.get(' ', 0)
|
| 10 |
+
CONTEXT_SIZE = 64
|
| 11 |
+
MAX_GEN = 60
|
| 12 |
+
|
| 13 |
+
def generate(model, prompt_text):
|
| 14 |
+
clean = prompt_text.replace(' ', '')
|
| 15 |
+
stripped = ''.join(c for c in clean if c in CHAR_TO_INT)
|
| 16 |
+
if '=' not in stripped:
|
| 17 |
+
return prompt_text + '?'
|
| 18 |
+
prompt = stripped[:stripped.index('=') + 1]
|
| 19 |
+
generated = ''
|
| 20 |
+
for _ in range(MAX_GEN):
|
| 21 |
+
prefix = (prompt + generated).rjust(CONTEXT_SIZE, ' ')[:CONTEXT_SIZE]
|
| 22 |
+
indices = torch.tensor([CHAR_TO_INT.get(c, SPACE_IDX) for c in prefix], dtype=torch.long)
|
| 23 |
+
x = F.one_hot(indices, num_classes=N_DIM).float().view(1, -1)
|
| 24 |
+
with torch.no_grad():
|
| 25 |
+
logits = model(x)
|
| 26 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
| 27 |
+
char = INT_TO_CHAR.get(next_token.item(), '')
|
| 28 |
+
if not char or char == ' ' or char == '|':
|
| 29 |
+
break
|
| 30 |
+
generated += char
|
| 31 |
+
return prompt_text[:prompt_text.index('=') + 1] + generated
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
if len(sys.argv) < 2:
|
| 35 |
+
print("Usage: python infer.py <model.pt> [prompt]")
|
| 36 |
+
print(" python infer.py <model.pt> (interactive mode)")
|
| 37 |
+
sys.exit(1)
|
| 38 |
+
|
| 39 |
+
model_path = sys.argv[1]
|
| 40 |
+
model = torch.jit.load(model_path)
|
| 41 |
+
model.eval()
|
| 42 |
+
print(f"Loaded model from {model_path}", file=sys.stderr)
|
| 43 |
+
|
| 44 |
+
if len(sys.argv) >= 3:
|
| 45 |
+
prompt = sys.argv[2]
|
| 46 |
+
result = generate(model, prompt)
|
| 47 |
+
print(result)
|
| 48 |
+
else:
|
| 49 |
+
while True:
|
| 50 |
+
try:
|
| 51 |
+
prompt = input("Prompt > ")
|
| 52 |
+
if prompt.lower() == 'q':
|
| 53 |
+
break
|
| 54 |
+
print(generate(model, prompt))
|
| 55 |
+
except (EOFError, KeyboardInterrupt):
|
| 56 |
+
break
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Error: {e}")
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
main()
|
model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:58eaa26d7bdbbe02bf5dfe7070aa6922cfb681419a7eab7905e8c6f37cfb3aee
|
| 3 |
+
size 10375165
|
plot.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
matplotlib.use('Agg')
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import sys
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
def load_loss_data(log_path='train_output.log'):
|
| 8 |
+
"""Load loss data from training log."""
|
| 9 |
+
steps_after, losses_after = [], []
|
| 10 |
+
with open(log_path) as f:
|
| 11 |
+
for line in f:
|
| 12 |
+
m = re.search(r'Step (\d+) .* Loss: ([\d.]+)', line)
|
| 13 |
+
if m:
|
| 14 |
+
steps_after.append(int(m.group(1)))
|
| 15 |
+
losses_after.append(float(m.group(2)))
|
| 16 |
+
if not steps_after:
|
| 17 |
+
return None, None
|
| 18 |
+
return steps_after, losses_after
|
| 19 |
+
|
| 20 |
+
EARLY_LOSS = [
|
| 21 |
+
(0, 2.9836), (500, 1.2863), (1000, 0.8944), (1500, 0.6346),
|
| 22 |
+
(2000, 0.4688), (2500, 0.3735), (3000, 0.2973), (3500, 0.2215),
|
| 23 |
+
(4000, 0.1777), (4500, 0.1588), (5000, 0.1440), (5500, 0.1289),
|
| 24 |
+
(6000, 0.1050), (6500, 0.1028), (7000, 0.1009), (7500, 0.0914),
|
| 25 |
+
(8000, 0.0778), (8500, 0.0769), (9000, 0.0704), (9500, 0.0686),
|
| 26 |
+
(10000, 0.0640), (10500, 0.0696), (11000, 0.0676), (11500, 0.0663),
|
| 27 |
+
(12000, 0.0492), (12500, 0.0590), (13000, 0.0515), (13500, 0.0495),
|
| 28 |
+
(14000, 0.0507), (14500, 0.0522), (15000, 0.0402), (15500, 0.0414),
|
| 29 |
+
(16000, 0.0484), (16500, 0.0444), (17000, 0.0380), (17500, 0.0399),
|
| 30 |
+
(18000, 0.0384), (18500, 0.0359), (19000, 0.0379), (19500, 0.0362),
|
| 31 |
+
(20000, 0.0339), (20500, 0.0338),
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
def plot(output_path='training_curves.png'):
|
| 35 |
+
steps_after, losses_after = load_loss_data()
|
| 36 |
+
steps_early = [s for s, _ in EARLY_LOSS]
|
| 37 |
+
losses_early = [l for _, l in EARLY_LOSS]
|
| 38 |
+
|
| 39 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))
|
| 40 |
+
for ax in [ax1, ax2]:
|
| 41 |
+
ax.set_facecolor('#f8f9fa')
|
| 42 |
+
|
| 43 |
+
if steps_early:
|
| 44 |
+
ax1.plot(steps_early, losses_early, color='#2563eb', linewidth=1.5,
|
| 45 |
+
alpha=0.7, label='Steps 0–20500')
|
| 46 |
+
if steps_after:
|
| 47 |
+
ax1.plot(steps_after, losses_after, color='#dc2626', linewidth=1.5,
|
| 48 |
+
alpha=0.7, label='Steps 21000–49500')
|
| 49 |
+
|
| 50 |
+
all_steps = (steps_early or []) + (steps_after or [])
|
| 51 |
+
all_losses = (losses_early or []) + (losses_after or [])
|
| 52 |
+
if all_steps:
|
| 53 |
+
ax1.fill_between(all_steps, all_losses, alpha=0.06, color='#2563eb')
|
| 54 |
+
|
| 55 |
+
ax1.axvline(x=21000, color='#888', linewidth=0.8, linestyle=':', alpha=0.6)
|
| 56 |
+
ax1.text(21000, max(all_losses) * 0.9, 'resume', fontsize=8, color='#888',
|
| 57 |
+
ha='center', va='top', style='italic')
|
| 58 |
+
|
| 59 |
+
ax1.set_xlabel('Training Step', fontsize=10)
|
| 60 |
+
ax1.set_ylabel('Cross-Entropy Loss', fontsize=10)
|
| 61 |
+
ax1.set_title('Training Loss', fontsize=12, fontweight='bold', pad=10)
|
| 62 |
+
ax1.set_yscale('log')
|
| 63 |
+
ax1.grid(True, alpha=0.25, linestyle='--')
|
| 64 |
+
ax1.legend(fontsize=8, loc='upper right')
|
| 65 |
+
|
| 66 |
+
bars = ax2.bar(['Random 500\nExpressions', 'Fixed\nBenchmark'], [91.6, 100.0],
|
| 67 |
+
color=['#2563eb', '#16a34a'], width=0.5, edgecolor='white', linewidth=1.5)
|
| 68 |
+
for bar, val in zip(bars, [91.6, 100.0]):
|
| 69 |
+
ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1.2,
|
| 70 |
+
f'{val:.1f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')
|
| 71 |
+
ax2.set_ylim(0, 108)
|
| 72 |
+
ax2.set_ylabel('Accuracy', fontsize=10)
|
| 73 |
+
ax2.set_title('Inference Accuracy', fontsize=12, fontweight='bold', pad=10)
|
| 74 |
+
ax2.grid(True, alpha=0.25, linestyle='--', axis='y')
|
| 75 |
+
ax2.spines['top'].set_visible(False)
|
| 76 |
+
ax2.spines['right'].set_visible(False)
|
| 77 |
+
|
| 78 |
+
fig.suptitle('Arithmetic Reasoner — TrueACT 1-Layer',
|
| 79 |
+
fontsize=14, fontweight='bold', y=1.02)
|
| 80 |
+
plt.tight_layout()
|
| 81 |
+
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
| 82 |
+
print(f'Saved {output_path}')
|
| 83 |
+
|
| 84 |
+
if __name__ == '__main__':
|
| 85 |
+
out = sys.argv[1] if len(sys.argv) > 1 else 'training_curves.png'
|
| 86 |
+
plot(out)
|
train_output.log
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Resumed ACT model at step 21000
|
| 2 |
+
|
training_curves.png
ADDED
|
Git LFS Details
|