File size: 14,230 Bytes
687b215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
"""
GazeInception-Lite: Gated Inception Model for Mobile Eye Gaze Estimation

Architecture:
  - Input: 64x64 RGB eye crop (left + right eye stacked as 2-channel or 128x64 side-by-side)
  - Gated Inception Blocks: Each inception block has a lightweight gate (squeeze-excitation style)
    that learns to skip branches that contribute little, reducing useless compute
  - Multi-scale feature extraction via inception (1x1, 3x3, 5x5 parallel convolutions)
  - Coordinate Attention for spatial awareness
  - Output: (x, y) screen coordinates normalized to [0, 1]

Design goals:
  - < 500K parameters for fast mobile inference
  - TFLite compatible (no unsupported ops)
  - Works in dark (trained with illumination augmentation)
  - Handles glasses (trained with glasses augmentation)
  - Handles lazy eye / strabismus (trained with per-eye asymmetric augmentation)
"""

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
import numpy as np


class GatedInceptionBlock(layers.Layer):
    """
    Inception block with gating mechanism.
    
    The gate is a lightweight learned sigmoid that scales each inception branch.
    Branches with low gate values contribute near-zero, effectively being "skipped"
    at inference — reducing useless compute via learned conditional computation.
    
    Branches:
      1. 1x1 conv (point features)
      2. 1x1 -> 3x3 conv (local features)  
      3. 1x1 -> 5x5 depthwise separable conv (wider context)
      4. 3x3 max pool -> 1x1 conv (pooled features)
    
    Gate: Global Average Pool -> Dense -> Sigmoid per branch
    """
    
    def __init__(self, filters_1x1, filters_3x3_reduce, filters_3x3,
                 filters_5x5_reduce, filters_5x5, filters_pool, **kwargs):
        super().__init__(**kwargs)
        self.filters_1x1 = filters_1x1
        self.filters_3x3 = filters_3x3
        self.filters_5x5 = filters_5x5
        self.filters_pool = filters_pool
        self.num_branches = 4
        
        # Branch 1: 1x1
        self.branch1_conv = layers.Conv2D(filters_1x1, 1, padding='same', use_bias=False)
        self.branch1_bn = layers.BatchNormalization()
        
        # Branch 2: 1x1 -> 3x3
        self.branch2_reduce = layers.Conv2D(filters_3x3_reduce, 1, padding='same', use_bias=False)
        self.branch2_reduce_bn = layers.BatchNormalization()
        self.branch2_conv = layers.DepthwiseConv2D(3, padding='same', use_bias=False)
        self.branch2_pw = layers.Conv2D(filters_3x3, 1, padding='same', use_bias=False)
        self.branch2_bn = layers.BatchNormalization()
        
        # Branch 3: 1x1 -> 5x5 depthwise separable
        self.branch3_reduce = layers.Conv2D(filters_5x5_reduce, 1, padding='same', use_bias=False)
        self.branch3_reduce_bn = layers.BatchNormalization()
        self.branch3_dw = layers.DepthwiseConv2D(5, padding='same', use_bias=False)
        self.branch3_pw = layers.Conv2D(filters_5x5, 1, padding='same', use_bias=False)
        self.branch3_bn = layers.BatchNormalization()
        
        # Branch 4: MaxPool -> 1x1
        self.branch4_pool = layers.MaxPooling2D(3, strides=1, padding='same')
        self.branch4_conv = layers.Conv2D(filters_pool, 1, padding='same', use_bias=False)
        self.branch4_bn = layers.BatchNormalization()
        
        # Gating mechanism: learns to weight each branch
        total_filters = filters_1x1 + filters_3x3 + filters_5x5 + filters_pool
        self.gate_pool = layers.GlobalAveragePooling2D()
        self.gate_dense1 = layers.Dense(self.num_branches * 4, activation='relu')
        self.gate_dense2 = layers.Dense(self.num_branches, activation='sigmoid')
        
        # Final activation
        self.relu = layers.ReLU()
    
    def call(self, x, training=False):
        # Compute gate values (which branches to activate)
        gate_input = self.gate_pool(x)
        gate = self.gate_dense1(gate_input)
        gate = self.gate_dense2(gate)  # [batch, 4] sigmoid values
        
        # Branch 1
        b1 = self.branch1_conv(x)
        b1 = self.branch1_bn(b1, training=training)
        b1 = self.relu(b1)
        
        # Branch 2
        b2 = self.branch2_reduce(x)
        b2 = self.branch2_reduce_bn(b2, training=training)
        b2 = self.relu(b2)
        b2 = self.branch2_conv(b2)
        b2 = self.branch2_pw(b2)
        b2 = self.branch2_bn(b2, training=training)
        b2 = self.relu(b2)
        
        # Branch 3
        b3 = self.branch3_reduce(x)
        b3 = self.branch3_reduce_bn(b3, training=training)
        b3 = self.relu(b3)
        b3 = self.branch3_dw(b3)
        b3 = self.branch3_pw(b3)
        b3 = self.branch3_bn(b3, training=training)
        b3 = self.relu(b3)
        
        # Branch 4
        b4 = self.branch4_pool(x)
        b4 = self.branch4_conv(b4)
        b4 = self.branch4_bn(b4, training=training)
        b4 = self.relu(b4)
        
        # Apply gates: multiply each branch by its gate scalar
        # gate[:, i] is a scalar per sample - reshape for broadcasting
        g1 = tf.reshape(gate[:, 0], [-1, 1, 1, 1])
        g2 = tf.reshape(gate[:, 1], [-1, 1, 1, 1])
        g3 = tf.reshape(gate[:, 2], [-1, 1, 1, 1])
        g4 = tf.reshape(gate[:, 3], [-1, 1, 1, 1])
        
        b1 = b1 * g1
        b2 = b2 * g2
        b3 = b3 * g3
        b4 = b4 * g4
        
        # Concatenate gated branches
        return tf.concat([b1, b2, b3, b4], axis=-1)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'filters_1x1': self.filters_1x1,
            'filters_3x3_reduce': self.branch2_reduce.filters if hasattr(self.branch2_reduce, 'filters') else 0,
            'filters_3x3': self.filters_3x3,
            'filters_5x5_reduce': self.branch3_reduce.filters if hasattr(self.branch3_reduce, 'filters') else 0,
            'filters_5x5': self.filters_5x5,
            'filters_pool': self.filters_pool,
        })
        return config


class CoordinateAttention(layers.Layer):
    """
    Coordinate Attention module (Hou et al. 2021).
    Encodes spatial position info into channel attention for better localization.
    Critical for gaze estimation where spatial position of iris matters.
    """
    
    def __init__(self, reduction_ratio=4, **kwargs):
        super().__init__(**kwargs)
        self.reduction_ratio = reduction_ratio
    
    def build(self, input_shape):
        channels = input_shape[-1]
        reduced_channels = max(channels // self.reduction_ratio, 8)
        
        self.pool_h = layers.Lambda(lambda x: tf.reduce_mean(x, axis=2, keepdims=True))
        self.pool_w = layers.Lambda(lambda x: tf.reduce_mean(x, axis=1, keepdims=True))
        
        self.conv_reduce = layers.Conv2D(reduced_channels, 1, use_bias=False)
        self.bn = layers.BatchNormalization()
        self.relu = layers.ReLU()
        
        self.conv_h = layers.Conv2D(channels, 1, activation='sigmoid')
        self.conv_w = layers.Conv2D(channels, 1, activation='sigmoid')
        
        super().build(input_shape)
    
    def call(self, x, training=False):
        # Pool along width (keep height)
        h_att = self.pool_h(x)  # [B, H, 1, C]
        # Pool along height (keep width)
        w_att = self.pool_w(x)  # [B, 1, W, C]
        
        # Transpose w_att to match h_att shape for concatenation
        w_att_t = tf.transpose(w_att, perm=[0, 2, 1, 3])  # [B, W, 1, C]
        
        # Concatenate and reduce
        combined = tf.concat([h_att, w_att_t], axis=1)  # [B, H+W, 1, C]
        combined = self.conv_reduce(combined)
        combined = self.bn(combined, training=training)
        combined = self.relu(combined)
        
        # Split back
        h_len = tf.shape(h_att)[1]
        w_len = tf.shape(w_att_t)[1]
        
        h_out = combined[:, :h_len, :, :]
        w_out = combined[:, h_len:, :, :]
        
        # Generate attention maps
        h_att_map = self.conv_h(h_out)  # [B, H, 1, C]
        w_att_map = self.conv_w(w_out)  # [B, W, 1, C]
        w_att_map = tf.transpose(w_att_map, perm=[0, 2, 1, 3])  # [B, 1, W, C]
        
        # Apply attention
        return x * h_att_map * w_att_map


def build_gaze_inception_lite(input_shape=(64, 64, 3), num_outputs=2):
    """
    Build the GazeInception-Lite model.
    
    Architecture:
      Input (64x64x3) -> Stem -> GatedInception1 -> GatedInception2 -> 
      CoordAttention -> GatedInception3 -> GlobalPool -> Dense -> (x, y)
    
    Total: ~350K parameters
    """
    inputs = layers.Input(shape=input_shape, name='eye_image')
    
    # Stem: lightweight feature extraction
    x = layers.Conv2D(32, 3, strides=2, padding='same', use_bias=False)(inputs)  # 32x32
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(32, 3, padding='same', use_bias=False)(x)  # 32x32
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    # Gated Inception Block 1 (32x32 -> 16x16)
    x = GatedInceptionBlock(
        filters_1x1=16, 
        filters_3x3_reduce=16, filters_3x3=24,
        filters_5x5_reduce=8, filters_5x5=12,
        filters_pool=12,
        name='gated_inception_1'
    )(x)  # output: 64 channels
    x = layers.MaxPooling2D(2)(x)  # 16x16
    
    # Gated Inception Block 2 (16x16 -> 8x8)
    x = GatedInceptionBlock(
        filters_1x1=32,
        filters_3x3_reduce=24, filters_3x3=48,
        filters_5x5_reduce=12, filters_5x5=24,
        filters_pool=24,
        name='gated_inception_2'
    )(x)  # output: 128 channels
    x = layers.MaxPooling2D(2)(x)  # 8x8
    
    # Coordinate Attention - encodes spatial position for gaze direction
    x = CoordinateAttention(reduction_ratio=4, name='coord_attention')(x)
    
    # Gated Inception Block 3 (8x8 -> 4x4)
    x = GatedInceptionBlock(
        filters_1x1=48,
        filters_3x3_reduce=32, filters_3x3=64,
        filters_5x5_reduce=16, filters_5x5=32,
        filters_pool=32,
        name='gated_inception_3'
    )(x)  # output: 176 channels
    x = layers.MaxPooling2D(2)(x)  # 4x4
    
    # Global feature aggregation
    x = layers.GlobalAveragePooling2D()(x)
    
    # Regression head
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    
    # Output: (x, y) screen coordinates in [0, 1]
    outputs = layers.Dense(num_outputs, activation='sigmoid', name='gaze_coords')(x)
    
    model = Model(inputs=inputs, outputs=outputs, name='GazeInceptionLite')
    return model


def build_dual_eye_model(eye_shape=(64, 64, 3), face_shape=(64, 64, 3), num_outputs=2):
    """
    Full model with dual eye inputs + face context.
    
    This handles lazy eye by processing each eye independently through
    shared-weight gated inception, then combining with face features.
    Each eye gets its own gaze features, and the model learns to handle
    asymmetric eye conditions (strabismus/amblyopia).
    
    Inputs:
      - left_eye: 64x64x3 crop
      - right_eye: 64x64x3 crop  
      - face: 64x64x3 crop (provides head pose context)
    
    Output:
      - (x, y) normalized screen coordinates
    """
    left_eye_input = layers.Input(shape=eye_shape, name='left_eye')
    right_eye_input = layers.Input(shape=eye_shape, name='right_eye')
    face_input = layers.Input(shape=face_shape, name='face')
    
    # Shared eye feature extractor (gated inception backbone)
    eye_backbone = build_gaze_inception_lite(input_shape=eye_shape, num_outputs=2)
    # Get features from the GlobalAveragePooling layer (before dense head)
    # Find the GlobalAveragePooling2D layer
    gap_layer = None
    for layer in eye_backbone.layers:
        if isinstance(layer, layers.GlobalAveragePooling2D):
            gap_layer = layer
    eye_feature_extractor = Model(
        inputs=eye_backbone.input,
        outputs=gap_layer.output,
        name='eye_feature_extractor'
    )
    
    # Extract features for each eye independently (shared weights)
    left_features = eye_feature_extractor(left_eye_input)   # [B, 176]
    right_features = eye_feature_extractor(right_eye_input)  # [B, 176]
    
    # Lightweight face context extractor (head pose proxy)
    f = layers.Conv2D(16, 3, strides=2, padding='same', activation='relu')(face_input)
    f = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(f)
    f = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(f)
    f = layers.GlobalAveragePooling2D()(f)
    face_features = layers.Dense(64, activation='relu')(f)  # [B, 64]
    
    # Combine: left_eye + right_eye + face
    # The model learns eye asymmetry (lazy eye) because eyes are separate inputs
    combined = layers.Concatenate()([left_features, right_features, face_features])
    
    # Fusion head
    x = layers.Dense(128, activation='relu')(combined)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(num_outputs, activation='sigmoid', name='gaze_coords')(x)
    
    model = Model(
        inputs=[left_eye_input, right_eye_input, face_input],
        outputs=outputs,
        name='GazeInceptionLite_DualEye'
    )
    return model


if __name__ == '__main__':
    # Test single eye model
    model_single = build_gaze_inception_lite()
    model_single.summary()
    print(f"\nSingle eye model params: {model_single.count_params():,}")
    
    # Test with random input
    test_input = np.random.rand(2, 64, 64, 3).astype(np.float32)
    output = model_single(test_input)
    print(f"Output shape: {output.shape}")
    print(f"Output values: {output.numpy()}")
    
    print("\n" + "="*60)
    
    # Test dual eye model  
    model_dual = build_dual_eye_model()
    model_dual.summary()
    print(f"\nDual eye model params: {model_dual.count_params():,}")
    
    test_left = np.random.rand(2, 64, 64, 3).astype(np.float32)
    test_right = np.random.rand(2, 64, 64, 3).astype(np.float32)
    test_face = np.random.rand(2, 64, 64, 3).astype(np.float32)
    output = model_dual([test_left, test_right, test_face])
    print(f"Output shape: {output.shape}")
    print(f"Output values: {output.numpy()}")