GazeInceptionLite / src /model.py
BcantCode's picture
Upload src/model.py with huggingface_hub
687b215 verified
"""
GazeInception-Lite: Gated Inception Model for Mobile Eye Gaze Estimation
Architecture:
- Input: 64x64 RGB eye crop (left + right eye stacked as 2-channel or 128x64 side-by-side)
- Gated Inception Blocks: Each inception block has a lightweight gate (squeeze-excitation style)
that learns to skip branches that contribute little, reducing useless compute
- Multi-scale feature extraction via inception (1x1, 3x3, 5x5 parallel convolutions)
- Coordinate Attention for spatial awareness
- Output: (x, y) screen coordinates normalized to [0, 1]
Design goals:
- < 500K parameters for fast mobile inference
- TFLite compatible (no unsupported ops)
- Works in dark (trained with illumination augmentation)
- Handles glasses (trained with glasses augmentation)
- Handles lazy eye / strabismus (trained with per-eye asymmetric augmentation)
"""
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
import numpy as np
class GatedInceptionBlock(layers.Layer):
"""
Inception block with gating mechanism.
The gate is a lightweight learned sigmoid that scales each inception branch.
Branches with low gate values contribute near-zero, effectively being "skipped"
at inference — reducing useless compute via learned conditional computation.
Branches:
1. 1x1 conv (point features)
2. 1x1 -> 3x3 conv (local features)
3. 1x1 -> 5x5 depthwise separable conv (wider context)
4. 3x3 max pool -> 1x1 conv (pooled features)
Gate: Global Average Pool -> Dense -> Sigmoid per branch
"""
def __init__(self, filters_1x1, filters_3x3_reduce, filters_3x3,
filters_5x5_reduce, filters_5x5, filters_pool, **kwargs):
super().__init__(**kwargs)
self.filters_1x1 = filters_1x1
self.filters_3x3 = filters_3x3
self.filters_5x5 = filters_5x5
self.filters_pool = filters_pool
self.num_branches = 4
# Branch 1: 1x1
self.branch1_conv = layers.Conv2D(filters_1x1, 1, padding='same', use_bias=False)
self.branch1_bn = layers.BatchNormalization()
# Branch 2: 1x1 -> 3x3
self.branch2_reduce = layers.Conv2D(filters_3x3_reduce, 1, padding='same', use_bias=False)
self.branch2_reduce_bn = layers.BatchNormalization()
self.branch2_conv = layers.DepthwiseConv2D(3, padding='same', use_bias=False)
self.branch2_pw = layers.Conv2D(filters_3x3, 1, padding='same', use_bias=False)
self.branch2_bn = layers.BatchNormalization()
# Branch 3: 1x1 -> 5x5 depthwise separable
self.branch3_reduce = layers.Conv2D(filters_5x5_reduce, 1, padding='same', use_bias=False)
self.branch3_reduce_bn = layers.BatchNormalization()
self.branch3_dw = layers.DepthwiseConv2D(5, padding='same', use_bias=False)
self.branch3_pw = layers.Conv2D(filters_5x5, 1, padding='same', use_bias=False)
self.branch3_bn = layers.BatchNormalization()
# Branch 4: MaxPool -> 1x1
self.branch4_pool = layers.MaxPooling2D(3, strides=1, padding='same')
self.branch4_conv = layers.Conv2D(filters_pool, 1, padding='same', use_bias=False)
self.branch4_bn = layers.BatchNormalization()
# Gating mechanism: learns to weight each branch
total_filters = filters_1x1 + filters_3x3 + filters_5x5 + filters_pool
self.gate_pool = layers.GlobalAveragePooling2D()
self.gate_dense1 = layers.Dense(self.num_branches * 4, activation='relu')
self.gate_dense2 = layers.Dense(self.num_branches, activation='sigmoid')
# Final activation
self.relu = layers.ReLU()
def call(self, x, training=False):
# Compute gate values (which branches to activate)
gate_input = self.gate_pool(x)
gate = self.gate_dense1(gate_input)
gate = self.gate_dense2(gate) # [batch, 4] sigmoid values
# Branch 1
b1 = self.branch1_conv(x)
b1 = self.branch1_bn(b1, training=training)
b1 = self.relu(b1)
# Branch 2
b2 = self.branch2_reduce(x)
b2 = self.branch2_reduce_bn(b2, training=training)
b2 = self.relu(b2)
b2 = self.branch2_conv(b2)
b2 = self.branch2_pw(b2)
b2 = self.branch2_bn(b2, training=training)
b2 = self.relu(b2)
# Branch 3
b3 = self.branch3_reduce(x)
b3 = self.branch3_reduce_bn(b3, training=training)
b3 = self.relu(b3)
b3 = self.branch3_dw(b3)
b3 = self.branch3_pw(b3)
b3 = self.branch3_bn(b3, training=training)
b3 = self.relu(b3)
# Branch 4
b4 = self.branch4_pool(x)
b4 = self.branch4_conv(b4)
b4 = self.branch4_bn(b4, training=training)
b4 = self.relu(b4)
# Apply gates: multiply each branch by its gate scalar
# gate[:, i] is a scalar per sample - reshape for broadcasting
g1 = tf.reshape(gate[:, 0], [-1, 1, 1, 1])
g2 = tf.reshape(gate[:, 1], [-1, 1, 1, 1])
g3 = tf.reshape(gate[:, 2], [-1, 1, 1, 1])
g4 = tf.reshape(gate[:, 3], [-1, 1, 1, 1])
b1 = b1 * g1
b2 = b2 * g2
b3 = b3 * g3
b4 = b4 * g4
# Concatenate gated branches
return tf.concat([b1, b2, b3, b4], axis=-1)
def get_config(self):
config = super().get_config()
config.update({
'filters_1x1': self.filters_1x1,
'filters_3x3_reduce': self.branch2_reduce.filters if hasattr(self.branch2_reduce, 'filters') else 0,
'filters_3x3': self.filters_3x3,
'filters_5x5_reduce': self.branch3_reduce.filters if hasattr(self.branch3_reduce, 'filters') else 0,
'filters_5x5': self.filters_5x5,
'filters_pool': self.filters_pool,
})
return config
class CoordinateAttention(layers.Layer):
"""
Coordinate Attention module (Hou et al. 2021).
Encodes spatial position info into channel attention for better localization.
Critical for gaze estimation where spatial position of iris matters.
"""
def __init__(self, reduction_ratio=4, **kwargs):
super().__init__(**kwargs)
self.reduction_ratio = reduction_ratio
def build(self, input_shape):
channels = input_shape[-1]
reduced_channels = max(channels // self.reduction_ratio, 8)
self.pool_h = layers.Lambda(lambda x: tf.reduce_mean(x, axis=2, keepdims=True))
self.pool_w = layers.Lambda(lambda x: tf.reduce_mean(x, axis=1, keepdims=True))
self.conv_reduce = layers.Conv2D(reduced_channels, 1, use_bias=False)
self.bn = layers.BatchNormalization()
self.relu = layers.ReLU()
self.conv_h = layers.Conv2D(channels, 1, activation='sigmoid')
self.conv_w = layers.Conv2D(channels, 1, activation='sigmoid')
super().build(input_shape)
def call(self, x, training=False):
# Pool along width (keep height)
h_att = self.pool_h(x) # [B, H, 1, C]
# Pool along height (keep width)
w_att = self.pool_w(x) # [B, 1, W, C]
# Transpose w_att to match h_att shape for concatenation
w_att_t = tf.transpose(w_att, perm=[0, 2, 1, 3]) # [B, W, 1, C]
# Concatenate and reduce
combined = tf.concat([h_att, w_att_t], axis=1) # [B, H+W, 1, C]
combined = self.conv_reduce(combined)
combined = self.bn(combined, training=training)
combined = self.relu(combined)
# Split back
h_len = tf.shape(h_att)[1]
w_len = tf.shape(w_att_t)[1]
h_out = combined[:, :h_len, :, :]
w_out = combined[:, h_len:, :, :]
# Generate attention maps
h_att_map = self.conv_h(h_out) # [B, H, 1, C]
w_att_map = self.conv_w(w_out) # [B, W, 1, C]
w_att_map = tf.transpose(w_att_map, perm=[0, 2, 1, 3]) # [B, 1, W, C]
# Apply attention
return x * h_att_map * w_att_map
def build_gaze_inception_lite(input_shape=(64, 64, 3), num_outputs=2):
"""
Build the GazeInception-Lite model.
Architecture:
Input (64x64x3) -> Stem -> GatedInception1 -> GatedInception2 ->
CoordAttention -> GatedInception3 -> GlobalPool -> Dense -> (x, y)
Total: ~350K parameters
"""
inputs = layers.Input(shape=input_shape, name='eye_image')
# Stem: lightweight feature extraction
x = layers.Conv2D(32, 3, strides=2, padding='same', use_bias=False)(inputs) # 32x32
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.Conv2D(32, 3, padding='same', use_bias=False)(x) # 32x32
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
# Gated Inception Block 1 (32x32 -> 16x16)
x = GatedInceptionBlock(
filters_1x1=16,
filters_3x3_reduce=16, filters_3x3=24,
filters_5x5_reduce=8, filters_5x5=12,
filters_pool=12,
name='gated_inception_1'
)(x) # output: 64 channels
x = layers.MaxPooling2D(2)(x) # 16x16
# Gated Inception Block 2 (16x16 -> 8x8)
x = GatedInceptionBlock(
filters_1x1=32,
filters_3x3_reduce=24, filters_3x3=48,
filters_5x5_reduce=12, filters_5x5=24,
filters_pool=24,
name='gated_inception_2'
)(x) # output: 128 channels
x = layers.MaxPooling2D(2)(x) # 8x8
# Coordinate Attention - encodes spatial position for gaze direction
x = CoordinateAttention(reduction_ratio=4, name='coord_attention')(x)
# Gated Inception Block 3 (8x8 -> 4x4)
x = GatedInceptionBlock(
filters_1x1=48,
filters_3x3_reduce=32, filters_3x3=64,
filters_5x5_reduce=16, filters_5x5=32,
filters_pool=32,
name='gated_inception_3'
)(x) # output: 176 channels
x = layers.MaxPooling2D(2)(x) # 4x4
# Global feature aggregation
x = layers.GlobalAveragePooling2D()(x)
# Regression head
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(64, activation='relu')(x)
x = layers.Dropout(0.2)(x)
# Output: (x, y) screen coordinates in [0, 1]
outputs = layers.Dense(num_outputs, activation='sigmoid', name='gaze_coords')(x)
model = Model(inputs=inputs, outputs=outputs, name='GazeInceptionLite')
return model
def build_dual_eye_model(eye_shape=(64, 64, 3), face_shape=(64, 64, 3), num_outputs=2):
"""
Full model with dual eye inputs + face context.
This handles lazy eye by processing each eye independently through
shared-weight gated inception, then combining with face features.
Each eye gets its own gaze features, and the model learns to handle
asymmetric eye conditions (strabismus/amblyopia).
Inputs:
- left_eye: 64x64x3 crop
- right_eye: 64x64x3 crop
- face: 64x64x3 crop (provides head pose context)
Output:
- (x, y) normalized screen coordinates
"""
left_eye_input = layers.Input(shape=eye_shape, name='left_eye')
right_eye_input = layers.Input(shape=eye_shape, name='right_eye')
face_input = layers.Input(shape=face_shape, name='face')
# Shared eye feature extractor (gated inception backbone)
eye_backbone = build_gaze_inception_lite(input_shape=eye_shape, num_outputs=2)
# Get features from the GlobalAveragePooling layer (before dense head)
# Find the GlobalAveragePooling2D layer
gap_layer = None
for layer in eye_backbone.layers:
if isinstance(layer, layers.GlobalAveragePooling2D):
gap_layer = layer
eye_feature_extractor = Model(
inputs=eye_backbone.input,
outputs=gap_layer.output,
name='eye_feature_extractor'
)
# Extract features for each eye independently (shared weights)
left_features = eye_feature_extractor(left_eye_input) # [B, 176]
right_features = eye_feature_extractor(right_eye_input) # [B, 176]
# Lightweight face context extractor (head pose proxy)
f = layers.Conv2D(16, 3, strides=2, padding='same', activation='relu')(face_input)
f = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(f)
f = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(f)
f = layers.GlobalAveragePooling2D()(f)
face_features = layers.Dense(64, activation='relu')(f) # [B, 64]
# Combine: left_eye + right_eye + face
# The model learns eye asymmetry (lazy eye) because eyes are separate inputs
combined = layers.Concatenate()([left_features, right_features, face_features])
# Fusion head
x = layers.Dense(128, activation='relu')(combined)
x = layers.Dropout(0.3)(x)
x = layers.Dense(64, activation='relu')(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(num_outputs, activation='sigmoid', name='gaze_coords')(x)
model = Model(
inputs=[left_eye_input, right_eye_input, face_input],
outputs=outputs,
name='GazeInceptionLite_DualEye'
)
return model
if __name__ == '__main__':
# Test single eye model
model_single = build_gaze_inception_lite()
model_single.summary()
print(f"\nSingle eye model params: {model_single.count_params():,}")
# Test with random input
test_input = np.random.rand(2, 64, 64, 3).astype(np.float32)
output = model_single(test_input)
print(f"Output shape: {output.shape}")
print(f"Output values: {output.numpy()}")
print("\n" + "="*60)
# Test dual eye model
model_dual = build_dual_eye_model()
model_dual.summary()
print(f"\nDual eye model params: {model_dual.count_params():,}")
test_left = np.random.rand(2, 64, 64, 3).astype(np.float32)
test_right = np.random.rand(2, 64, 64, 3).astype(np.float32)
test_face = np.random.rand(2, 64, 64, 3).astype(np.float32)
output = model_dual([test_left, test_right, test_face])
print(f"Output shape: {output.shape}")
print(f"Output values: {output.numpy()}")