GazeInceptionLite / src /train.py
BcantCode's picture
Upload src/train.py with huggingface_hub
65a793e verified
"""
Training script for GazeInception-Lite model.
Trains both:
1. Single-eye model (89K params) - ultralight for constrained devices
2. Dual-eye model (137K params) - better accuracy with face context + lazy eye handling
Features:
- Gated Inception blocks with learned branch gating
- Coordinate Attention for spatial gaze awareness
- Synthetic data with: dark conditions, glasses, lazy eye, sensor noise
- TFLite conversion with full integer quantization
- Push to Hugging Face Hub
Based on:
- AGE framework (arxiv:2603.26945) - augmentation pipeline, multi-task approach
- Gated Compression Layers (arxiv:2303.08970) - gating mechanism
- iTracker (arxiv:1606.05814) - dual-eye + face architecture
"""
import os
import json
import time
import numpy as np
import tensorflow as tf
from tensorflow import keras
from pathlib import Path
# Import our modules
from model import build_gaze_inception_lite, build_dual_eye_model
from data_generator import SyntheticGazeDataGenerator, create_tf_dataset, create_single_eye_dataset
def euclidean_distance_metric(y_true, y_pred):
"""Euclidean distance in normalized [0,1] coordinates."""
return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(y_true - y_pred), axis=-1)))
def screen_error_mm(y_true, y_pred, screen_w_mm=65.0, screen_h_mm=140.0):
"""Error in mm assuming typical phone screen (65mm x 140mm)."""
diff = y_true - y_pred
diff_mm = diff * tf.constant([screen_w_mm, screen_h_mm])
return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(diff_mm), axis=-1)))
def train_single_eye_model(train_data, val_data, epochs=100, batch_size=128,
output_dir='models/single_eye'):
"""Train the lightweight single-eye model."""
print("\n" + "="*60)
print("Training Single-Eye GazeInception-Lite Model")
print("="*60)
os.makedirs(output_dir, exist_ok=True)
model = build_gaze_inception_lite(input_shape=(64, 64, 3), num_outputs=2)
# Cosine decay learning rate
lr_schedule = keras.optimizers.schedules.CosineDecay(
initial_learning_rate=1e-3,
decay_steps=epochs * (len(train_data['gaze']) * 2 // batch_size),
alpha=1e-5
)
optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(
optimizer=optimizer,
loss='mse',
metrics=[euclidean_distance_metric]
)
# Create datasets
train_ds = create_single_eye_dataset(train_data, batch_size=batch_size, shuffle=True)
val_ds = create_single_eye_dataset(val_data, batch_size=batch_size, shuffle=False)
callbacks = [
keras.callbacks.ModelCheckpoint(
os.path.join(output_dir, 'best_model.keras'),
monitor='val_euclidean_distance_metric',
save_best_only=True, mode='min', verbose=1
),
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1
),
keras.callbacks.EarlyStopping(
monitor='val_euclidean_distance_metric', patience=20,
restore_best_weights=True, verbose=1
),
]
history = model.fit(
train_ds, validation_data=val_ds,
epochs=epochs, callbacks=callbacks, verbose=1
)
# Save final model
model.save(os.path.join(output_dir, 'final_model.keras'))
return model, history
def train_dual_eye_model(train_data, val_data, epochs=100, batch_size=64,
output_dir='models/dual_eye'):
"""Train the dual-eye model with face context."""
print("\n" + "="*60)
print("Training Dual-Eye GazeInception-Lite Model")
print("="*60)
os.makedirs(output_dir, exist_ok=True)
model = build_dual_eye_model(
eye_shape=(64, 64, 3), face_shape=(64, 64, 3), num_outputs=2
)
lr_schedule = keras.optimizers.schedules.CosineDecay(
initial_learning_rate=1e-3,
decay_steps=epochs * (len(train_data['gaze']) // batch_size),
alpha=1e-5
)
optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(
optimizer=optimizer,
loss='mse',
metrics=[euclidean_distance_metric]
)
# Create datasets
train_ds = create_tf_dataset(train_data, batch_size=batch_size, shuffle=True)
val_ds = create_tf_dataset(val_data, batch_size=batch_size, shuffle=False)
callbacks = [
keras.callbacks.ModelCheckpoint(
os.path.join(output_dir, 'best_model.keras'),
monitor='val_euclidean_distance_metric',
save_best_only=True, mode='min', verbose=1
),
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', factor=0.5, patience=10, min_lr=1e-6, verbose=1
),
keras.callbacks.EarlyStopping(
monitor='val_euclidean_distance_metric', patience=20,
restore_best_weights=True, verbose=1
),
]
history = model.fit(
train_ds, validation_data=val_ds,
epochs=epochs, callbacks=callbacks, verbose=1
)
model.save(os.path.join(output_dir, 'final_model.keras'))
return model, history
def convert_to_tflite(keras_model, output_path, quantize=True, test_data=None):
"""Convert Keras model to TFLite with optional quantization."""
print(f"\nConverting to TFLite: {output_path}")
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
# Optimization settings for mobile
converter.optimizations = [tf.lite.Optimize.DEFAULT]
if quantize and test_data is not None:
# Full integer quantization (fastest on mobile)
def representative_dataset_gen():
for i in range(min(200, len(test_data))):
yield [test_data[i:i+1].astype(np.float32)]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8
]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.float32
print(" Using INT8 quantization for maximum mobile speed")
tflite_model = converter.convert()
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
f.write(tflite_model)
size_kb = len(tflite_model) / 1024
print(f" TFLite model size: {size_kb:.1f} KB")
return tflite_model
def convert_dual_eye_to_tflite(keras_model, output_path, quantize=True, test_data=None):
"""Convert dual-eye model to TFLite."""
print(f"\nConverting dual-eye model to TFLite: {output_path}")
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
if quantize and test_data is not None:
def representative_dataset_gen():
for i in range(min(200, len(test_data['left_eye']))):
yield [
test_data['left_eye'][i:i+1].astype(np.float32),
test_data['right_eye'][i:i+1].astype(np.float32),
test_data['face'][i:i+1].astype(np.float32),
]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS_INT8,
tf.lite.OpsSet.TFLITE_BUILTINS, # fallback for unsupported ops
]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.float32
print(" Using INT8 quantization for maximum mobile speed")
tflite_model = converter.convert()
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'wb') as f:
f.write(tflite_model)
size_kb = len(tflite_model) / 1024
print(f" TFLite model size: {size_kb:.1f} KB")
return tflite_model
def evaluate_tflite(tflite_path, test_inputs, test_labels, is_dual=False):
"""Evaluate TFLite model accuracy and inference speed."""
print(f"\nEvaluating TFLite model: {tflite_path}")
interpreter = tf.lite.Interpreter(model_path=tflite_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(f" Input details: {[(d['name'], d['shape'], d['dtype']) for d in input_details]}")
print(f" Output details: {[(d['name'], d['shape'], d['dtype']) for d in output_details]}")
predictions = []
num_samples = min(500, len(test_labels))
# Warmup
for _ in range(5):
if is_dual:
for idx, detail in enumerate(input_details):
if detail['dtype'] == np.uint8:
data = (test_inputs[idx][0:1] * 255).astype(np.uint8)
else:
data = test_inputs[idx][0:1].astype(np.float32)
interpreter.set_tensor(detail['index'], data)
else:
if input_details[0]['dtype'] == np.uint8:
data = (test_inputs[0:1] * 255).astype(np.uint8)
else:
data = test_inputs[0:1].astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], data)
interpreter.invoke()
# Benchmark
start_time = time.time()
for i in range(num_samples):
if is_dual:
for idx, detail in enumerate(input_details):
if detail['dtype'] == np.uint8:
data = (test_inputs[idx][i:i+1] * 255).astype(np.uint8)
else:
data = test_inputs[idx][i:i+1].astype(np.float32)
interpreter.set_tensor(detail['index'], data)
else:
if input_details[0]['dtype'] == np.uint8:
data = (test_inputs[i:i+1] * 255).astype(np.uint8)
else:
data = test_inputs[i:i+1].astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], data)
interpreter.invoke()
pred = interpreter.get_tensor(output_details[0]['index'])[0]
predictions.append(pred)
elapsed = time.time() - start_time
predictions = np.array(predictions)
labels = test_labels[:num_samples]
# Metrics
eucl_error = np.mean(np.sqrt(np.sum((predictions - labels) ** 2, axis=-1)))
# Screen error in mm (typical phone: 65mm x 140mm)
diff_mm = (predictions - labels) * np.array([65.0, 140.0])
screen_error = np.mean(np.sqrt(np.sum(diff_mm ** 2, axis=-1)))
# Screen error in cm
screen_error_cm = screen_error / 10.0
avg_inference_ms = (elapsed / num_samples) * 1000
print(f" Euclidean error (normalized): {eucl_error:.4f}")
print(f" Screen error: {screen_error:.1f} mm ({screen_error_cm:.2f} cm)")
print(f" Average inference time (CPU): {avg_inference_ms:.2f} ms")
print(f" FPS (CPU): {1000 / avg_inference_ms:.1f}")
return {
'euclidean_error': float(eucl_error),
'screen_error_mm': float(screen_error),
'screen_error_cm': float(screen_error_cm),
'avg_inference_ms': float(avg_inference_ms),
'fps': float(1000 / avg_inference_ms),
'num_test_samples': num_samples,
}
def main():
print("="*60)
print("GazeInception-Lite: Mobile Eye Gaze Estimation")
print("="*60)
# Configuration
NUM_TRAIN = 50000
NUM_VAL = 5000
NUM_TEST = 3000
EPOCHS_SINGLE = 80
EPOCHS_DUAL = 80
BATCH_SIZE = 128
OUTPUT_DIR = '/app/output'
os.makedirs(OUTPUT_DIR, exist_ok=True)
# ==========================================
# Generate synthetic training data
# ==========================================
print("\n[1/6] Generating synthetic training data...")
print(f" Train: {NUM_TRAIN}, Val: {NUM_VAL}, Test: {NUM_TEST}")
print(f" Augmentations: dark(30%), glasses(25%), lazy_eye(15%), noise(50%)")
gen = SyntheticGazeDataGenerator(seed=42)
t0 = time.time()
train_data = gen.generate_dataset(NUM_TRAIN, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15)
print(f" Train data generated in {time.time()-t0:.1f}s")
gen_val = SyntheticGazeDataGenerator(seed=123)
val_data = gen_val.generate_dataset(NUM_VAL, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15)
gen_test = SyntheticGazeDataGenerator(seed=456)
test_data = gen_test.generate_dataset(NUM_TEST, dark_prob=0.30, with_glasses_prob=0.25, lazy_eye_prob=0.15)
# Also generate condition-specific test sets for robustness evaluation
gen_dark = SyntheticGazeDataGenerator(seed=789)
test_dark = gen_dark.generate_dataset(1000, dark_prob=1.0, with_glasses_prob=0.0, lazy_eye_prob=0.0)
gen_glasses = SyntheticGazeDataGenerator(seed=101)
test_glasses = gen_glasses.generate_dataset(1000, dark_prob=0.0, with_glasses_prob=1.0, lazy_eye_prob=0.0)
gen_lazy = SyntheticGazeDataGenerator(seed=202)
test_lazy = gen_lazy.generate_dataset(1000, dark_prob=0.0, with_glasses_prob=0.0, lazy_eye_prob=1.0)
print(f" All data generated. Total time: {time.time()-t0:.1f}s")
# ==========================================
# Train Single-Eye Model
# ==========================================
print("\n[2/6] Training single-eye model...")
single_model, single_history = train_single_eye_model(
train_data, val_data,
epochs=EPOCHS_SINGLE, batch_size=BATCH_SIZE,
output_dir=os.path.join(OUTPUT_DIR, 'single_eye')
)
# ==========================================
# Train Dual-Eye Model
# ==========================================
print("\n[3/6] Training dual-eye model...")
dual_model, dual_history = train_dual_eye_model(
train_data, val_data,
epochs=EPOCHS_DUAL, batch_size=64,
output_dir=os.path.join(OUTPUT_DIR, 'dual_eye')
)
# ==========================================
# Convert to TFLite
# ==========================================
print("\n[4/6] Converting models to TFLite...")
# Single eye - float16
single_tflite_f16_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_single_f16.tflite')
convert_to_tflite(single_model, single_tflite_f16_path, quantize=False)
# Single eye - INT8 quantized
single_tflite_int8_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_single_int8.tflite')
all_eyes = np.concatenate([test_data['left_eye'], test_data['right_eye']], axis=0)
convert_to_tflite(single_model, single_tflite_int8_path, quantize=True, test_data=all_eyes)
# Dual eye - float16
dual_tflite_f16_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_dual_f16.tflite')
convert_dual_eye_to_tflite(dual_model, dual_tflite_f16_path, quantize=False)
# Dual eye - INT8 quantized
dual_tflite_int8_path = os.path.join(OUTPUT_DIR, 'gaze_inception_lite_dual_int8.tflite')
convert_dual_eye_to_tflite(dual_model, dual_tflite_int8_path, quantize=True, test_data=test_data)
# ==========================================
# Evaluate TFLite models
# ==========================================
print("\n[5/6] Evaluating TFLite models...")
results = {}
# Single eye evaluation
print("\n--- Single Eye Model (Float16) ---")
results['single_f16'] = evaluate_tflite(
single_tflite_f16_path, all_eyes[:3000], test_data['gaze']
)
print("\n--- Single Eye Model (INT8) ---")
results['single_int8'] = evaluate_tflite(
single_tflite_int8_path, all_eyes[:3000], test_data['gaze']
)
# Dual eye evaluation
print("\n--- Dual Eye Model (Float16) ---")
dual_inputs = [test_data['left_eye'], test_data['right_eye'], test_data['face']]
results['dual_f16'] = evaluate_tflite(
dual_tflite_f16_path, dual_inputs, test_data['gaze'], is_dual=True
)
print("\n--- Dual Eye Model (INT8) ---")
results['dual_int8'] = evaluate_tflite(
dual_tflite_int8_path, dual_inputs, test_data['gaze'], is_dual=True
)
# Condition-specific evaluation (dual model, float16)
print("\n--- Robustness Evaluation (Dual Eye, Float16) ---")
print("\n [Dark conditions]")
dark_inputs = [test_dark['left_eye'], test_dark['right_eye'], test_dark['face']]
results['dual_f16_dark'] = evaluate_tflite(
dual_tflite_f16_path, dark_inputs, test_dark['gaze'], is_dual=True
)
print("\n [With glasses]")
glasses_inputs = [test_glasses['left_eye'], test_glasses['right_eye'], test_glasses['face']]
results['dual_f16_glasses'] = evaluate_tflite(
dual_tflite_f16_path, glasses_inputs, test_glasses['gaze'], is_dual=True
)
print("\n [Lazy eye / strabismus]")
lazy_inputs = [test_lazy['left_eye'], test_lazy['right_eye'], test_lazy['face']]
results['dual_f16_lazy_eye'] = evaluate_tflite(
dual_tflite_f16_path, lazy_inputs, test_lazy['gaze'], is_dual=True
)
# ==========================================
# Save results and metadata
# ==========================================
print("\n[6/6] Saving results...")
# Model card metadata
metadata = {
'model_name': 'GazeInception-Lite',
'task': 'eye-gaze-estimation',
'description': 'Lightweight TFLite model for mobile eye gaze estimation on phone screens',
'architecture': {
'type': 'Gated Inception Network with Coordinate Attention',
'single_eye_params': int(single_model.count_params()),
'dual_eye_params': int(dual_model.count_params()),
'input_size': '64x64x3',
'features': [
'Gated Inception blocks (learned branch gating to skip useless compute)',
'Coordinate Attention for spatial gaze awareness',
'Depthwise separable convolutions for efficiency',
'Dual-eye processing with shared weights (handles lazy eye)',
'Face context branch (head pose proxy)'
]
},
'training': {
'dataset': 'Synthetic (50K train, 5K val, 3K test)',
'augmentations': [
'Dark/low-light conditions (30% probability, 15-50% brightness)',
'Glasses overlay synthesis (25% probability, 10 frame styles)',
'Lazy eye/strabismus simulation (15% probability)',
'CMOS sensor noise (50% probability)',
'Illumination perturbation (directional light gradients)',
'Diverse skin tones (12 variations)',
'Diverse eye colors (7 variations)'
],
'optimizer': 'Adam with Cosine Decay LR',
'initial_lr': 1e-3,
'loss': 'MSE',
'epochs': f'{EPOCHS_SINGLE} (single) / {EPOCHS_DUAL} (dual)',
},
'tflite_models': {
'single_eye_f16': {
'file': 'gaze_inception_lite_single_f16.tflite',
'size_kb': os.path.getsize(single_tflite_f16_path) / 1024,
'quantization': 'float16',
},
'single_eye_int8': {
'file': 'gaze_inception_lite_single_int8.tflite',
'size_kb': os.path.getsize(single_tflite_int8_path) / 1024,
'quantization': 'int8',
},
'dual_eye_f16': {
'file': 'gaze_inception_lite_dual_f16.tflite',
'size_kb': os.path.getsize(dual_tflite_f16_path) / 1024,
'quantization': 'float16',
},
'dual_eye_int8': {
'file': 'gaze_inception_lite_dual_int8.tflite',
'size_kb': os.path.getsize(dual_tflite_int8_path) / 1024,
'quantization': 'int8',
},
},
'evaluation_results': results,
'references': [
'AGE Framework - arxiv:2603.26945',
'Gated Compression Layers - arxiv:2303.08970',
'iTracker / GazeCapture - arxiv:1606.05814',
'Coordinate Attention - Hou et al. 2021',
'MobileNetV2 - arxiv:1801.04381',
]
}
with open(os.path.join(OUTPUT_DIR, 'metadata.json'), 'w') as f:
json.dump(metadata, f, indent=2)
# Save training history
for name, hist in [('single', single_history), ('dual', dual_history)]:
hist_dict = {k: [float(v) for v in vals] for k, vals in hist.history.items()}
with open(os.path.join(OUTPUT_DIR, f'{name}_history.json'), 'w') as f:
json.dump(hist_dict, f, indent=2)
# Print summary
print("\n" + "="*60)
print("TRAINING COMPLETE - SUMMARY")
print("="*60)
print(f"\nSingle-Eye Model:")
print(f" Parameters: {single_model.count_params():,}")
print(f" F16 TFLite: {os.path.getsize(single_tflite_f16_path)/1024:.1f} KB")
print(f" INT8 TFLite: {os.path.getsize(single_tflite_int8_path)/1024:.1f} KB")
if 'single_int8' in results:
r = results['single_int8']
print(f" Screen error: {r['screen_error_mm']:.1f} mm")
print(f" Inference: {r['avg_inference_ms']:.2f} ms ({r['fps']:.0f} FPS)")
print(f"\nDual-Eye Model:")
print(f" Parameters: {dual_model.count_params():,}")
print(f" F16 TFLite: {os.path.getsize(dual_tflite_f16_path)/1024:.1f} KB")
print(f" INT8 TFLite: {os.path.getsize(dual_tflite_int8_path)/1024:.1f} KB")
if 'dual_int8' in results:
r = results['dual_int8']
print(f" Screen error: {r['screen_error_mm']:.1f} mm")
print(f" Inference: {r['avg_inference_ms']:.2f} ms ({r['fps']:.0f} FPS)")
print(f"\nRobustness (Dual Eye):")
for condition in ['dark', 'glasses', 'lazy_eye']:
key = f'dual_f16_{condition}'
if key in results:
r = results[key]
print(f" {condition}: {r['screen_error_mm']:.1f} mm error")
print(f"\nOutput directory: {OUTPUT_DIR}")
print(f"Files: {os.listdir(OUTPUT_DIR)}")
if __name__ == '__main__':
main()