CompactAI commited on
Commit
239cc86
·
verified ·
1 Parent(s): 0643434

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. checkpoint.pt +3 -0
  3. infer.py +61 -0
  4. model.pt +3 -0
  5. plot.py +86 -0
  6. train_output.log +2 -0
  7. training_curves.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ training_curves.png filter=lfs diff=lfs merge=lfs -text
checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:038c76ebeba56444ee178bd38c2b65def32846dd6a40ac4428cc422b95e55ca4
3
+ size 10311296
infer.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import sys
4
+
5
+ ALPHABET = "0123456789+-*=()| "
6
+ N_DIM = len(ALPHABET)
7
+ CHAR_TO_INT = {c: i for i, c in enumerate(ALPHABET)}
8
+ INT_TO_CHAR = {i: c for i, c in enumerate(ALPHABET)}
9
+ SPACE_IDX = CHAR_TO_INT.get(' ', 0)
10
+ CONTEXT_SIZE = 64
11
+ MAX_GEN = 60
12
+
13
+ def generate(model, prompt_text):
14
+ clean = prompt_text.replace(' ', '')
15
+ stripped = ''.join(c for c in clean if c in CHAR_TO_INT)
16
+ if '=' not in stripped:
17
+ return prompt_text + '?'
18
+ prompt = stripped[:stripped.index('=') + 1]
19
+ generated = ''
20
+ for _ in range(MAX_GEN):
21
+ prefix = (prompt + generated).rjust(CONTEXT_SIZE, ' ')[:CONTEXT_SIZE]
22
+ indices = torch.tensor([CHAR_TO_INT.get(c, SPACE_IDX) for c in prefix], dtype=torch.long)
23
+ x = F.one_hot(indices, num_classes=N_DIM).float().view(1, -1)
24
+ with torch.no_grad():
25
+ logits = model(x)
26
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
27
+ char = INT_TO_CHAR.get(next_token.item(), '')
28
+ if not char or char == ' ' or char == '|':
29
+ break
30
+ generated += char
31
+ return prompt_text[:prompt_text.index('=') + 1] + generated
32
+
33
+ def main():
34
+ if len(sys.argv) < 2:
35
+ print("Usage: python infer.py <model.pt> [prompt]")
36
+ print(" python infer.py <model.pt> (interactive mode)")
37
+ sys.exit(1)
38
+
39
+ model_path = sys.argv[1]
40
+ model = torch.jit.load(model_path)
41
+ model.eval()
42
+ print(f"Loaded model from {model_path}", file=sys.stderr)
43
+
44
+ if len(sys.argv) >= 3:
45
+ prompt = sys.argv[2]
46
+ result = generate(model, prompt)
47
+ print(result)
48
+ else:
49
+ while True:
50
+ try:
51
+ prompt = input("Prompt > ")
52
+ if prompt.lower() == 'q':
53
+ break
54
+ print(generate(model, prompt))
55
+ except (EOFError, KeyboardInterrupt):
56
+ break
57
+ except Exception as e:
58
+ print(f"Error: {e}")
59
+
60
+ if __name__ == "__main__":
61
+ main()
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58eaa26d7bdbbe02bf5dfe7070aa6922cfb681419a7eab7905e8c6f37cfb3aee
3
+ size 10375165
plot.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use('Agg')
3
+ import matplotlib.pyplot as plt
4
+ import sys
5
+ import re
6
+
7
+ def load_loss_data(log_path='train_output.log'):
8
+ """Load loss data from training log."""
9
+ steps_after, losses_after = [], []
10
+ with open(log_path) as f:
11
+ for line in f:
12
+ m = re.search(r'Step (\d+) .* Loss: ([\d.]+)', line)
13
+ if m:
14
+ steps_after.append(int(m.group(1)))
15
+ losses_after.append(float(m.group(2)))
16
+ if not steps_after:
17
+ return None, None
18
+ return steps_after, losses_after
19
+
20
+ EARLY_LOSS = [
21
+ (0, 2.9836), (500, 1.2863), (1000, 0.8944), (1500, 0.6346),
22
+ (2000, 0.4688), (2500, 0.3735), (3000, 0.2973), (3500, 0.2215),
23
+ (4000, 0.1777), (4500, 0.1588), (5000, 0.1440), (5500, 0.1289),
24
+ (6000, 0.1050), (6500, 0.1028), (7000, 0.1009), (7500, 0.0914),
25
+ (8000, 0.0778), (8500, 0.0769), (9000, 0.0704), (9500, 0.0686),
26
+ (10000, 0.0640), (10500, 0.0696), (11000, 0.0676), (11500, 0.0663),
27
+ (12000, 0.0492), (12500, 0.0590), (13000, 0.0515), (13500, 0.0495),
28
+ (14000, 0.0507), (14500, 0.0522), (15000, 0.0402), (15500, 0.0414),
29
+ (16000, 0.0484), (16500, 0.0444), (17000, 0.0380), (17500, 0.0399),
30
+ (18000, 0.0384), (18500, 0.0359), (19000, 0.0379), (19500, 0.0362),
31
+ (20000, 0.0339), (20500, 0.0338),
32
+ ]
33
+
34
+ def plot(output_path='training_curves.png'):
35
+ steps_after, losses_after = load_loss_data()
36
+ steps_early = [s for s, _ in EARLY_LOSS]
37
+ losses_early = [l for _, l in EARLY_LOSS]
38
+
39
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))
40
+ for ax in [ax1, ax2]:
41
+ ax.set_facecolor('#f8f9fa')
42
+
43
+ if steps_early:
44
+ ax1.plot(steps_early, losses_early, color='#2563eb', linewidth=1.5,
45
+ alpha=0.7, label='Steps 0–20500')
46
+ if steps_after:
47
+ ax1.plot(steps_after, losses_after, color='#dc2626', linewidth=1.5,
48
+ alpha=0.7, label='Steps 21000–49500')
49
+
50
+ all_steps = (steps_early or []) + (steps_after or [])
51
+ all_losses = (losses_early or []) + (losses_after or [])
52
+ if all_steps:
53
+ ax1.fill_between(all_steps, all_losses, alpha=0.06, color='#2563eb')
54
+
55
+ ax1.axvline(x=21000, color='#888', linewidth=0.8, linestyle=':', alpha=0.6)
56
+ ax1.text(21000, max(all_losses) * 0.9, 'resume', fontsize=8, color='#888',
57
+ ha='center', va='top', style='italic')
58
+
59
+ ax1.set_xlabel('Training Step', fontsize=10)
60
+ ax1.set_ylabel('Cross-Entropy Loss', fontsize=10)
61
+ ax1.set_title('Training Loss', fontsize=12, fontweight='bold', pad=10)
62
+ ax1.set_yscale('log')
63
+ ax1.grid(True, alpha=0.25, linestyle='--')
64
+ ax1.legend(fontsize=8, loc='upper right')
65
+
66
+ bars = ax2.bar(['Random 500\nExpressions', 'Fixed\nBenchmark'], [91.6, 100.0],
67
+ color=['#2563eb', '#16a34a'], width=0.5, edgecolor='white', linewidth=1.5)
68
+ for bar, val in zip(bars, [91.6, 100.0]):
69
+ ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1.2,
70
+ f'{val:.1f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')
71
+ ax2.set_ylim(0, 108)
72
+ ax2.set_ylabel('Accuracy', fontsize=10)
73
+ ax2.set_title('Inference Accuracy', fontsize=12, fontweight='bold', pad=10)
74
+ ax2.grid(True, alpha=0.25, linestyle='--', axis='y')
75
+ ax2.spines['top'].set_visible(False)
76
+ ax2.spines['right'].set_visible(False)
77
+
78
+ fig.suptitle('Arithmetic Reasoner — TrueACT 1-Layer',
79
+ fontsize=14, fontweight='bold', y=1.02)
80
+ plt.tight_layout()
81
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
82
+ print(f'Saved {output_path}')
83
+
84
+ if __name__ == '__main__':
85
+ out = sys.argv[1] if len(sys.argv) > 1 else 'training_curves.png'
86
+ plot(out)
train_output.log ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Resumed ACT model at step 21000
2
+
training_curves.png ADDED

Git LFS Details

  • SHA256: bba5c9abfca344d2ae421a2f898f28d3126cd076fc6abf84fd84dbec4cdbdfcb
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB