dagloop5 commited on
Commit
fbda38d
·
verified ·
1 Parent(s): 080a03d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -18
app.py CHANGED
@@ -183,36 +183,34 @@ print("This may take a few minutes...")
183
 
184
  # TI2VidTwoStagesHQPipeline uses:
185
  # - Builder methods that return models directly when called
186
- # - Context methods that return context managers when called
187
  # We need to call these methods, capture the results, and preserve them
188
 
189
- # 1. Load transformer via _transformer_ctx() (call first, then enter)
190
  print(" Loading stage 1 transformer...")
191
- _ctx = pipeline.stage_1._transformer_ctx() # Get context manager
192
- _ctx.__enter__() # Enter context
193
  _stage_1_transformer = _ctx.__dict__.get('transformer') or \
194
  getattr(pipeline.stage_1, '_transformer', None)
195
- # Replace _transformer_ctx with a lambda that returns cached model
196
- pipeline.stage_1._transformer_ctx = lambda: _ctx
197
  print(f" Captured stage 1 transformer: {type(_stage_1_transformer)}")
198
 
199
  print(" Loading stage 2 transformer...")
200
- _ctx = pipeline.stage_2._transformer_ctx()
201
  _ctx.__enter__()
202
  _stage_2_transformer = _ctx.__dict__.get('transformer') or \
203
  getattr(pipeline.stage_2, '_transformer', None)
204
- pipeline.stage_2._transformer_ctx = lambda: _ctx
205
  print(f" Captured stage 2 transformer: {type(_stage_2_transformer)}")
206
 
207
- # 2. Load text encoder via _text_encoder_ctx() (call first, then enter)
208
  print(" Loading Gemma text encoder...")
209
- _ctx = pipeline.prompt_encoder._text_encoder_ctx()
210
  _ctx.__enter__()
211
  _text_encoder = _ctx.__dict__.get('text_encoder') or \
212
  getattr(pipeline.prompt_encoder, '_text_encoder', None)
213
- # Store as instance attribute and create replacement lambda
214
  pipeline.prompt_encoder._text_encoder = _text_encoder
215
- pipeline.prompt_encoder._text_encoder_ctx = lambda: _ctx
216
  print(f" Captured text encoder: {type(_text_encoder)}")
217
 
218
  # 3. Load video encoder (builder method - returns model directly)
@@ -256,12 +254,7 @@ print(f" Captured spatial upsampler: {type(_spatial_upsampler)}")
256
  print(" Loading image conditioner...")
257
  if hasattr(pipeline, 'image_conditioner'):
258
  if hasattr(pipeline.image_conditioner, 'video_encoder'):
259
- _ic_encoder = pipeline.image_conditioner.video_encoder()
260
- pipeline.image_conditioner.video_encoder = lambda: _ic_encoder
261
-
262
- print(" Models captured and preserved for ZeroGPU tensor packing...")
263
- print("All models preloaded for ZeroGPU tensor packing!")
264
- print("=" * 80)
265
 
266
  # =============================================================================
267
  # Helper Functions
 
183
 
184
  # TI2VidTwoStagesHQPipeline uses:
185
  # - Builder methods that return models directly when called
186
+ # - Context methods that return context managers when called (require streaming_prefetch_count)
187
  # We need to call these methods, capture the results, and preserve them
188
 
189
+ # 1. Load transformer via _transformer_ctx(streaming_prefetch_count=None)
190
  print(" Loading stage 1 transformer...")
191
+ _ctx = pipeline.stage_1._transformer_ctx(streaming_prefetch_count=None)
192
+ _ctx.__enter__()
193
  _stage_1_transformer = _ctx.__dict__.get('transformer') or \
194
  getattr(pipeline.stage_1, '_transformer', None)
195
+ pipeline.stage_1._transformer_ctx = lambda streaming_prefetch_count=None: _ctx
 
196
  print(f" Captured stage 1 transformer: {type(_stage_1_transformer)}")
197
 
198
  print(" Loading stage 2 transformer...")
199
+ _ctx = pipeline.stage_2._transformer_ctx(streaming_prefetch_count=None)
200
  _ctx.__enter__()
201
  _stage_2_transformer = _ctx.__dict__.get('transformer') or \
202
  getattr(pipeline.stage_2, '_transformer', None)
203
+ pipeline.stage_2._transformer_ctx = lambda streaming_prefetch_count=None: _ctx
204
  print(f" Captured stage 2 transformer: {type(_stage_2_transformer)}")
205
 
206
+ # 2. Load text encoder via _text_encoder_ctx(streaming_prefetch_count=None)
207
  print(" Loading Gemma text encoder...")
208
+ _ctx = pipeline.prompt_encoder._text_encoder_ctx(streaming_prefetch_count=None)
209
  _ctx.__enter__()
210
  _text_encoder = _ctx.__dict__.get('text_encoder') or \
211
  getattr(pipeline.prompt_encoder, '_text_encoder', None)
 
212
  pipeline.prompt_encoder._text_encoder = _text_encoder
213
+ pipeline.prompt_encoder._text_encoder_ctx = lambda streaming_prefetch_count=None: _ctx
214
  print(f" Captured text encoder: {type(_text_encoder)}")
215
 
216
  # 3. Load video encoder (builder method - returns model directly)
 
254
  print(" Loading image conditioner...")
255
  if hasattr(pipeline, 'image_conditioner'):
256
  if hasattr(pipeline.image_conditioner, 'video_encoder'):
257
+ _ic_encoder = pipeline.image
 
 
 
 
 
258
 
259
  # =============================================================================
260
  # Helper Functions