dagloop5 commited on
Commit
e2569a7
·
verified ·
1 Parent(s): f638ada

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -2,10 +2,6 @@ import os
2
  import subprocess
3
  import sys
4
 
5
- # Disable torch.compile / dynamo before any torch import
6
- os.environ["TORCH_COMPILE_DISABLE"] = "1"
7
- os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
-
9
  # Clone LTX-2 repo and install packages
10
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
11
  LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
@@ -36,8 +32,8 @@ import gc
36
  import hashlib
37
 
38
  import torch
39
- torch._dynamo.config.suppress_errors = True
40
- torch._dynamo.config.disable = True
41
 
42
  import spaces
43
  import gradio as gr
@@ -700,6 +696,18 @@ print("=" * 80)
700
  print("Pipeline ready!")
701
  print("=" * 80)
702
 
 
 
 
 
 
 
 
 
 
 
 
 
703
  def log_memory(tag: str):
704
  if torch.cuda.is_available():
705
  allocated = torch.cuda.memory_allocated() / 1024**3
 
2
  import subprocess
3
  import sys
4
 
 
 
 
 
5
  # Clone LTX-2 repo and install packages
6
  LTX_REPO_URL = "https://github.com/Lightricks/LTX-2.git"
7
  LTX_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "LTX-2")
 
32
  import hashlib
33
 
34
  import torch
35
+ torch._dynamo.config.suppress_errors = False
36
+ torch._dynamo.config.disable = False
37
 
38
  import spaces
39
  import gradio as gr
 
696
  print("Pipeline ready!")
697
  print("=" * 80)
698
 
699
+ # AFTER your preload block, compile the transformer
700
+ print("Compiling transformer with torch.compile...")
701
+
702
+ # Regional compilation on transformer_blocks is best for DiT models
703
+ # This compiles the repeated BasicAVTransformerBlock layers and reuses the graph
704
+ _transformer = torch.compile(
705
+ _transformer,
706
+ mode="max-autotune",
707
+ fullgraph=False, # safer; set True if no graph breaks
708
+ dynamic=False, # critical: must be False for fixed shapes per generation
709
+ )
710
+
711
  def log_memory(tag: str):
712
  if torch.cuda.is_available():
713
  allocated = torch.cuda.memory_allocated() / 1024**3