dagloop5 commited on
Commit
080a03d
·
verified ·
1 Parent(s): e2fe30b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -39
app.py CHANGED
@@ -181,65 +181,55 @@ print("=" * 80)
181
  print("Preloading all models for ZeroGPU tensor packing...")
182
  print("This may take a few minutes...")
183
 
184
- # The TI2VidTwoStagesHQPipeline uses context managers for lazy loading.
185
- # We need to enter the contexts, capture the loaded models, AND preserve them
186
- # by replacing the pipeline's internal references with lambdas that hold them.
187
- # This prevents garbage collection and allows ZeroGPU to pack them.
188
 
189
- # 1. Load transformer via _transformer_ctx (enter context to load, store result)
190
  print(" Loading stage 1 transformer...")
191
- pipeline.stage_1._transformer_ctx.__enter__()
192
- # Capture the actual model from the context
193
- _stage_1_transformer = pipeline.stage_1._transformer_ctx.__dict__.get('transformer') or \
194
  getattr(pipeline.stage_1, '_transformer', None)
195
- # Replace _transformer_ctx with lambda that returns the captured model
196
- pipeline.stage_1._transformer_ctx = type('ctx', (), {
197
- '__enter__': lambda s: _stage_1_transformer,
198
- '__exit__': lambda s, *a: None,
199
- '__call__': lambda s, *a, **kw: _stage_1_transformer(*a, **kw)
200
- })()
201
  print(f" Captured stage 1 transformer: {type(_stage_1_transformer)}")
202
 
203
  print(" Loading stage 2 transformer...")
204
- pipeline.stage_2._transformer_ctx.__enter__()
205
- _stage_2_transformer = pipeline.stage_2._transformer_ctx.__dict__.get('transformer') or \
 
206
  getattr(pipeline.stage_2, '_transformer', None)
207
- pipeline.stage_2._transformer_ctx = type('ctx', (), {
208
- '__enter__': lambda s: _stage_2_transformer,
209
- '__exit__': lambda s, *a: None,
210
- '__call__': lambda s, *a, **kw: _stage_2_transformer(*a, **kw)
211
- })()
212
  print(f" Captured stage 2 transformer: {type(_stage_2_transformer)}")
213
 
214
- # 2. Load text encoder via _text_encoder_ctx
215
  print(" Loading Gemma text encoder...")
216
- pipeline.prompt_encoder._text_encoder_ctx.__enter__()
217
- _text_encoder = pipeline.prompt_encoder._text_encoder_ctx.__dict__.get('text_encoder') or \
 
218
  getattr(pipeline.prompt_encoder, '_text_encoder', None)
219
- # Store as instance attribute and create replacement context
220
  pipeline.prompt_encoder._text_encoder = _text_encoder
221
- pipeline.prompt_encoder._text_encoder_ctx = type('ctx', (), {
222
- '__enter__': lambda s: _text_encoder,
223
- '__exit__': lambda s, *a: None
224
- })()
225
  print(f" Captured text encoder: {type(_text_encoder)}")
226
 
227
- # 3. Load video encoder (from prompt_encoder's video_encoder method)
228
  print(" Loading video encoder...")
229
  _video_encoder = pipeline.prompt_encoder.video_encoder()
230
  pipeline.prompt_encoder.video_encoder = lambda: _video_encoder
231
  print(f" Captured video encoder: {type(_video_encoder)}")
232
 
233
- # 4. Load video decoder via _decoder_builder
234
  print(" Loading video decoder...")
235
  _video_decoder = pipeline.video_decoder._decoder_builder()
236
  pipeline.video_decoder._decoder_builder = lambda: _video_decoder
237
- # Also try direct model attribute if exists
238
  if hasattr(pipeline.video_decoder, '_decoder'):
239
  pipeline.video_decoder._decoder = _video_decoder
240
  print(f" Captured video decoder: {type(_video_decoder)}")
241
 
242
- # 5. Load audio decoder via _decoder_builder
243
  print(" Loading audio decoder...")
244
  _audio_decoder = pipeline.audio_decoder._decoder_builder()
245
  pipeline.audio_decoder._decoder_builder = lambda: _audio_decoder
@@ -247,18 +237,17 @@ if hasattr(pipeline.audio_decoder, '_decoder'):
247
  pipeline.audio_decoder._decoder = _audio_decoder
248
  print(f" Captured audio decoder: {type(_audio_decoder)}")
249
 
250
- # 6. Load vocoder (audio decoder has _vocoder_builder)
251
  print(" Loading vocoder...")
252
  if hasattr(pipeline.audio_decoder, '_vocoder_builder'):
253
  _vocoder = pipeline.audio_decoder._vocoder_builder()
254
  pipeline.audio_decoder._vocoder_builder = lambda: _vocoder
255
  print(f" Captured vocoder: {type(_vocoder)}")
256
 
257
- # 7. Load spatial upsampler via _upsampler_builder
258
  print(" Loading spatial upsampler...")
259
  _spatial_upsampler = pipeline.upsampler._upsampler_builder()
260
  pipeline.upsampler._upsampler_builder = lambda: _spatial_upsampler
261
- # Also try _encoder_builder
262
  if hasattr(pipeline.upsampler, '_encoder'):
263
  pipeline.upsampler._encoder = _spatial_upsampler
264
  print(f" Captured spatial upsampler: {type(_spatial_upsampler)}")
@@ -270,10 +259,7 @@ if hasattr(pipeline, 'image_conditioner'):
270
  _ic_encoder = pipeline.image_conditioner.video_encoder()
271
  pipeline.image_conditioner.video_encoder = lambda: _ic_encoder
272
 
273
- # Create references to prevent garbage collection
274
- # At module level, variables are already global - no 'global' keyword needed
275
  print(" Models captured and preserved for ZeroGPU tensor packing...")
276
-
277
  print("All models preloaded for ZeroGPU tensor packing!")
278
  print("=" * 80)
279
 
 
181
  print("Preloading all models for ZeroGPU tensor packing...")
182
  print("This may take a few minutes...")
183
 
184
+ # TI2VidTwoStagesHQPipeline uses:
185
+ # - Builder methods that return models directly when called
186
+ # - Context methods that return context managers when called
187
+ # We need to call these methods, capture the results, and preserve them
188
 
189
+ # 1. Load transformer via _transformer_ctx() (call first, then enter)
190
  print(" Loading stage 1 transformer...")
191
+ _ctx = pipeline.stage_1._transformer_ctx() # Get context manager
192
+ _ctx.__enter__() # Enter context
193
+ _stage_1_transformer = _ctx.__dict__.get('transformer') or \
194
  getattr(pipeline.stage_1, '_transformer', None)
195
+ # Replace _transformer_ctx with a lambda that returns cached model
196
+ pipeline.stage_1._transformer_ctx = lambda: _ctx
 
 
 
 
197
  print(f" Captured stage 1 transformer: {type(_stage_1_transformer)}")
198
 
199
  print(" Loading stage 2 transformer...")
200
+ _ctx = pipeline.stage_2._transformer_ctx()
201
+ _ctx.__enter__()
202
+ _stage_2_transformer = _ctx.__dict__.get('transformer') or \
203
  getattr(pipeline.stage_2, '_transformer', None)
204
+ pipeline.stage_2._transformer_ctx = lambda: _ctx
 
 
 
 
205
  print(f" Captured stage 2 transformer: {type(_stage_2_transformer)}")
206
 
207
+ # 2. Load text encoder via _text_encoder_ctx() (call first, then enter)
208
  print(" Loading Gemma text encoder...")
209
+ _ctx = pipeline.prompt_encoder._text_encoder_ctx()
210
+ _ctx.__enter__()
211
+ _text_encoder = _ctx.__dict__.get('text_encoder') or \
212
  getattr(pipeline.prompt_encoder, '_text_encoder', None)
213
+ # Store as instance attribute and create replacement lambda
214
  pipeline.prompt_encoder._text_encoder = _text_encoder
215
+ pipeline.prompt_encoder._text_encoder_ctx = lambda: _ctx
 
 
 
216
  print(f" Captured text encoder: {type(_text_encoder)}")
217
 
218
+ # 3. Load video encoder (builder method - returns model directly)
219
  print(" Loading video encoder...")
220
  _video_encoder = pipeline.prompt_encoder.video_encoder()
221
  pipeline.prompt_encoder.video_encoder = lambda: _video_encoder
222
  print(f" Captured video encoder: {type(_video_encoder)}")
223
 
224
+ # 4. Load video decoder (builder method)
225
  print(" Loading video decoder...")
226
  _video_decoder = pipeline.video_decoder._decoder_builder()
227
  pipeline.video_decoder._decoder_builder = lambda: _video_decoder
 
228
  if hasattr(pipeline.video_decoder, '_decoder'):
229
  pipeline.video_decoder._decoder = _video_decoder
230
  print(f" Captured video decoder: {type(_video_decoder)}")
231
 
232
+ # 5. Load audio decoder (builder method)
233
  print(" Loading audio decoder...")
234
  _audio_decoder = pipeline.audio_decoder._decoder_builder()
235
  pipeline.audio_decoder._decoder_builder = lambda: _audio_decoder
 
237
  pipeline.audio_decoder._decoder = _audio_decoder
238
  print(f" Captured audio decoder: {type(_audio_decoder)}")
239
 
240
+ # 6. Load vocoder (builder method)
241
  print(" Loading vocoder...")
242
  if hasattr(pipeline.audio_decoder, '_vocoder_builder'):
243
  _vocoder = pipeline.audio_decoder._vocoder_builder()
244
  pipeline.audio_decoder._vocoder_builder = lambda: _vocoder
245
  print(f" Captured vocoder: {type(_vocoder)}")
246
 
247
+ # 7. Load spatial upsampler (builder method)
248
  print(" Loading spatial upsampler...")
249
  _spatial_upsampler = pipeline.upsampler._upsampler_builder()
250
  pipeline.upsampler._upsampler_builder = lambda: _spatial_upsampler
 
251
  if hasattr(pipeline.upsampler, '_encoder'):
252
  pipeline.upsampler._encoder = _spatial_upsampler
253
  print(f" Captured spatial upsampler: {type(_spatial_upsampler)}")
 
259
  _ic_encoder = pipeline.image_conditioner.video_encoder()
260
  pipeline.image_conditioner.video_encoder = lambda: _ic_encoder
261
 
 
 
262
  print(" Models captured and preserved for ZeroGPU tensor packing...")
 
263
  print("All models preloaded for ZeroGPU tensor packing!")
264
  print("=" * 80)
265