| """ |
| Training script for GazeInception-Lite model. |
| |
| Trains both: |
| 1. Single-eye model (89K params) - ultralight for constrained devices |
| 2. Dual-eye model (137K params) - better accuracy with face context + lazy eye handling |
| |
| Features: |
| - Gated Inception blocks with learned branch gating |
| - Coordinate Attention for spatial gaze awareness |
| - Synthetic data with: dark conditions, glasses, lazy eye, sensor noise |
| - TFLite conversion with full integer quantization |
| - Push to Hugging Face Hub |
| |
| Based on: |
| - AGE framework (arxiv:2603.26945) - augmentation pipeline, multi-task approach |
| - Gated Compression Layers (arxiv:2303.08970) - gating mechanism |
| - iTracker (arxiv:1606.05814) - dual-eye + face architecture |
| """ |
|
|
| import os |
| import json |
| import time |
| import numpy as np |
| import tensorflow as tf |
| from tensorflow import keras |
| from pathlib import Path |
|
|
| |
| from model import build_gaze_inception_lite, build_dual_eye_model |
| from data_generator import SyntheticGazeDataGenerator, create_tf_dataset, create_single_eye_dataset |
|
|
|
|
| def euclidean_distance_metric(y_true, y_pred): |
| """Euclidean distance in normalized [0,1] coordinates.""" |
| return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(y_true - y_pred), axis=-1))) |
|
|
|
|
| def screen_error_mm(y_true, y_pred, screen_w_mm=65.0, screen_h_mm=140.0): |
| """Error in mm assuming typical phone screen (65mm x 140mm).""" |
| diff = y_true - y_pred |
| diff_mm = diff * tf.constant([screen_w_mm, screen_h_mm]) |
| return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(diff_mm), axis=-1))) |
|
|
|
|
| def train_single_eye_model(train_data, val_data, epochs=100, batch_size=128, |
| output_dir='models/single_eye'): |
| """Train the lightweight single-eye model.""" |
| print("\n" + "="*60) |
| print("Training Single-Eye GazeInception-Lite Model") |
| print("="*60) |
| |
| os.makedirs(output_dir, exist_ok=True) |
| |
| model = build_gaze_inception_lite(input_shape=(64, 64, 3), num_outputs=2) |
| |
| |
| lr_schedule = keras.optimizers.schedules.CosineDecay( |
| initial_learning_rate=1e-3, |
| decay_steps=epochs * (len(train_data['gaze']) * 2 // batch_size), |
| alpha=1e-5 |
| ) |
| |
| optimizer = keras.optimizers.Adam(learning_rate=lr_schedule) |
| |
| model.compile( |
| optimizer=optimizer, |
| loss='mse', |
| metrics=[euclidean_distance_metric] |
| ) |
| |
| |
| train_ds = create_single_eye_dataset(train_data, batch_size=batch_size, shuffle=True) |
| val_ds = create_single_eye_dataset(val_data, batch_size=batch_size, shuffle=False) |
| |
| callbacks = [ |
| keras.callbacks.ModelCheckpoint( |
| os.path.join(output_dir, 'best_model.keras'), |
| monitor='val_euclidean_distance_metric', |
| save_best_only=True, mode='min', verbose=1 |
| ), |
| keras.callbacks.ReduceLROnPlateau( |
| monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1 |
| ), |
| keras.callbacks.EarlyStopping( |
| monitor='val_euclidean_distance_metric', patience=20, |
| restore_best_weights=True, verbose=1 |
| ), |
| ] |
| |
| history = model.fit( |
| train_ds, validation_data=val_ds, |
| epochs=epochs, callbacks=callbacks, verbose=1 |
| ) |
| |
| |
| model.save(os.path.join(output_dir, 'final_model.keras')) |
| |
| return model, history |
|
|
|
|
| def train_dual_eye_model(train_data, val_data, epochs=100, batch_size=64, |
| output_dir='models/dual_eye'): |
| """Train the dual-eye model with face context.""" |
| print("\n" + "="*60) |
| print("Training Dual-Eye GazeInception-Lite Model") |
| print("="*60) |
| |
| os.makedirs(output_dir, exist_ok=True) |
| |
| model = build_dual_eye_model( |
| eye_shape=(64, 64, 3), face_shape=(64, 64, 3), num_outputs=2 |
| ) |
| |
| lr_schedule = keras.optimizers.schedules.CosineDecay( |
| initial_learning_rate=1e-3, |
| decay_steps=epochs * (len(train_data['gaze']) // batch_size), |
| alpha=1e-5 |
| ) |
| |
| optimizer = keras.optimizers.Adam(learning_rate=lr_schedule) |
| |
| model.compile( |
| optimizer=optimizer, |
| loss='mse', |
| metrics=[euclidean_distance_metric] |
| ) |
| |
| |
| train_ds = create_tf_dataset(train_data, batch_size=batch_size, shuffle=True) |
| val_ds = create_tf_dataset(val_data, batch_size=batch_size, shuffle=False) |
| |
| callbacks = [ |
| keras.callbacks.ModelCheckpoint( |
| os.path.join(output_dir, 'best_model.keras'), |
| monitor='val_euclidean_distance_metric', |
| save_best_only=True, mode='min', verbose=1 |
| ), |
| keras.callbacks.ReduceLROnPlateau( |
| monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1 |
| ), |
| keras.callbacks.EarlyStopping( |
| monitor='val_euclidean_distance_metric', patience=20, |
| restore_best_weights=True, verbose=1 |
| ), |
| ] |
| |
| history = model.fit( |
| train_ds, validation_data=val_ds, |
| epochs=epochs, callbacks=callbacks, verbose=1 |
| ) |
| |
| model.save(os.path.join(output_dir, 'final_model.keras')) |
| |
| return model, history |
|
|
|
|
| def convert_to_tflite(keras_model, output_path, quantize=True, test_data=None): |
| """Convert Keras model to TFLite with optional quantization.""" |
| print(f"\nConverting to TFLite: {output_path}") |
| |
| converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) |
| |
| |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] |
| |
| if quantize and test_data is not None: |
| |
| def representative_dataset_gen(): |
| for i in range(min(200, len(test_data))): |
| yield [test_data[i:i+1].astype(np.float32)] |
| |
| converter.representative_dataset = representative_dataset_gen |
| converter.target_spec.supported_ops = [ |
| tf.lite.OpsSet.TFLITE_BUILTINS_INT8 |
| ] |
| converter.inference_input_type = tf.uint8 |
| converter.inference_output_type = tf.float32 |
| print(" Using INT8 quantization for maximum mobile speed") |
| |
| tflite_model = converter.convert() |
| |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| with open(output_path, 'wb') as f: |
| f.write(tflite_model) |
| |
| size_kb = len(tflite_model) / 1024 |
| print(f" TFLite model size: {size_kb:.1f} KB") |
| |
| return tflite_model |
|
|
|
|
| def convert_dual_eye_to_tflite(keras_model, output_path, quantize=True, test_data=None): |
| """Convert dual-eye model to TFLite.""" |
| print(f"\nConverting dual-eye model to TFLite: {output_path}") |
| |
| converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) |
| converter.optimizations = [tf.lite.Optimize.DEFAULT] |
| |
| if quantize and test_data is not None: |
| def representative_dataset_gen(): |
| for i in range(min(200, len(test_data['left_eye']))): |
| yield [ |
| test_data['left_eye'][i:i+1].astype(np.float32), |
| test_data['right_eye'][i:i+1].astype(np.float32), |
| test_data['face'][i:i+1].astype(np.float32), |
| ] |
| |
| converter.representative_dataset = representative_dataset_gen |
| converter.target_spec.supported_ops = [ |
| tf.lite.OpsSet.TFLITE_BUILTINS_INT8, |
| tf.lite.OpsSet.TFLITE_BUILTINS, |
| ] |
| converter.inference_input_type = tf.uint8 |
| converter.inference_output_type = tf.float32 |
| print(" Using INT8 quantization for maximum mobile speed") |
| |
| tflite_model = converter.convert() |
| |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| with open(output_path, 'wb') as f: |
| f.write(tflite_model) |
| |
| size_kb = len(tflite_model) / 1024 |
| print(f" TFLite model size: {size_kb:.1f} KB") |
| |
| return tflite_model |
|
|
|
|
| def evaluate_tflite(tflite_path, test_inputs, test_labels, is_dual=False): |
| """Evaluate TFLite model accuracy and inference speed.""" |
| print(f"\nEvaluating TFLite model: {tflite_path}") |
| |
| interpreter = tf.lite.Interpreter(model_path=tflite_path) |
| interpreter.allocate_tensors() |
| |
| input_details = interpreter.get_input_details() |
| output_details = interpreter.get_output_details() |
| |
| print(f" Input details: {[(d['name'], d['shape'], d['dtype']) for d in input_details]}") |
| print(f" Output details: {[(d['name'], d['shape'], d['dtype']) for d in output_details]}") |
| |
| predictions = [] |
| num_samples = min(500, len(test_labels)) |
| |
| |
| for _ in range(5): |
| if is_dual: |
| for idx, detail in enumerate(input_details): |
| if detail['dtype'] == np.uint8: |
| data = (test_inputs[idx][0:1] * 255).astype(np.uint8) |
| else: |
| data = test_inputs[idx][0:1].astype(np.float32) |
| interpreter.set_tensor(detail['index'], data) |
| else: |
| if input_details[0]['dtype'] == np.uint8: |
| data = (test_inputs[0:1] * 255).astype(np.uint8) |
| else: |
| data = test_inputs[0:1].astype(np.float32) |
| interpreter.set_tensor(input_details[0]['index'], data) |
| interpreter.invoke() |
| |
| |
| start_time = time.time() |
| for i in range(num_samples): |
| if is_dual: |
| for idx, detail in enumerate(input_details): |
| if detail['dtype'] == np.uint8: |
| data = (test_inputs[idx][i:i+1] * 255).astype(np.uint8) |
| else: |
| data = test_inputs[idx][i:i+1].astype(np.float32) |
| interpreter.set_tensor(detail['index'], data) |
| else: |
| if input_details[0]['dtype'] == np.uint8: |
| data = (test_inputs[i:i+1] * 255).astype(np.uint8) |
| else: |
| data = test_inputs[i:i+1].astype(np.float32) |
| interpreter.set_tensor(input_details[0]['index'], data) |
| |
| interpreter.invoke() |
| pred = interpreter.get_tensor(output_details[0]['index'])[0] |
| predictions.append(pred) |
| |
| elapsed = time.time() - start_time |
| |
| predictions = np.array(predictions) |
| labels = test_labels[:num_samples] |
| |
| |
| eucl_error = np.mean(np.sqrt(np.sum((predictions - labels) ** 2, axis=-1))) |
| |
| |
| diff_mm = (predictions - labels) * np.array([65.0, 140.0]) |
| screen_error = np.mean(np.sqrt(np.sum(diff_mm ** 2, axis=-1))) |
| |
| |
| screen_error_cm = screen_error / 10.0 |
| |
| avg_inference_ms = (elapsed / num_samples) * 1000 |
| |
| print(f" Euclidean error (normalized): {eucl_error:.4f}") |
| print(f" Screen error: {screen_error:.1f} mm ({screen_error_cm:.2f} cm)") |
| print(f" Average inference time (CPU): {avg_inference_ms:.2f} ms") |
| print(f" FPS (CPU): {1000 / avg_inference_ms:.1f}") |
| |
| return { |
| 'euclidean_error': float(eucl_error), |
| 'screen_error_mm': float(screen_error), |
| 'screen_error_cm': float(screen_error_cm), |
| 'avg_inference_ms': float(avg_inference_ms), |
| 'fps': float(1000 / avg_inference_ms), |
| 'num_test_samples': num_samples, |
| } |
|
|
|
|
| def main(): |
| print("="*60) |
| print("GazeInception-Lite: Mobile Eye Gaze Estimation") |
| print("="*60) |
| |
| |
| NUM_TRAIN = 50000 |
| NUM_VAL = 5000 |
| NUM_TEST = 3000 |
| EPOCHS_SINGLE = 80 |
| EPOCHS_DUAL = 80 |
| BATCH_SIZE = 128 |
| |
| OUTPUT_DIR = '/app/output' |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
| |
| |
| |
| |
| print("\n[1/6] Generating synthetic training data...") |
| print(f" Train: {NUM_TRAIN}, Val: {NUM_VAL}, Test: {NUM_TEST}") |
| print(f" Augmentations: dark(30%), glasses(25%), lazy_eye(15%), noise(50%)") |
| |
| gen = SyntheticGazeDataGenerator(seed=42) |
| |
| t0 = time.time() |
| train_data = gen.generate_dataset(NUM_TRAIN, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15) |
| print(f" Train data generated in {time.time()-t0:.1f}s") |
| |
| gen_val = SyntheticGazeDataGenerator(seed=123) |
| val_data = gen_val.generate_dataset(NUM_VAL, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15) |
| |
| gen_test = SyntheticGazeDataGenerator(seed=456) |
| test_data = gen_test.generate_dataset(NUM_TEST, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15) |
| |
| |
| gen_dark = SyntheticGazeDataGenerator(seed=789) |
| test_dark = gen_dark.generate_dataset(1000, dark_prob=1.0, with_glasses_prob=0.0, lazy_eye_prob=0.0) |
| |
| gen_glasses = SyntheticGazeDataGenerator(seed=101) |
| test_glasses = gen_glasses.generate_dataset(1000, dark_prob=0.0, with_glasses_prob=1.0, lazy_eye_prob=0.0) |
| |
| gen_lazy = SyntheticGazeDataGenerator(seed=202) |
| test_lazy = gen_lazy.generate_dataset(1000, dark_prob=0.0, with_glasses_prob=0.0, lazy_eye_prob=1.0) |
| |
| print(f" All data generated. Total time: {time.time()-t0:.1f}s") |
| |
| |
| |
| |
| print("\n[2/6] Training single-eye model...") |
| single_model, single_history = train_single_eye_model( |
| train_data, val_data, |
| epochs=EPOCHS_SINGLE, batch_size=BATCH_SIZE, |
| output_dir=os.path.join(OUTPUT_DIR, 'single_eye') |
| ) |
| |
| |
| |
| |
| print("\n[3/6] Training dual-eye model...") |
| dual_model, dual_history = train_dual_eye_model( |
| train_data, val_data, |
| epochs=EPOCHS_DUAL, batch_size=64, |
| output_dir=os.path.join(OUTPUT_DIR, 'dual_eye') |
| ) |
| |
| |
| |
| |
| print("\n[4/6] Converting models to TFLite...") |
| |
| |
| single_tflite_f16_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_single_f16.tflite') |
| convert_to_tflite(single_model, single_tflite_f16_path, quantize=False) |
| |
| |
| single_tflite_int8_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_single_int8.tflite') |
| all_eyes = np.concatenate([test_data['left_eye'], test_data['right_eye']], axis=0) |
| convert_to_tflite(single_model, single_tflite_int8_path, quantize=True, test_data=all_eyes) |
| |
| |
| dual_tflite_f16_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_dual_f16.tflite') |
| convert_dual_eye_to_tflite(dual_model, dual_tflite_f16_path, quantize=False) |
| |
| |
| dual_tflite_int8_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_dual_int8.tflite') |
| convert_dual_eye_to_tflite(dual_model, dual_tflite_int8_path, quantize=True, test_data=test_data) |
| |
| |
| |
| |
| print("\n[5/6] Evaluating TFLite models...") |
| |
| results = {} |
| |
| |
| print("\n--- Single Eye Model (Float16) ---") |
| results['single_f16'] = evaluate_tflite( |
| single_tflite_f16_path, all_eyes[:3000], test_data['gaze'] |
| ) |
| |
| print("\n--- Single Eye Model (INT8) ---") |
| results['single_int8'] = evaluate_tflite( |
| single_tflite_int8_path, all_eyes[:3000], test_data['gaze'] |
| ) |
| |
| |
| print("\n--- Dual Eye Model (Float16) ---") |
| dual_inputs = [test_data['left_eye'], test_data['right_eye'], test_data['face']] |
| results['dual_f16'] = evaluate_tflite( |
| dual_tflite_f16_path, dual_inputs, test_data['gaze'], is_dual=True |
| ) |
| |
| print("\n--- Dual Eye Model (INT8) ---") |
| results['dual_int8'] = evaluate_tflite( |
| dual_tflite_int8_path, dual_inputs, test_data['gaze'], is_dual=True |
| ) |
| |
| |
| print("\n--- Robustness Evaluation (Dual Eye, Float16) ---") |
| print("\n [Dark conditions]") |
| dark_inputs = [test_dark['left_eye'], test_dark['right_eye'], test_dark['face']] |
| results['dual_f16_dark'] = evaluate_tflite( |
| dual_tflite_f16_path, dark_inputs, test_dark['gaze'], is_dual=True |
| ) |
| |
| print("\n [With glasses]") |
| glasses_inputs = [test_glasses['left_eye'], test_glasses['right_eye'], test_glasses['face']] |
| results['dual_f16_glasses'] = evaluate_tflite( |
| dual_tflite_f16_path, glasses_inputs, test_glasses['gaze'], is_dual=True |
| ) |
| |
| print("\n [Lazy eye / strabismus]") |
| lazy_inputs = [test_lazy['left_eye'], test_lazy['right_eye'], test_lazy['face']] |
| results['dual_f16_lazy_eye'] = evaluate_tflite( |
| dual_tflite_f16_path, lazy_inputs, test_lazy['gaze'], is_dual=True |
| ) |
| |
| |
| |
| |
| print("\n[6/6] Saving results...") |
| |
| |
| metadata = { |
| 'model_name': 'GazeInception-Lite', |
| 'task': 'eye-gaze-estimation', |
| 'description': 'Lightweight TFLite model for mobile eye gaze estimation on phone screens', |
| 'architecture': { |
| 'type': 'Gated Inception Network with Coordinate Attention', |
| 'single_eye_params': int(single_model.count_params()), |
| 'dual_eye_params': int(dual_model.count_params()), |
| 'input_size': '64x64x3', |
| 'features': [ |
| 'Gated Inception blocks (learned branch gating to skip useless compute)', |
| 'Coordinate Attention for spatial gaze awareness', |
| 'Depthwise separable convolutions for efficiency', |
| 'Dual-eye processing with shared weights (handles lazy eye)', |
| 'Face context branch (head pose proxy)' |
| ] |
| }, |
| 'training': { |
| 'dataset': 'Synthetic (50K train, 5K val, 3K test)', |
| 'augmentations': [ |
| 'Dark/low-light conditions (30% probability, 15-50% brightness)', |
| 'Glasses overlay synthesis (25% probability, 10 frame styles)', |
| 'Lazy eye/strabismus simulation (15% probability)', |
| 'CMOS sensor noise (50% probability)', |
| 'Illumination perturbation (directional light gradients)', |
| 'Diverse skin tones (12 variations)', |
| 'Diverse eye colors (7 variations)' |
| ], |
| 'optimizer': 'Adam with Cosine Decay LR', |
| 'initial_lr': 1e-3, |
| 'loss': 'MSE', |
| 'epochs': f'{EPOCHS_SINGLE} (single) / {EPOCHS_DUAL} (dual)', |
| }, |
| 'tflite_models': { |
| 'single_eye_f16': { |
| 'file': 'gaze_inception_lite_single_f16.tflite', |
| 'size_kb': os.path.getsize(single_tflite_f16_path) / 1024, |
| 'quantization': 'float16', |
| }, |
| 'single_eye_int8': { |
| 'file': 'gaze_inception_lite_single_int8.tflite', |
| 'size_kb': os.path.getsize(single_tflite_int8_path) / 1024, |
| 'quantization': 'int8', |
| }, |
| 'dual_eye_f16': { |
| 'file': 'gaze_inception_lite_dual_f16.tflite', |
| 'size_kb': os.path.getsize(dual_tflite_f16_path) / 1024, |
| 'quantization': 'float16', |
| }, |
| 'dual_eye_int8': { |
| 'file': 'gaze_inception_lite_dual_int8.tflite', |
| 'size_kb': os.path.getsize(dual_tflite_int8_path) / 1024, |
| 'quantization': 'int8', |
| }, |
| }, |
| 'evaluation_results': results, |
| 'references': [ |
| 'AGE Framework - arxiv:2603.26945', |
| 'Gated Compression Layers - arxiv:2303.08970', |
| 'iTracker / GazeCapture - arxiv:1606.05814', |
| 'Coordinate Attention - Hou et al. 2021', |
| 'MobileNetV2 - arxiv:1801.04381', |
| ] |
| } |
| |
| with open(os.path.join(OUTPUT_DIR, 'metadata.json'), 'w') as f: |
| json.dump(metadata, f, indent=2) |
| |
| |
| for name, hist in [('single', single_history), ('dual', dual_history)]: |
| hist_dict = {k: [float(v) for v in vals] for k, vals in hist.history.items()} |
| with open(os.path.join(OUTPUT_DIR, f'{name}_history.json'), 'w') as f: |
| json.dump(hist_dict, f, indent=2) |
| |
| |
| print("\n" + "="*60) |
| print("TRAINING COMPLETE - SUMMARY") |
| print("="*60) |
| |
| print(f"\nSingle-Eye Model:") |
| print(f" Parameters: {single_model.count_params():,}") |
| print(f" F16 TFLite: {os.path.getsize(single_tflite_f16_path)/1024:.1f} KB") |
| print(f" INT8 TFLite: {os.path.getsize(single_tflite_int8_path)/1024:.1f} KB") |
| if 'single_int8' in results: |
| r = results['single_int8'] |
| print(f" Screen error: {r['screen_error_mm']:.1f} mm") |
| print(f" Inference: {r['avg_inference_ms']:.2f} ms ({r['fps']:.0f} FPS)") |
| |
| print(f"\nDual-Eye Model:") |
| print(f" Parameters: {dual_model.count_params():,}") |
| print(f" F16 TFLite: {os.path.getsize(dual_tflite_f16_path)/1024:.1f} KB") |
| print(f" INT8 TFLite: {os.path.getsize(dual_tflite_int8_path)/1024:.1f} KB") |
| if 'dual_int8' in results: |
| r = results['dual_int8'] |
| print(f" Screen error: {r['screen_error_mm']:.1f} mm") |
| print(f" Inference: {r['avg_inference_ms']:.2f} ms ({r['fps']:.0f} FPS)") |
| |
| print(f"\nRobustness (Dual Eye):") |
| for condition in ['dark', 'glasses', 'lazy_eye']: |
| key = f'dual_f16_{condition}' |
| if key in results: |
| r = results[key] |
| print(f" {condition}: {r['screen_error_mm']:.1f} mm error") |
| |
| print(f"\nOutput directory: {OUTPUT_DIR}") |
| print(f"Files: {os.listdir(OUTPUT_DIR)}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|