""" Training script for GazeInception-Lite model. Trains both: 1. Single-eye model (89K params) - ultralight for constrained devices 2. Dual-eye model (137K params) - better accuracy with face context + lazy eye handling Features: - Gated Inception blocks with learned branch gating - Coordinate Attention for spatial gaze awareness - Synthetic data with: dark conditions, glasses, lazy eye, sensor noise - TFLite conversion with full integer quantization - Push to Hugging Face Hub Based on: - AGE framework (arxiv:2603.26945) - augmentation pipeline, multi-task approach - Gated Compression Layers (arxiv:2303.08970) - gating mechanism - iTracker (arxiv:1606.05814) - dual-eye + face architecture """ import os import json import time import numpy as np import tensorflow as tf from tensorflow import keras from pathlib import Path # Import our modules from model import build_gaze_inception_lite, build_dual_eye_model from data_generator import SyntheticGazeDataGenerator, create_tf_dataset, create_single_eye_dataset def euclidean_distance_metric(y_true, y_pred): """Euclidean distance in normalized [0,1] coordinates.""" return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(y_true - y_pred), axis=-1))) def screen_error_mm(y_true, y_pred, screen_w_mm=65.0, screen_h_mm=140.0): """Error in mm assuming typical phone screen (65mm x 140mm).""" diff = y_true - y_pred diff_mm = diff * tf.constant([screen_w_mm, screen_h_mm]) return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(diff_mm), axis=-1))) def train_single_eye_model(train_data, val_data, epochs=100, batch_size=128, output_dir='models/single_eye'): """Train the lightweight single-eye model.""" print("\n" + "="*60) print("Training Single-Eye GazeInception-Lite Model") print("="*60) os.makedirs(output_dir, exist_ok=True) model = build_gaze_inception_lite(input_shape=(64, 64, 3), num_outputs=2) # Cosine decay learning rate lr_schedule = keras.optimizers.schedules.CosineDecay( initial_learning_rate=1e-3, decay_steps=epochs * (len(train_data['gaze']) * 2 // batch_size), alpha=1e-5 ) optimizer = keras.optimizers.Adam(learning_rate=lr_schedule) model.compile( optimizer=optimizer, loss='mse', metrics=[euclidean_distance_metric] ) # Create datasets train_ds = create_single_eye_dataset(train_data, batch_size=batch_size, shuffle=True) val_ds = create_single_eye_dataset(val_data, batch_size=batch_size, shuffle=False) callbacks = [ keras.callbacks.ModelCheckpoint( os.path.join(output_dir, 'best_model.keras'), monitor='val_euclidean_distance_metric', save_best_only=True, mode='min', verbose=1 ), keras.callbacks.ReduceLROnPlateau( monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1 ), keras.callbacks.EarlyStopping( monitor='val_euclidean_distance_metric', patience=20, restore_best_weights=True, verbose=1 ), ] history = model.fit( train_ds, validation_data=val_ds, epochs=epochs, callbacks=callbacks, verbose=1 ) # Save final model model.save(os.path.join(output_dir, 'final_model.keras')) return model, history def train_dual_eye_model(train_data, val_data, epochs=100, batch_size=64, output_dir='models/dual_eye'): """Train the dual-eye model with face context.""" print("\n" + "="*60) print("Training Dual-Eye GazeInception-Lite Model") print("="*60) os.makedirs(output_dir, exist_ok=True) model = build_dual_eye_model( eye_shape=(64, 64, 3), face_shape=(64, 64, 3), num_outputs=2 ) lr_schedule = keras.optimizers.schedules.CosineDecay( initial_learning_rate=1e-3, decay_steps=epochs * (len(train_data['gaze']) // batch_size), alpha=1e-5 ) optimizer = keras.optimizers.Adam(learning_rate=lr_schedule) model.compile( optimizer=optimizer, loss='mse', metrics=[euclidean_distance_metric] ) # Create datasets train_ds = create_tf_dataset(train_data, batch_size=batch_size, shuffle=True) val_ds = create_tf_dataset(val_data, batch_size=batch_size, shuffle=False) callbacks = [ keras.callbacks.ModelCheckpoint( os.path.join(output_dir, 'best_model.keras'), monitor='val_euclidean_distance_metric', save_best_only=True, mode='min', verbose=1 ), keras.callbacks.ReduceLROnPlateau( monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1 ), keras.callbacks.EarlyStopping( monitor='val_euclidean_distance_metric', patience=20, restore_best_weights=True, verbose=1 ), ] history = model.fit( train_ds, validation_data=val_ds, epochs=epochs, callbacks=callbacks, verbose=1 ) model.save(os.path.join(output_dir, 'final_model.keras')) return model, history def convert_to_tflite(keras_model, output_path, quantize=True, test_data=None): """Convert Keras model to TFLite with optional quantization.""" print(f"\nConverting to TFLite: {output_path}") converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) # Optimization settings for mobile converter.optimizations = [tf.lite.Optimize.DEFAULT] if quantize and test_data is not None: # Full integer quantization (fastest on mobile) def representative_dataset_gen(): for i in range(min(200, len(test_data))): yield [test_data[i:i+1].astype(np.float32)] converter.representative_dataset = representative_dataset_gen converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS_INT8 ] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.float32 print(" Using INT8 quantization for maximum mobile speed") tflite_model = converter.convert() os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'wb') as f: f.write(tflite_model) size_kb = len(tflite_model) / 1024 print(f" TFLite model size: {size_kb:.1f} KB") return tflite_model def convert_dual_eye_to_tflite(keras_model, output_path, quantize=True, test_data=None): """Convert dual-eye model to TFLite.""" print(f"\nConverting dual-eye model to TFLite: {output_path}") converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) converter.optimizations = [tf.lite.Optimize.DEFAULT] if quantize and test_data is not None: def representative_dataset_gen(): for i in range(min(200, len(test_data['left_eye']))): yield [ test_data['left_eye'][i:i+1].astype(np.float32), test_data['right_eye'][i:i+1].astype(np.float32), test_data['face'][i:i+1].astype(np.float32), ] converter.representative_dataset = representative_dataset_gen converter.target_spec.supported_ops = [ tf.lite.OpsSet.TFLITE_BUILTINS_INT8, tf.lite.OpsSet.TFLITE_BUILTINS, # fallback for unsupported ops ] converter.inference_input_type = tf.uint8 converter.inference_output_type = tf.float32 print(" Using INT8 quantization for maximum mobile speed") tflite_model = converter.convert() os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'wb') as f: f.write(tflite_model) size_kb = len(tflite_model) / 1024 print(f" TFLite model size: {size_kb:.1f} KB") return tflite_model def evaluate_tflite(tflite_path, test_inputs, test_labels, is_dual=False): """Evaluate TFLite model accuracy and inference speed.""" print(f"\nEvaluating TFLite model: {tflite_path}") interpreter = tf.lite.Interpreter(model_path=tflite_path) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() print(f" Input details: {[(d['name'], d['shape'], d['dtype']) for d in input_details]}") print(f" Output details: {[(d['name'], d['shape'], d['dtype']) for d in output_details]}") predictions = [] num_samples = min(500, len(test_labels)) # Warmup for _ in range(5): if is_dual: for idx, detail in enumerate(input_details): if detail['dtype'] == np.uint8: data = (test_inputs[idx][0:1] * 255).astype(np.uint8) else: data = test_inputs[idx][0:1].astype(np.float32) interpreter.set_tensor(detail['index'], data) else: if input_details[0]['dtype'] == np.uint8: data = (test_inputs[0:1] * 255).astype(np.uint8) else: data = test_inputs[0:1].astype(np.float32) interpreter.set_tensor(input_details[0]['index'], data) interpreter.invoke() # Benchmark start_time = time.time() for i in range(num_samples): if is_dual: for idx, detail in enumerate(input_details): if detail['dtype'] == np.uint8: data = (test_inputs[idx][i:i+1] * 255).astype(np.uint8) else: data = test_inputs[idx][i:i+1].astype(np.float32) interpreter.set_tensor(detail['index'], data) else: if input_details[0]['dtype'] == np.uint8: data = (test_inputs[i:i+1] * 255).astype(np.uint8) else: data = test_inputs[i:i+1].astype(np.float32) interpreter.set_tensor(input_details[0]['index'], data) interpreter.invoke() pred = interpreter.get_tensor(output_details[0]['index'])[0] predictions.append(pred) elapsed = time.time() - start_time predictions = np.array(predictions) labels = test_labels[:num_samples] # Metrics eucl_error = np.mean(np.sqrt(np.sum((predictions - labels) ** 2, axis=-1))) # Screen error in mm (typical phone: 65mm x 140mm) diff_mm = (predictions - labels) * np.array([65.0, 140.0]) screen_error = np.mean(np.sqrt(np.sum(diff_mm ** 2, axis=-1))) # Screen error in cm screen_error_cm = screen_error / 10.0 avg_inference_ms = (elapsed / num_samples) * 1000 print(f" Euclidean error (normalized): {eucl_error:.4f}") print(f" Screen error: {screen_error:.1f} mm ({screen_error_cm:.2f} cm)") print(f" Average inference time (CPU): {avg_inference_ms:.2f} ms") print(f" FPS (CPU): {1000 / avg_inference_ms:.1f}") return { 'euclidean_error': float(eucl_error), 'screen_error_mm': float(screen_error), 'screen_error_cm': float(screen_error_cm), 'avg_inference_ms': float(avg_inference_ms), 'fps': float(1000 / avg_inference_ms), 'num_test_samples': num_samples, } def main(): print("="*60) print("GazeInception-Lite: Mobile Eye Gaze Estimation") print("="*60) # Configuration NUM_TRAIN = 50000 NUM_VAL = 5000 NUM_TEST = 3000 EPOCHS_SINGLE = 80 EPOCHS_DUAL = 80 BATCH_SIZE = 128 OUTPUT_DIR = '/app/output' os.makedirs(OUTPUT_DIR, exist_ok=True) # ========================================== # Generate synthetic training data # ========================================== print("\n[1/6] Generating synthetic training data...") print(f" Train: {NUM_TRAIN}, Val: {NUM_VAL}, Test: {NUM_TEST}") print(f" Augmentations: dark(30%), glasses(25%), lazy_eye(15%), noise(50%)") gen = SyntheticGazeDataGenerator(seed=42) t0 = time.time() train_data = gen.generate_dataset(NUM_TRAIN, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15) print(f" Train data generated in {time.time()-t0:.1f}s") gen_val = SyntheticGazeDataGenerator(seed=123) val_data = gen_val.generate_dataset(NUM_VAL, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15) gen_test = SyntheticGazeDataGenerator(seed=456) test_data = gen_test.generate_dataset(NUM_TEST, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15) # Also generate condition-specific test sets for robustness evaluation gen_dark = SyntheticGazeDataGenerator(seed=789) test_dark = gen_dark.generate_dataset(1000, dark_prob=1.0, with_glasses_prob=0.0, lazy_eye_prob=0.0) gen_glasses = SyntheticGazeDataGenerator(seed=101) test_glasses = gen_glasses.generate_dataset(1000, dark_prob=0.0, with_glasses_prob=1.0, lazy_eye_prob=0.0) gen_lazy = SyntheticGazeDataGenerator(seed=202) test_lazy = gen_lazy.generate_dataset(1000, dark_prob=0.0, with_glasses_prob=0.0, lazy_eye_prob=1.0) print(f" All data generated. Total time: {time.time()-t0:.1f}s") # ========================================== # Train Single-Eye Model # ========================================== print("\n[2/6] Training single-eye model...") single_model, single_history = train_single_eye_model( train_data, val_data, epochs=EPOCHS_SINGLE, batch_size=BATCH_SIZE, output_dir=os.path.join(OUTPUT_DIR, 'single_eye') ) # ========================================== # Train Dual-Eye Model # ========================================== print("\n[3/6] Training dual-eye model...") dual_model, dual_history = train_dual_eye_model( train_data, val_data, epochs=EPOCHS_DUAL, batch_size=64, output_dir=os.path.join(OUTPUT_DIR, 'dual_eye') ) # ========================================== # Convert to TFLite # ========================================== print("\n[4/6] Converting models to TFLite...") # Single eye - float16 single_tflite_f16_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_single_f16.tflite') convert_to_tflite(single_model, single_tflite_f16_path, quantize=False) # Single eye - INT8 quantized single_tflite_int8_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_single_int8.tflite') all_eyes = np.concatenate([test_data['left_eye'], test_data['right_eye']], axis=0) convert_to_tflite(single_model, single_tflite_int8_path, quantize=True, test_data=all_eyes) # Dual eye - float16 dual_tflite_f16_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_dual_f16.tflite') convert_dual_eye_to_tflite(dual_model, dual_tflite_f16_path, quantize=False) # Dual eye - INT8 quantized dual_tflite_int8_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_dual_int8.tflite') convert_dual_eye_to_tflite(dual_model, dual_tflite_int8_path, quantize=True, test_data=test_data) # ========================================== # Evaluate TFLite models # ========================================== print("\n[5/6] Evaluating TFLite models...") results = {} # Single eye evaluation print("\n--- Single Eye Model (Float16) ---") results['single_f16'] = evaluate_tflite( single_tflite_f16_path, all_eyes[:3000], test_data['gaze'] ) print("\n--- Single Eye Model (INT8) ---") results['single_int8'] = evaluate_tflite( single_tflite_int8_path, all_eyes[:3000], test_data['gaze'] ) # Dual eye evaluation print("\n--- Dual Eye Model (Float16) ---") dual_inputs = [test_data['left_eye'], test_data['right_eye'], test_data['face']] results['dual_f16'] = evaluate_tflite( dual_tflite_f16_path, dual_inputs, test_data['gaze'], is_dual=True ) print("\n--- Dual Eye Model (INT8) ---") results['dual_int8'] = evaluate_tflite( dual_tflite_int8_path, dual_inputs, test_data['gaze'], is_dual=True ) # Condition-specific evaluation (dual model, float16) print("\n--- Robustness Evaluation (Dual Eye, Float16) ---") print("\n [Dark conditions]") dark_inputs = [test_dark['left_eye'], test_dark['right_eye'], test_dark['face']] results['dual_f16_dark'] = evaluate_tflite( dual_tflite_f16_path, dark_inputs, test_dark['gaze'], is_dual=True ) print("\n [With glasses]") glasses_inputs = [test_glasses['left_eye'], test_glasses['right_eye'], test_glasses['face']] results['dual_f16_glasses'] = evaluate_tflite( dual_tflite_f16_path, glasses_inputs, test_glasses['gaze'], is_dual=True ) print("\n [Lazy eye / strabismus]") lazy_inputs = [test_lazy['left_eye'], test_lazy['right_eye'], test_lazy['face']] results['dual_f16_lazy_eye'] = evaluate_tflite( dual_tflite_f16_path, lazy_inputs, test_lazy['gaze'], is_dual=True ) # ========================================== # Save results and metadata # ========================================== print("\n[6/6] Saving results...") # Model card metadata metadata = { 'model_name': 'GazeInception-Lite', 'task': 'eye-gaze-estimation', 'description': 'Lightweight TFLite model for mobile eye gaze estimation on phone screens', 'architecture': { 'type': 'Gated Inception Network with Coordinate Attention', 'single_eye_params': int(single_model.count_params()), 'dual_eye_params': int(dual_model.count_params()), 'input_size': '64x64x3', 'features': [ 'Gated Inception blocks (learned branch gating to skip useless compute)', 'Coordinate Attention for spatial gaze awareness', 'Depthwise separable convolutions for efficiency', 'Dual-eye processing with shared weights (handles lazy eye)', 'Face context branch (head pose proxy)' ] }, 'training': { 'dataset': 'Synthetic (50K train, 5K val, 3K test)', 'augmentations': [ 'Dark/low-light conditions (30% probability, 15-50% brightness)', 'Glasses overlay synthesis (25% probability, 10 frame styles)', 'Lazy eye/strabismus simulation (15% probability)', 'CMOS sensor noise (50% probability)', 'Illumination perturbation (directional light gradients)', 'Diverse skin tones (12 variations)', 'Diverse eye colors (7 variations)' ], 'optimizer': 'Adam with Cosine Decay LR', 'initial_lr': 1e-3, 'loss': 'MSE', 'epochs': f'{EPOCHS_SINGLE} (single) / {EPOCHS_DUAL} (dual)', }, 'tflite_models': { 'single_eye_f16': { 'file': 'gaze_inception_lite_single_f16.tflite', 'size_kb': os.path.getsize(single_tflite_f16_path) / 1024, 'quantization': 'float16', }, 'single_eye_int8': { 'file': 'gaze_inception_lite_single_int8.tflite', 'size_kb': os.path.getsize(single_tflite_int8_path) / 1024, 'quantization': 'int8', }, 'dual_eye_f16': { 'file': 'gaze_inception_lite_dual_f16.tflite', 'size_kb': os.path.getsize(dual_tflite_f16_path) / 1024, 'quantization': 'float16', }, 'dual_eye_int8': { 'file': 'gaze_inception_lite_dual_int8.tflite', 'size_kb': os.path.getsize(dual_tflite_int8_path) / 1024, 'quantization': 'int8', }, }, 'evaluation_results': results, 'references': [ 'AGE Framework - arxiv:2603.26945', 'Gated Compression Layers - arxiv:2303.08970', 'iTracker / GazeCapture - arxiv:1606.05814', 'Coordinate Attention - Hou et al. 2021', 'MobileNetV2 - arxiv:1801.04381', ] } with open(os.path.join(OUTPUT_DIR, 'metadata.json'), 'w') as f: json.dump(metadata, f, indent=2) # Save training history for name, hist in [('single', single_history), ('dual', dual_history)]: hist_dict = {k: [float(v) for v in vals] for k, vals in hist.history.items()} with open(os.path.join(OUTPUT_DIR, f'{name}_history.json'), 'w') as f: json.dump(hist_dict, f, indent=2) # Print summary print("\n" + "="*60) print("TRAINING COMPLETE - SUMMARY") print("="*60) print(f"\nSingle-Eye Model:") print(f" Parameters: {single_model.count_params():,}") print(f" F16 TFLite: {os.path.getsize(single_tflite_f16_path)/1024:.1f} KB") print(f" INT8 TFLite: {os.path.getsize(single_tflite_int8_path)/1024:.1f} KB") if 'single_int8' in results: r = results['single_int8'] print(f" Screen error: {r['screen_error_mm']:.1f} mm") print(f" Inference: {r['avg_inference_ms']:.2f} ms ({r['fps']:.0f} FPS)") print(f"\nDual-Eye Model:") print(f" Parameters: {dual_model.count_params():,}") print(f" F16 TFLite: {os.path.getsize(dual_tflite_f16_path)/1024:.1f} KB") print(f" INT8 TFLite: {os.path.getsize(dual_tflite_int8_path)/1024:.1f} KB") if 'dual_int8' in results: r = results['dual_int8'] print(f" Screen error: {r['screen_error_mm']:.1f} mm") print(f" Inference: {r['avg_inference_ms']:.2f} ms ({r['fps']:.0f} FPS)") print(f"\nRobustness (Dual Eye):") for condition in ['dark', 'glasses', 'lazy_eye']: key = f'dual_f16_{condition}' if key in results: r = results[key] print(f" {condition}: {r['screen_error_mm']:.1f} mm error") print(f"\nOutput directory: {OUTPUT_DIR}") print(f"Files: {os.listdir(OUTPUT_DIR)}") if __name__ == '__main__': main()