dagloop5 commited on
Commit
7e969f2
·
verified ·
1 Parent(s): fbda38d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -69
app.py CHANGED
@@ -178,83 +178,86 @@ print("=" * 80)
178
  # =============================================================================
179
  # ZeroGPU Tensor Preloading
180
  # =============================================================================
181
- print("Preloading all models for ZeroGPU tensor packing...")
 
 
 
 
 
182
  print("This may take a few minutes...")
183
 
184
- # TI2VidTwoStagesHQPipeline uses:
185
- # - Builder methods that return models directly when called
186
- # - Context methods that return context managers when called (require streaming_prefetch_count)
187
- # We need to call these methods, capture the results, and preserve them
188
-
189
- # 1. Load transformer via _transformer_ctx(streaming_prefetch_count=None)
190
- print(" Loading stage 1 transformer...")
191
- _ctx = pipeline.stage_1._transformer_ctx(streaming_prefetch_count=None)
192
- _ctx.__enter__()
193
- _stage_1_transformer = _ctx.__dict__.get('transformer') or \
194
- getattr(pipeline.stage_1, '_transformer', None)
195
- pipeline.stage_1._transformer_ctx = lambda streaming_prefetch_count=None: _ctx
196
- print(f" Captured stage 1 transformer: {type(_stage_1_transformer)}")
197
-
198
- print(" Loading stage 2 transformer...")
199
- _ctx = pipeline.stage_2._transformer_ctx(streaming_prefetch_count=None)
200
- _ctx.__enter__()
201
- _stage_2_transformer = _ctx.__dict__.get('transformer') or \
202
- getattr(pipeline.stage_2, '_transformer', None)
203
- pipeline.stage_2._transformer_ctx = lambda streaming_prefetch_count=None: _ctx
204
- print(f" Captured stage 2 transformer: {type(_stage_2_transformer)}")
205
-
206
- # 2. Load text encoder via _text_encoder_ctx(streaming_prefetch_count=None)
207
- print(" Loading Gemma text encoder...")
208
- _ctx = pipeline.prompt_encoder._text_encoder_ctx(streaming_prefetch_count=None)
209
- _ctx.__enter__()
210
- _text_encoder = _ctx.__dict__.get('text_encoder') or \
211
- getattr(pipeline.prompt_encoder, '_text_encoder', None)
212
- pipeline.prompt_encoder._text_encoder = _text_encoder
213
- pipeline.prompt_encoder._text_encoder_ctx = lambda streaming_prefetch_count=None: _ctx
214
- print(f" Captured text encoder: {type(_text_encoder)}")
215
-
216
- # 3. Load video encoder (builder method - returns model directly)
217
  print(" Loading video encoder...")
218
- _video_encoder = pipeline.prompt_encoder.video_encoder()
219
- pipeline.prompt_encoder.video_encoder = lambda: _video_encoder
220
- print(f" Captured video encoder: {type(_video_encoder)}")
221
-
222
- # 4. Load video decoder (builder method)
 
 
 
223
  print(" Loading video decoder...")
224
- _video_decoder = pipeline.video_decoder._decoder_builder()
225
- pipeline.video_decoder._decoder_builder = lambda: _video_decoder
226
- if hasattr(pipeline.video_decoder, '_decoder'):
227
- pipeline.video_decoder._decoder = _video_decoder
228
- print(f" Captured video decoder: {type(_video_decoder)}")
229
-
230
- # 5. Load audio decoder (builder method)
 
 
 
231
  print(" Loading audio decoder...")
232
- _audio_decoder = pipeline.audio_decoder._decoder_builder()
233
- pipeline.audio_decoder._decoder_builder = lambda: _audio_decoder
234
- if hasattr(pipeline.audio_decoder, '_decoder'):
235
- pipeline.audio_decoder._decoder = _audio_decoder
236
- print(f" Captured audio decoder: {type(_audio_decoder)}")
237
-
238
- # 6. Load vocoder (builder method)
 
 
 
239
  print(" Loading vocoder...")
240
- if hasattr(pipeline.audio_decoder, '_vocoder_builder'):
241
- _vocoder = pipeline.audio_decoder._vocoder_builder()
242
- pipeline.audio_decoder._vocoder_builder = lambda: _vocoder
243
- print(f" Captured vocoder: {type(_vocoder)}")
244
-
245
- # 7. Load spatial upsampler (builder method)
 
 
 
246
  print(" Loading spatial upsampler...")
247
- _spatial_upsampler = pipeline.upsampler._upsampler_builder()
248
- pipeline.upsampler._upsampler_builder = lambda: _spatial_upsampler
249
- if hasattr(pipeline.upsampler, '_encoder'):
250
- pipeline.upsampler._encoder = _spatial_upsampler
251
- print(f" Captured spatial upsampler: {type(_spatial_upsampler)}")
252
-
253
- # 8. Load image conditioner
 
 
 
254
  print(" Loading image conditioner...")
255
- if hasattr(pipeline, 'image_conditioner'):
256
- if hasattr(pipeline.image_conditioner, 'video_encoder'):
257
- _ic_encoder = pipeline.image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  # =============================================================================
260
  # Helper Functions
 
178
  # =============================================================================
179
  # ZeroGPU Tensor Preloading
180
  # =============================================================================
181
+ # NOTE: At Space startup, no GPU is available (ZeroGPU assigns it at runtime).
182
+ # We can only preload components that don't require CUDA.
183
+ # The transformer (and other GPU-heavy components) will load during generation
184
+ # when ZeroGPU provides the GPU. ZeroGPU should capture them then.
185
+
186
+ print("Preloading non-CUDA components for ZeroGPU tensor packing...")
187
  print("This may take a few minutes...")
188
 
189
+ # 1. Try loading video encoder (may work without GPU if just file loading)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  print(" Loading video encoder...")
191
+ try:
192
+ _video_encoder = pipeline.prompt_encoder.video_encoder()
193
+ pipeline.prompt_encoder.video_encoder = lambda: _video_encoder
194
+ print(f" Loaded video encoder: {type(_video_encoder)}")
195
+ except Exception as e:
196
+ print(f" Video encoder preload skipped: {e}")
197
+
198
+ # 2. Try loading video decoder (VAE - may work without GPU)
199
  print(" Loading video decoder...")
200
+ try:
201
+ _video_decoder = pipeline.video_decoder._decoder_builder()
202
+ pipeline.video_decoder._decoder_builder = lambda: _video_decoder
203
+ if hasattr(pipeline.video_decoder, '_decoder'):
204
+ pipeline.video_decoder._decoder = _video_decoder
205
+ print(f" Loaded video decoder: {type(_video_decoder)}")
206
+ except Exception as e:
207
+ print(f" Video decoder preload skipped: {e}")
208
+
209
+ # 3. Try loading audio decoder (VAE - may work without GPU)
210
  print(" Loading audio decoder...")
211
+ try:
212
+ _audio_decoder = pipeline.audio_decoder._decoder_builder()
213
+ pipeline.audio_decoder._decoder_builder = lambda: _audio_decoder
214
+ if hasattr(pipeline.audio_decoder, '_decoder'):
215
+ pipeline.audio_decoder._decoder = _audio_decoder
216
+ print(f" Loaded audio decoder: {type(_audio_decoder)}")
217
+ except Exception as e:
218
+ print(f" Audio decoder preload skipped: {e}")
219
+
220
+ # 4. Try loading vocoder
221
  print(" Loading vocoder...")
222
+ try:
223
+ if hasattr(pipeline.audio_decoder, '_vocoder_builder'):
224
+ _vocoder = pipeline.audio_decoder._vocoder_builder()
225
+ pipeline.audio_decoder._vocoder_builder = lambda: _vocoder
226
+ print(f" Loaded vocoder: {type(_vocoder)}")
227
+ except Exception as e:
228
+ print(f" Vocoder preload skipped: {e}")
229
+
230
+ # 5. Try loading spatial upsampler
231
  print(" Loading spatial upsampler...")
232
+ try:
233
+ _spatial_upsampler = pipeline.upsampler._upsampler_builder()
234
+ pipeline.upsampler._upsampler_builder = lambda: _spatial_upsampler
235
+ if hasattr(pipeline.upsampler, '_encoder'):
236
+ pipeline.upsampler._encoder = _spatial_upsampler
237
+ print(f" Loaded spatial upsampler: {type(_spatial_upsampler)}")
238
+ except Exception as e:
239
+ print(f" Spatial upsampler preload skipped: {e}")
240
+
241
+ # 6. Load image conditioner
242
  print(" Loading image conditioner...")
243
+ try:
244
+ if hasattr(pipeline, 'image_conditioner'):
245
+ if hasattr(pipeline.image_conditioner, 'video_encoder'):
246
+ _ic_encoder = pipeline.image_conditioner.video_encoder()
247
+ pipeline.image_conditioner.video_encoder = lambda: _ic_encoder
248
+ print(f" Loaded image conditioner encoder")
249
+ except Exception as e:
250
+ print(f" Image conditioner preload skipped: {e}")
251
+
252
+ # 7. NOTE: Transformer loading is intentionally skipped here
253
+ # The transformer requires CUDA (LoRA fusion uses triton kernels)
254
+ # It will load during generate_video() when ZeroGPU provides a GPU
255
+ # ZeroGPU should capture it then
256
+ print(" Transformer: Will load during generation (requires GPU)")
257
+ print(" Text encoder: Will load during generation (requires GPU)")
258
+
259
+ print("Non-CUDA components preloaded!")
260
+ print("=" * 80)
261
 
262
  # =============================================================================
263
  # Helper Functions