BcantCode commited on
Commit
65a793e
·
verified ·
1 Parent(s): 5dccaf9

Upload src/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/train.py +570 -0
src/train.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for GazeInception-Lite model.
3
+
4
+ Trains both:
5
+ 1. Single-eye model (89K params) - ultralight for constrained devices
6
+ 2. Dual-eye model (137K params) - better accuracy with face context + lazy eye handling
7
+
8
+ Features:
9
+ - Gated Inception blocks with learned branch gating
10
+ - Coordinate Attention for spatial gaze awareness
11
+ - Synthetic data with: dark conditions, glasses, lazy eye, sensor noise
12
+ - TFLite conversion with full integer quantization
13
+ - Push to Hugging Face Hub
14
+
15
+ Based on:
16
+ - AGE framework (arxiv:2603.26945) - augmentation pipeline, multi-task approach
17
+ - Gated Compression Layers (arxiv:2303.08970) - gating mechanism
18
+ - iTracker (arxiv:1606.05814) - dual-eye + face architecture
19
+ """
20
+
21
+ import os
22
+ import json
23
+ import time
24
+ import numpy as np
25
+ import tensorflow as tf
26
+ from tensorflow import keras
27
+ from pathlib import Path
28
+
29
+ # Import our modules
30
+ from model import build_gaze_inception_lite, build_dual_eye_model
31
+ from data_generator import SyntheticGazeDataGenerator, create_tf_dataset, create_single_eye_dataset
32
+
33
+
34
+ def euclidean_distance_metric(y_true, y_pred):
35
+ """Euclidean distance in normalized [0,1] coordinates."""
36
+ return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(y_true - y_pred), axis=-1)))
37
+
38
+
39
+ def screen_error_mm(y_true, y_pred, screen_w_mm=65.0, screen_h_mm=140.0):
40
+ """Error in mm assuming typical phone screen (65mm x 140mm)."""
41
+ diff = y_true - y_pred
42
+ diff_mm = diff * tf.constant([screen_w_mm, screen_h_mm])
43
+ return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(diff_mm), axis=-1)))
44
+
45
+
46
+ def train_single_eye_model(train_data, val_data, epochs=100, batch_size=128,
47
+ output_dir='models/single_eye'):
48
+ """Train the lightweight single-eye model."""
49
+ print("\n" + "="*60)
50
+ print("Training Single-Eye GazeInception-Lite Model")
51
+ print("="*60)
52
+
53
+ os.makedirs(output_dir, exist_ok=True)
54
+
55
+ model = build_gaze_inception_lite(input_shape=(64, 64, 3), num_outputs=2)
56
+
57
+ # Cosine decay learning rate
58
+ lr_schedule = keras.optimizers.schedules.CosineDecay(
59
+ initial_learning_rate=1e-3,
60
+ decay_steps=epochs * (len(train_data['gaze']) * 2 // batch_size),
61
+ alpha=1e-5
62
+ )
63
+
64
+ optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
65
+
66
+ model.compile(
67
+ optimizer=optimizer,
68
+ loss='mse',
69
+ metrics=[euclidean_distance_metric]
70
+ )
71
+
72
+ # Create datasets
73
+ train_ds = create_single_eye_dataset(train_data, batch_size=batch_size, shuffle=True)
74
+ val_ds = create_single_eye_dataset(val_data, batch_size=batch_size, shuffle=False)
75
+
76
+ callbacks = [
77
+ keras.callbacks.ModelCheckpoint(
78
+ os.path.join(output_dir, 'best_model.keras'),
79
+ monitor='val_euclidean_distance_metric',
80
+ save_best_only=True, mode='min', verbose=1
81
+ ),
82
+ keras.callbacks.ReduceLROnPlateau(
83
+ monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1
84
+ ),
85
+ keras.callbacks.EarlyStopping(
86
+ monitor='val_euclidean_distance_metric', patience=20,
87
+ restore_best_weights=True, verbose=1
88
+ ),
89
+ ]
90
+
91
+ history = model.fit(
92
+ train_ds, validation_data=val_ds,
93
+ epochs=epochs, callbacks=callbacks, verbose=1
94
+ )
95
+
96
+ # Save final model
97
+ model.save(os.path.join(output_dir, 'final_model.keras'))
98
+
99
+ return model, history
100
+
101
+
102
+ def train_dual_eye_model(train_data, val_data, epochs=100, batch_size=64,
103
+ output_dir='models/dual_eye'):
104
+ """Train the dual-eye model with face context."""
105
+ print("\n" + "="*60)
106
+ print("Training Dual-Eye GazeInception-Lite Model")
107
+ print("="*60)
108
+
109
+ os.makedirs(output_dir, exist_ok=True)
110
+
111
+ model = build_dual_eye_model(
112
+ eye_shape=(64, 64, 3), face_shape=(64, 64, 3), num_outputs=2
113
+ )
114
+
115
+ lr_schedule = keras.optimizers.schedules.CosineDecay(
116
+ initial_learning_rate=1e-3,
117
+ decay_steps=epochs * (len(train_data['gaze']) // batch_size),
118
+ alpha=1e-5
119
+ )
120
+
121
+ optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
122
+
123
+ model.compile(
124
+ optimizer=optimizer,
125
+ loss='mse',
126
+ metrics=[euclidean_distance_metric]
127
+ )
128
+
129
+ # Create datasets
130
+ train_ds = create_tf_dataset(train_data, batch_size=batch_size, shuffle=True)
131
+ val_ds = create_tf_dataset(val_data, batch_size=batch_size, shuffle=False)
132
+
133
+ callbacks = [
134
+ keras.callbacks.ModelCheckpoint(
135
+ os.path.join(output_dir, 'best_model.keras'),
136
+ monitor='val_euclidean_distance_metric',
137
+ save_best_only=True, mode='min', verbose=1
138
+ ),
139
+ keras.callbacks.ReduceLROnPlateau(
140
+ monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1
141
+ ),
142
+ keras.callbacks.EarlyStopping(
143
+ monitor='val_euclidean_distance_metric', patience=20,
144
+ restore_best_weights=True, verbose=1
145
+ ),
146
+ ]
147
+
148
+ history = model.fit(
149
+ train_ds, validation_data=val_ds,
150
+ epochs=epochs, callbacks=callbacks, verbose=1
151
+ )
152
+
153
+ model.save(os.path.join(output_dir, 'final_model.keras'))
154
+
155
+ return model, history
156
+
157
+
158
+ def convert_to_tflite(keras_model, output_path, quantize=True, test_data=None):
159
+ """Convert Keras model to TFLite with optional quantization."""
160
+ print(f"\nConverting to TFLite: {output_path}")
161
+
162
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
163
+
164
+ # Optimization settings for mobile
165
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
166
+
167
+ if quantize and test_data is not None:
168
+ # Full integer quantization (fastest on mobile)
169
+ def representative_dataset_gen():
170
+ for i in range(min(200, len(test_data))):
171
+ yield [test_data[i:i+1].astype(np.float32)]
172
+
173
+ converter.representative_dataset = representative_dataset_gen
174
+ converter.target_spec.supported_ops = [
175
+ tf.lite.OpsSet.TFLITE_BUILTINS_INT8
176
+ ]
177
+ converter.inference_input_type = tf.uint8
178
+ converter.inference_output_type = tf.float32
179
+ print(" Using INT8 quantization for maximum mobile speed")
180
+
181
+ tflite_model = converter.convert()
182
+
183
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
184
+ with open(output_path, 'wb') as f:
185
+ f.write(tflite_model)
186
+
187
+ size_kb = len(tflite_model) / 1024
188
+ print(f" TFLite model size: {size_kb:.1f} KB")
189
+
190
+ return tflite_model
191
+
192
+
193
+ def convert_dual_eye_to_tflite(keras_model, output_path, quantize=True, test_data=None):
194
+ """Convert dual-eye model to TFLite."""
195
+ print(f"\nConverting dual-eye model to TFLite: {output_path}")
196
+
197
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
198
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
199
+
200
+ if quantize and test_data is not None:
201
+ def representative_dataset_gen():
202
+ for i in range(min(200, len(test_data['left_eye']))):
203
+ yield [
204
+ test_data['left_eye'][i:i+1].astype(np.float32),
205
+ test_data['right_eye'][i:i+1].astype(np.float32),
206
+ test_data['face'][i:i+1].astype(np.float32),
207
+ ]
208
+
209
+ converter.representative_dataset = representative_dataset_gen
210
+ converter.target_spec.supported_ops = [
211
+ tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
212
+ tf.lite.OpsSet.TFLITE_BUILTINS, # fallback for unsupported ops
213
+ ]
214
+ converter.inference_input_type = tf.uint8
215
+ converter.inference_output_type = tf.float32
216
+ print(" Using INT8 quantization for maximum mobile speed")
217
+
218
+ tflite_model = converter.convert()
219
+
220
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
221
+ with open(output_path, 'wb') as f:
222
+ f.write(tflite_model)
223
+
224
+ size_kb = len(tflite_model) / 1024
225
+ print(f" TFLite model size: {size_kb:.1f} KB")
226
+
227
+ return tflite_model
228
+
229
+
230
+ def evaluate_tflite(tflite_path, test_inputs, test_labels, is_dual=False):
231
+ """Evaluate TFLite model accuracy and inference speed."""
232
+ print(f"\nEvaluating TFLite model: {tflite_path}")
233
+
234
+ interpreter = tf.lite.Interpreter(model_path=tflite_path)
235
+ interpreter.allocate_tensors()
236
+
237
+ input_details = interpreter.get_input_details()
238
+ output_details = interpreter.get_output_details()
239
+
240
+ print(f" Input details: {[(d['name'], d['shape'], d['dtype']) for d in input_details]}")
241
+ print(f" Output details: {[(d['name'], d['shape'], d['dtype']) for d in output_details]}")
242
+
243
+ predictions = []
244
+ num_samples = min(500, len(test_labels))
245
+
246
+ # Warmup
247
+ for _ in range(5):
248
+ if is_dual:
249
+ for idx, detail in enumerate(input_details):
250
+ if detail['dtype'] == np.uint8:
251
+ data = (test_inputs[idx][0:1] * 255).astype(np.uint8)
252
+ else:
253
+ data = test_inputs[idx][0:1].astype(np.float32)
254
+ interpreter.set_tensor(detail['index'], data)
255
+ else:
256
+ if input_details[0]['dtype'] == np.uint8:
257
+ data = (test_inputs[0:1] * 255).astype(np.uint8)
258
+ else:
259
+ data = test_inputs[0:1].astype(np.float32)
260
+ interpreter.set_tensor(input_details[0]['index'], data)
261
+ interpreter.invoke()
262
+
263
+ # Benchmark
264
+ start_time = time.time()
265
+ for i in range(num_samples):
266
+ if is_dual:
267
+ for idx, detail in enumerate(input_details):
268
+ if detail['dtype'] == np.uint8:
269
+ data = (test_inputs[idx][i:i+1] * 255).astype(np.uint8)
270
+ else:
271
+ data = test_inputs[idx][i:i+1].astype(np.float32)
272
+ interpreter.set_tensor(detail['index'], data)
273
+ else:
274
+ if input_details[0]['dtype'] == np.uint8:
275
+ data = (test_inputs[i:i+1] * 255).astype(np.uint8)
276
+ else:
277
+ data = test_inputs[i:i+1].astype(np.float32)
278
+ interpreter.set_tensor(input_details[0]['index'], data)
279
+
280
+ interpreter.invoke()
281
+ pred = interpreter.get_tensor(output_details[0]['index'])[0]
282
+ predictions.append(pred)
283
+
284
+ elapsed = time.time() - start_time
285
+
286
+ predictions = np.array(predictions)
287
+ labels = test_labels[:num_samples]
288
+
289
+ # Metrics
290
+ eucl_error = np.mean(np.sqrt(np.sum((predictions - labels) ** 2, axis=-1)))
291
+
292
+ # Screen error in mm (typical phone: 65mm x 140mm)
293
+ diff_mm = (predictions - labels) * np.array([65.0, 140.0])
294
+ screen_error = np.mean(np.sqrt(np.sum(diff_mm ** 2, axis=-1)))
295
+
296
+ # Screen error in cm
297
+ screen_error_cm = screen_error / 10.0
298
+
299
+ avg_inference_ms = (elapsed / num_samples) * 1000
300
+
301
+ print(f" Euclidean error (normalized): {eucl_error:.4f}")
302
+ print(f" Screen error: {screen_error:.1f} mm ({screen_error_cm:.2f} cm)")
303
+ print(f" Average inference time (CPU): {avg_inference_ms:.2f} ms")
304
+ print(f" FPS (CPU): {1000 / avg_inference_ms:.1f}")
305
+
306
+ return {
307
+ 'euclidean_error': float(eucl_error),
308
+ 'screen_error_mm': float(screen_error),
309
+ 'screen_error_cm': float(screen_error_cm),
310
+ 'avg_inference_ms': float(avg_inference_ms),
311
+ 'fps': float(1000 / avg_inference_ms),
312
+ 'num_test_samples': num_samples,
313
+ }
314
+
315
+
316
+ def main():
317
+ print("="*60)
318
+ print("GazeInception-Lite: Mobile Eye Gaze Estimation")
319
+ print("="*60)
320
+
321
+ # Configuration
322
+ NUM_TRAIN = 50000
323
+ NUM_VAL = 5000
324
+ NUM_TEST = 3000
325
+ EPOCHS_SINGLE = 80
326
+ EPOCHS_DUAL = 80
327
+ BATCH_SIZE = 128
328
+
329
+ OUTPUT_DIR = '/app/output'
330
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
331
+
332
+ # ==========================================
333
+ # Generate synthetic training data
334
+ # ==========================================
335
+ print("\n[1/6] Generating synthetic training data...")
336
+ print(f" Train: {NUM_TRAIN}, Val: {NUM_VAL}, Test: {NUM_TEST}")
337
+ print(f" Augmentations: dark(30%), glasses(25%), lazy_eye(15%), noise(50%)")
338
+
339
+ gen = SyntheticGazeDataGenerator(seed=42)
340
+
341
+ t0 = time.time()
342
+ train_data = gen.generate_dataset(NUM_TRAIN, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15)
343
+ print(f" Train data generated in {time.time()-t0:.1f}s")
344
+
345
+ gen_val = SyntheticGazeDataGenerator(seed=123)
346
+ val_data = gen_val.generate_dataset(NUM_VAL, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15)
347
+
348
+ gen_test = SyntheticGazeDataGenerator(seed=456)
349
+ test_data = gen_test.generate_dataset(NUM_TEST, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15)
350
+
351
+ # Also generate condition-specific test sets for robustness evaluation
352
+ gen_dark = SyntheticGazeDataGenerator(seed=789)
353
+ test_dark = gen_dark.generate_dataset(1000, dark_prob=1.0, with_glasses_prob=0.0, lazy_eye_prob=0.0)
354
+
355
+ gen_glasses = SyntheticGazeDataGenerator(seed=101)
356
+ test_glasses = gen_glasses.generate_dataset(1000, dark_prob=0.0, with_glasses_prob=1.0, lazy_eye_prob=0.0)
357
+
358
+ gen_lazy = SyntheticGazeDataGenerator(seed=202)
359
+ test_lazy = gen_lazy.generate_dataset(1000, dark_prob=0.0, with_glasses_prob=0.0, lazy_eye_prob=1.0)
360
+
361
+ print(f" All data generated. Total time: {time.time()-t0:.1f}s")
362
+
363
+ # ==========================================
364
+ # Train Single-Eye Model
365
+ # ==========================================
366
+ print("\n[2/6] Training single-eye model...")
367
+ single_model, single_history = train_single_eye_model(
368
+ train_data, val_data,
369
+ epochs=EPOCHS_SINGLE, batch_size=BATCH_SIZE,
370
+ output_dir=os.path.join(OUTPUT_DIR, 'single_eye')
371
+ )
372
+
373
+ # ==========================================
374
+ # Train Dual-Eye Model
375
+ # ==========================================
376
+ print("\n[3/6] Training dual-eye model...")
377
+ dual_model, dual_history = train_dual_eye_model(
378
+ train_data, val_data,
379
+ epochs=EPOCHS_DUAL, batch_size=64,
380
+ output_dir=os.path.join(OUTPUT_DIR, 'dual_eye')
381
+ )
382
+
383
+ # ==========================================
384
+ # Convert to TFLite
385
+ # ==========================================
386
+ print("\n[4/6] Converting models to TFLite...")
387
+
388
+ # Single eye - float16
389
+ single_tflite_f16_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_single_f16.tflite')
390
+ convert_to_tflite(single_model, single_tflite_f16_path, quantize=False)
391
+
392
+ # Single eye - INT8 quantized
393
+ single_tflite_int8_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_single_int8.tflite')
394
+ all_eyes = np.concatenate([test_data['left_eye'], test_data['right_eye']], axis=0)
395
+ convert_to_tflite(single_model, single_tflite_int8_path, quantize=True, test_data=all_eyes)
396
+
397
+ # Dual eye - float16
398
+ dual_tflite_f16_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_dual_f16.tflite')
399
+ convert_dual_eye_to_tflite(dual_model, dual_tflite_f16_path, quantize=False)
400
+
401
+ # Dual eye - INT8 quantized
402
+ dual_tflite_int8_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_dual_int8.tflite')
403
+ convert_dual_eye_to_tflite(dual_model, dual_tflite_int8_path, quantize=True, test_data=test_data)
404
+
405
+ # ==========================================
406
+ # Evaluate TFLite models
407
+ # ==========================================
408
+ print("\n[5/6] Evaluating TFLite models...")
409
+
410
+ results = {}
411
+
412
+ # Single eye evaluation
413
+ print("\n--- Single Eye Model (Float16) ---")
414
+ results['single_f16'] = evaluate_tflite(
415
+ single_tflite_f16_path, all_eyes[:3000], test_data['gaze']
416
+ )
417
+
418
+ print("\n--- Single Eye Model (INT8) ---")
419
+ results['single_int8'] = evaluate_tflite(
420
+ single_tflite_int8_path, all_eyes[:3000], test_data['gaze']
421
+ )
422
+
423
+ # Dual eye evaluation
424
+ print("\n--- Dual Eye Model (Float16) ---")
425
+ dual_inputs = [test_data['left_eye'], test_data['right_eye'], test_data['face']]
426
+ results['dual_f16'] = evaluate_tflite(
427
+ dual_tflite_f16_path, dual_inputs, test_data['gaze'], is_dual=True
428
+ )
429
+
430
+ print("\n--- Dual Eye Model (INT8) ---")
431
+ results['dual_int8'] = evaluate_tflite(
432
+ dual_tflite_int8_path, dual_inputs, test_data['gaze'], is_dual=True
433
+ )
434
+
435
+ # Condition-specific evaluation (dual model, float16)
436
+ print("\n--- Robustness Evaluation (Dual Eye, Float16) ---")
437
+ print("\n [Dark conditions]")
438
+ dark_inputs = [test_dark['left_eye'], test_dark['right_eye'], test_dark['face']]
439
+ results['dual_f16_dark'] = evaluate_tflite(
440
+ dual_tflite_f16_path, dark_inputs, test_dark['gaze'], is_dual=True
441
+ )
442
+
443
+ print("\n [With glasses]")
444
+ glasses_inputs = [test_glasses['left_eye'], test_glasses['right_eye'], test_glasses['face']]
445
+ results['dual_f16_glasses'] = evaluate_tflite(
446
+ dual_tflite_f16_path, glasses_inputs, test_glasses['gaze'], is_dual=True
447
+ )
448
+
449
+ print("\n [Lazy eye / strabismus]")
450
+ lazy_inputs = [test_lazy['left_eye'], test_lazy['right_eye'], test_lazy['face']]
451
+ results['dual_f16_lazy_eye'] = evaluate_tflite(
452
+ dual_tflite_f16_path, lazy_inputs, test_lazy['gaze'], is_dual=True
453
+ )
454
+
455
+ # ==========================================
456
+ # Save results and metadata
457
+ # ==========================================
458
+ print("\n[6/6] Saving results...")
459
+
460
+ # Model card metadata
461
+ metadata = {
462
+ 'model_name': 'GazeInception-Lite',
463
+ 'task': 'eye-gaze-estimation',
464
+ 'description': 'Lightweight TFLite model for mobile eye gaze estimation on phone screens',
465
+ 'architecture': {
466
+ 'type': 'Gated Inception Network with Coordinate Attention',
467
+ 'single_eye_params': int(single_model.count_params()),
468
+ 'dual_eye_params': int(dual_model.count_params()),
469
+ 'input_size': '64x64x3',
470
+ 'features': [
471
+ 'Gated Inception blocks (learned branch gating to skip useless compute)',
472
+ 'Coordinate Attention for spatial gaze awareness',
473
+ 'Depthwise separable convolutions for efficiency',
474
+ 'Dual-eye processing with shared weights (handles lazy eye)',
475
+ 'Face context branch (head pose proxy)'
476
+ ]
477
+ },
478
+ 'training': {
479
+ 'dataset': 'Synthetic (50K train, 5K val, 3K test)',
480
+ 'augmentations': [
481
+ 'Dark/low-light conditions (30% probability, 15-50% brightness)',
482
+ 'Glasses overlay synthesis (25% probability, 10 frame styles)',
483
+ 'Lazy eye/strabismus simulation (15% probability)',
484
+ 'CMOS sensor noise (50% probability)',
485
+ 'Illumination perturbation (directional light gradients)',
486
+ 'Diverse skin tones (12 variations)',
487
+ 'Diverse eye colors (7 variations)'
488
+ ],
489
+ 'optimizer': 'Adam with Cosine Decay LR',
490
+ 'initial_lr': 1e-3,
491
+ 'loss': 'MSE',
492
+ 'epochs': f'{EPOCHS_SINGLE} (single) / {EPOCHS_DUAL} (dual)',
493
+ },
494
+ 'tflite_models': {
495
+ 'single_eye_f16': {
496
+ 'file': 'gaze_inception_lite_single_f16.tflite',
497
+ 'size_kb': os.path.getsize(single_tflite_f16_path) / 1024,
498
+ 'quantization': 'float16',
499
+ },
500
+ 'single_eye_int8': {
501
+ 'file': 'gaze_inception_lite_single_int8.tflite',
502
+ 'size_kb': os.path.getsize(single_tflite_int8_path) / 1024,
503
+ 'quantization': 'int8',
504
+ },
505
+ 'dual_eye_f16': {
506
+ 'file': 'gaze_inception_lite_dual_f16.tflite',
507
+ 'size_kb': os.path.getsize(dual_tflite_f16_path) / 1024,
508
+ 'quantization': 'float16',
509
+ },
510
+ 'dual_eye_int8': {
511
+ 'file': 'gaze_inception_lite_dual_int8.tflite',
512
+ 'size_kb': os.path.getsize(dual_tflite_int8_path) / 1024,
513
+ 'quantization': 'int8',
514
+ },
515
+ },
516
+ 'evaluation_results': results,
517
+ 'references': [
518
+ 'AGE Framework - arxiv:2603.26945',
519
+ 'Gated Compression Layers - arxiv:2303.08970',
520
+ 'iTracker / GazeCapture - arxiv:1606.05814',
521
+ 'Coordinate Attention - Hou et al. 2021',
522
+ 'MobileNetV2 - arxiv:1801.04381',
523
+ ]
524
+ }
525
+
526
+ with open(os.path.join(OUTPUT_DIR, 'metadata.json'), 'w') as f:
527
+ json.dump(metadata, f, indent=2)
528
+
529
+ # Save training history
530
+ for name, hist in [('single', single_history), ('dual', dual_history)]:
531
+ hist_dict = {k: [float(v) for v in vals] for k, vals in hist.history.items()}
532
+ with open(os.path.join(OUTPUT_DIR, f'{name}_history.json'), 'w') as f:
533
+ json.dump(hist_dict, f, indent=2)
534
+
535
+ # Print summary
536
+ print("\n" + "="*60)
537
+ print("TRAINING COMPLETE - SUMMARY")
538
+ print("="*60)
539
+
540
+ print(f"\nSingle-Eye Model:")
541
+ print(f" Parameters: {single_model.count_params():,}")
542
+ print(f" F16 TFLite: {os.path.getsize(single_tflite_f16_path)/1024:.1f} KB")
543
+ print(f" INT8 TFLite: {os.path.getsize(single_tflite_int8_path)/1024:.1f} KB")
544
+ if 'single_int8' in results:
545
+ r = results['single_int8']
546
+ print(f" Screen error: {r['screen_error_mm']:.1f} mm")
547
+ print(f" Inference: {r['avg_inference_ms']:.2f} ms ({r['fps']:.0f} FPS)")
548
+
549
+ print(f"\nDual-Eye Model:")
550
+ print(f" Parameters: {dual_model.count_params():,}")
551
+ print(f" F16 TFLite: {os.path.getsize(dual_tflite_f16_path)/1024:.1f} KB")
552
+ print(f" INT8 TFLite: {os.path.getsize(dual_tflite_int8_path)/1024:.1f} KB")
553
+ if 'dual_int8' in results:
554
+ r = results['dual_int8']
555
+ print(f" Screen error: {r['screen_error_mm']:.1f} mm")
556
+ print(f" Inference: {r['avg_inference_ms']:.2f} ms ({r['fps']:.0f} FPS)")
557
+
558
+ print(f"\nRobustness (Dual Eye):")
559
+ for condition in ['dark', 'glasses', 'lazy_eye']:
560
+ key = f'dual_f16_{condition}'
561
+ if key in results:
562
+ r = results[key]
563
+ print(f" {condition}: {r['screen_error_mm']:.1f} mm error")
564
+
565
+ print(f"\nOutput directory: {OUTPUT_DIR}")
566
+ print(f"Files: {os.listdir(OUTPUT_DIR)}")
567
+
568
+
569
+ if __name__ == '__main__':
570
+ main()