dagloop5 commited on
Commit
4def770
·
verified ·
1 Parent(s): 32bfebd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -67
app.py CHANGED
@@ -96,73 +96,6 @@ try:
96
  except Exception as e:
97
  print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
98
 
99
- # Add this patch after imports in app.py
100
-
101
- def _patch_attention_for_kv_cache():
102
- """Patch Attention.forward to accept pre-projected K/V."""
103
- from ltx_core.model.transformer.attention import Attention, apply_rotary_emb
104
-
105
- _original_forward = Attention.forward
106
-
107
- def patched_forward(
108
- self,
109
- x: torch.Tensor,
110
- context: torch.Tensor | None = None,
111
- mask: torch.Tensor | None = None,
112
- pe: torch.Tensor | None = None,
113
- k_pe: torch.Tensor | None = None,
114
- perturbation_mask: torch.Tensor | None = None,
115
- all_perturbed: bool = False,
116
- # NEW: pre-computed KV for cross-attention
117
- cached_k: torch.Tensor | None = None,
118
- cached_v: torch.Tensor | None = None,
119
- ) -> torch.Tensor:
120
- context = x if context is None else context
121
- use_attention = not all_perturbed
122
-
123
- v = cached_v if cached_v is not None else self.to_v(context)
124
-
125
- if not use_attention:
126
- out = v
127
- else:
128
- if cached_k is not None:
129
- q = self.to_q(x)
130
- q = self.q_norm(q)
131
- k = cached_k
132
- if pe is not None:
133
- q = apply_rotary_emb(q, pe, self.rope_type)
134
- k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)
135
- else:
136
- q = self.to_q(x)
137
- k = self.to_k(context)
138
-
139
- q = self.q_norm(q)
140
- k = self.k_norm(k)
141
-
142
- if pe is not None:
143
- q = apply_rotary_emb(q, pe, self.rope_type)
144
- k = apply_rotary_emb(k, pe if k_pe is None else k_pe, self.rope_type)
145
-
146
- out = self.attention_function(q, k, v, self.heads, mask)
147
-
148
- if perturbation_mask is not None:
149
- out = out * perturbation_mask + v * (1 - perturbation_mask)
150
-
151
- # Gating logic remains the same
152
- if self.to_gate_logits is not None:
153
- gate_logits = self.to_gate_logits(x)
154
- b, t, _ = out.shape
155
- out = out.view(b, t, self.heads, self.dim_head)
156
- gates = 2.0 * torch.sigmoid(gate_logits)
157
- out = out * gates.unsqueeze(-1)
158
- out = out.view(b, t, self.heads * self.dim_head)
159
-
160
- return self.to_out(out)
161
-
162
- Attention.forward = patched_forward
163
-
164
- _patch_attention_for_kv_cache()
165
-
166
  logging.getLogger().setLevel(logging.INFO)
167
 
168
  MAX_SEED = np.iinfo(np.int32).max
 
96
  except Exception as e:
97
  print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  logging.getLogger().setLevel(logging.INFO)
100
 
101
  MAX_SEED = np.iinfo(np.int32).max