| """ |
| GazeInception-Lite: Gated Inception Model for Mobile Eye Gaze Estimation |
| |
| Architecture: |
| - Input: 64x64 RGB eye crop (left + right eye stacked as 2-channel or 128x64 side-by-side) |
| - Gated Inception Blocks: Each inception block has a lightweight gate (squeeze-excitation style) |
| that learns to skip branches that contribute little, reducing useless compute |
| - Multi-scale feature extraction via inception (1x1, 3x3, 5x5 parallel convolutions) |
| - Coordinate Attention for spatial awareness |
| - Output: (x, y) screen coordinates normalized to [0, 1] |
| |
| Design goals: |
| - < 500K parameters for fast mobile inference |
| - TFLite compatible (no unsupported ops) |
| - Works in dark (trained with illumination augmentation) |
| - Handles glasses (trained with glasses augmentation) |
| - Handles lazy eye / strabismus (trained with per-eye asymmetric augmentation) |
| """ |
|
|
| import tensorflow as tf |
| from tensorflow import keras |
| from tensorflow.keras import layers, Model |
| import numpy as np |
|
|
|
|
| class GatedInceptionBlock(layers.Layer): |
| """ |
| Inception block with gating mechanism. |
| |
| The gate is a lightweight learned sigmoid that scales each inception branch. |
| Branches with low gate values contribute near-zero, effectively being "skipped" |
| at inference — reducing useless compute via learned conditional computation. |
| |
| Branches: |
| 1. 1x1 conv (point features) |
| 2. 1x1 -> 3x3 conv (local features) |
| 3. 1x1 -> 5x5 depthwise separable conv (wider context) |
| 4. 3x3 max pool -> 1x1 conv (pooled features) |
| |
| Gate: Global Average Pool -> Dense -> Sigmoid per branch |
| """ |
| |
| def __init__(self, filters_1x1, filters_3x3_reduce, filters_3x3, |
| filters_5x5_reduce, filters_5x5, filters_pool, **kwargs): |
| super().__init__(**kwargs) |
| self.filters_1x1 = filters_1x1 |
| self.filters_3x3 = filters_3x3 |
| self.filters_5x5 = filters_5x5 |
| self.filters_pool = filters_pool |
| self.num_branches = 4 |
| |
| |
| self.branch1_conv = layers.Conv2D(filters_1x1, 1, padding='same', use_bias=False) |
| self.branch1_bn = layers.BatchNormalization() |
| |
| |
| self.branch2_reduce = layers.Conv2D(filters_3x3_reduce, 1, padding='same', use_bias=False) |
| self.branch2_reduce_bn = layers.BatchNormalization() |
| self.branch2_conv = layers.DepthwiseConv2D(3, padding='same', use_bias=False) |
| self.branch2_pw = layers.Conv2D(filters_3x3, 1, padding='same', use_bias=False) |
| self.branch2_bn = layers.BatchNormalization() |
| |
| |
| self.branch3_reduce = layers.Conv2D(filters_5x5_reduce, 1, padding='same', use_bias=False) |
| self.branch3_reduce_bn = layers.BatchNormalization() |
| self.branch3_dw = layers.DepthwiseConv2D(5, padding='same', use_bias=False) |
| self.branch3_pw = layers.Conv2D(filters_5x5, 1, padding='same', use_bias=False) |
| self.branch3_bn = layers.BatchNormalization() |
| |
| |
| self.branch4_pool = layers.MaxPooling2D(3, strides=1, padding='same') |
| self.branch4_conv = layers.Conv2D(filters_pool, 1, padding='same', use_bias=False) |
| self.branch4_bn = layers.BatchNormalization() |
| |
| |
| total_filters = filters_1x1 + filters_3x3 + filters_5x5 + filters_pool |
| self.gate_pool = layers.GlobalAveragePooling2D() |
| self.gate_dense1 = layers.Dense(self.num_branches * 4, activation='relu') |
| self.gate_dense2 = layers.Dense(self.num_branches, activation='sigmoid') |
| |
| |
| self.relu = layers.ReLU() |
| |
| def call(self, x, training=False): |
| |
| gate_input = self.gate_pool(x) |
| gate = self.gate_dense1(gate_input) |
| gate = self.gate_dense2(gate) |
| |
| |
| b1 = self.branch1_conv(x) |
| b1 = self.branch1_bn(b1, training=training) |
| b1 = self.relu(b1) |
| |
| |
| b2 = self.branch2_reduce(x) |
| b2 = self.branch2_reduce_bn(b2, training=training) |
| b2 = self.relu(b2) |
| b2 = self.branch2_conv(b2) |
| b2 = self.branch2_pw(b2) |
| b2 = self.branch2_bn(b2, training=training) |
| b2 = self.relu(b2) |
| |
| |
| b3 = self.branch3_reduce(x) |
| b3 = self.branch3_reduce_bn(b3, training=training) |
| b3 = self.relu(b3) |
| b3 = self.branch3_dw(b3) |
| b3 = self.branch3_pw(b3) |
| b3 = self.branch3_bn(b3, training=training) |
| b3 = self.relu(b3) |
| |
| |
| b4 = self.branch4_pool(x) |
| b4 = self.branch4_conv(b4) |
| b4 = self.branch4_bn(b4, training=training) |
| b4 = self.relu(b4) |
| |
| |
| |
| g1 = tf.reshape(gate[:, 0], [-1, 1, 1, 1]) |
| g2 = tf.reshape(gate[:, 1], [-1, 1, 1, 1]) |
| g3 = tf.reshape(gate[:, 2], [-1, 1, 1, 1]) |
| g4 = tf.reshape(gate[:, 3], [-1, 1, 1, 1]) |
| |
| b1 = b1 * g1 |
| b2 = b2 * g2 |
| b3 = b3 * g3 |
| b4 = b4 * g4 |
| |
| |
| return tf.concat([b1, b2, b3, b4], axis=-1) |
| |
| def get_config(self): |
| config = super().get_config() |
| config.update({ |
| 'filters_1x1': self.filters_1x1, |
| 'filters_3x3_reduce': self.branch2_reduce.filters if hasattr(self.branch2_reduce, 'filters') else 0, |
| 'filters_3x3': self.filters_3x3, |
| 'filters_5x5_reduce': self.branch3_reduce.filters if hasattr(self.branch3_reduce, 'filters') else 0, |
| 'filters_5x5': self.filters_5x5, |
| 'filters_pool': self.filters_pool, |
| }) |
| return config |
|
|
|
|
| class CoordinateAttention(layers.Layer): |
| """ |
| Coordinate Attention module (Hou et al. 2021). |
| Encodes spatial position info into channel attention for better localization. |
| Critical for gaze estimation where spatial position of iris matters. |
| """ |
| |
| def __init__(self, reduction_ratio=4, **kwargs): |
| super().__init__(**kwargs) |
| self.reduction_ratio = reduction_ratio |
| |
| def build(self, input_shape): |
| channels = input_shape[-1] |
| reduced_channels = max(channels // self.reduction_ratio, 8) |
| |
| self.pool_h = layers.Lambda(lambda x: tf.reduce_mean(x, axis=2, keepdims=True)) |
| self.pool_w = layers.Lambda(lambda x: tf.reduce_mean(x, axis=1, keepdims=True)) |
| |
| self.conv_reduce = layers.Conv2D(reduced_channels, 1, use_bias=False) |
| self.bn = layers.BatchNormalization() |
| self.relu = layers.ReLU() |
| |
| self.conv_h = layers.Conv2D(channels, 1, activation='sigmoid') |
| self.conv_w = layers.Conv2D(channels, 1, activation='sigmoid') |
| |
| super().build(input_shape) |
| |
| def call(self, x, training=False): |
| |
| h_att = self.pool_h(x) |
| |
| w_att = self.pool_w(x) |
| |
| |
| w_att_t = tf.transpose(w_att, perm=[0, 2, 1, 3]) |
| |
| |
| combined = tf.concat([h_att, w_att_t], axis=1) |
| combined = self.conv_reduce(combined) |
| combined = self.bn(combined, training=training) |
| combined = self.relu(combined) |
| |
| |
| h_len = tf.shape(h_att)[1] |
| w_len = tf.shape(w_att_t)[1] |
| |
| h_out = combined[:, :h_len, :, :] |
| w_out = combined[:, h_len:, :, :] |
| |
| |
| h_att_map = self.conv_h(h_out) |
| w_att_map = self.conv_w(w_out) |
| w_att_map = tf.transpose(w_att_map, perm=[0, 2, 1, 3]) |
| |
| |
| return x * h_att_map * w_att_map |
|
|
|
|
| def build_gaze_inception_lite(input_shape=(64, 64, 3), num_outputs=2): |
| """ |
| Build the GazeInception-Lite model. |
| |
| Architecture: |
| Input (64x64x3) -> Stem -> GatedInception1 -> GatedInception2 -> |
| CoordAttention -> GatedInception3 -> GlobalPool -> Dense -> (x, y) |
| |
| Total: ~350K parameters |
| """ |
| inputs = layers.Input(shape=input_shape, name='eye_image') |
| |
| |
| x = layers.Conv2D(32, 3, strides=2, padding='same', use_bias=False)(inputs) |
| x = layers.BatchNormalization()(x) |
| x = layers.ReLU()(x) |
| x = layers.Conv2D(32, 3, padding='same', use_bias=False)(x) |
| x = layers.BatchNormalization()(x) |
| x = layers.ReLU()(x) |
| |
| |
| x = GatedInceptionBlock( |
| filters_1x1=16, |
| filters_3x3_reduce=16, filters_3x3=24, |
| filters_5x5_reduce=8, filters_5x5=12, |
| filters_pool=12, |
| name='gated_inception_1' |
| )(x) |
| x = layers.MaxPooling2D(2)(x) |
| |
| |
| x = GatedInceptionBlock( |
| filters_1x1=32, |
| filters_3x3_reduce=24, filters_3x3=48, |
| filters_5x5_reduce=12, filters_5x5=24, |
| filters_pool=24, |
| name='gated_inception_2' |
| )(x) |
| x = layers.MaxPooling2D(2)(x) |
| |
| |
| x = CoordinateAttention(reduction_ratio=4, name='coord_attention')(x) |
| |
| |
| x = GatedInceptionBlock( |
| filters_1x1=48, |
| filters_3x3_reduce=32, filters_3x3=64, |
| filters_5x5_reduce=16, filters_5x5=32, |
| filters_pool=32, |
| name='gated_inception_3' |
| )(x) |
| x = layers.MaxPooling2D(2)(x) |
| |
| |
| x = layers.GlobalAveragePooling2D()(x) |
| |
| |
| x = layers.Dense(128, activation='relu')(x) |
| x = layers.Dropout(0.3)(x) |
| x = layers.Dense(64, activation='relu')(x) |
| x = layers.Dropout(0.2)(x) |
| |
| |
| outputs = layers.Dense(num_outputs, activation='sigmoid', name='gaze_coords')(x) |
| |
| model = Model(inputs=inputs, outputs=outputs, name='GazeInceptionLite') |
| return model |
|
|
|
|
| def build_dual_eye_model(eye_shape=(64, 64, 3), face_shape=(64, 64, 3), num_outputs=2): |
| """ |
| Full model with dual eye inputs + face context. |
| |
| This handles lazy eye by processing each eye independently through |
| shared-weight gated inception, then combining with face features. |
| Each eye gets its own gaze features, and the model learns to handle |
| asymmetric eye conditions (strabismus/amblyopia). |
| |
| Inputs: |
| - left_eye: 64x64x3 crop |
| - right_eye: 64x64x3 crop |
| - face: 64x64x3 crop (provides head pose context) |
| |
| Output: |
| - (x, y) normalized screen coordinates |
| """ |
| left_eye_input = layers.Input(shape=eye_shape, name='left_eye') |
| right_eye_input = layers.Input(shape=eye_shape, name='right_eye') |
| face_input = layers.Input(shape=face_shape, name='face') |
| |
| |
| eye_backbone = build_gaze_inception_lite(input_shape=eye_shape, num_outputs=2) |
| |
| |
| gap_layer = None |
| for layer in eye_backbone.layers: |
| if isinstance(layer, layers.GlobalAveragePooling2D): |
| gap_layer = layer |
| eye_feature_extractor = Model( |
| inputs=eye_backbone.input, |
| outputs=gap_layer.output, |
| name='eye_feature_extractor' |
| ) |
| |
| |
| left_features = eye_feature_extractor(left_eye_input) |
| right_features = eye_feature_extractor(right_eye_input) |
| |
| |
| f = layers.Conv2D(16, 3, strides=2, padding='same', activation='relu')(face_input) |
| f = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(f) |
| f = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(f) |
| f = layers.GlobalAveragePooling2D()(f) |
| face_features = layers.Dense(64, activation='relu')(f) |
| |
| |
| |
| combined = layers.Concatenate()([left_features, right_features, face_features]) |
| |
| |
| x = layers.Dense(128, activation='relu')(combined) |
| x = layers.Dropout(0.3)(x) |
| x = layers.Dense(64, activation='relu')(x) |
| x = layers.Dropout(0.2)(x) |
| outputs = layers.Dense(num_outputs, activation='sigmoid', name='gaze_coords')(x) |
| |
| model = Model( |
| inputs=[left_eye_input, right_eye_input, face_input], |
| outputs=outputs, |
| name='GazeInceptionLite_DualEye' |
| ) |
| return model |
|
|
|
|
| if __name__ == '__main__': |
| |
| model_single = build_gaze_inception_lite() |
| model_single.summary() |
| print(f"\nSingle eye model params: {model_single.count_params():,}") |
| |
| |
| test_input = np.random.rand(2, 64, 64, 3).astype(np.float32) |
| output = model_single(test_input) |
| print(f"Output shape: {output.shape}") |
| print(f"Output values: {output.numpy()}") |
| |
| print("\n" + "="*60) |
| |
| |
| model_dual = build_dual_eye_model() |
| model_dual.summary() |
| print(f"\nDual eye model params: {model_dual.count_params():,}") |
| |
| test_left = np.random.rand(2, 64, 64, 3).astype(np.float32) |
| test_right = np.random.rand(2, 64, 64, 3).astype(np.float32) |
| test_face = np.random.rand(2, 64, 64, 3).astype(np.float32) |
| output = model_dual([test_left, test_right, test_face]) |
| print(f"Output shape: {output.shape}") |
| print(f"Output values: {output.numpy()}") |
|
|