Upload external/TripoSG/triposg/models/transformers/__init__.py with huggingface_hub
Browse files
external/TripoSG/triposg/models/transformers/__init__.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Optional
|
| 2 |
+
|
| 3 |
+
from .triposg_transformer import TripoSGDiTModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def default_set_attn_proc_func(
|
| 7 |
+
name: str,
|
| 8 |
+
hidden_size: int,
|
| 9 |
+
cross_attention_dim: Optional[int],
|
| 10 |
+
ori_attn_proc: object,
|
| 11 |
+
) -> object:
|
| 12 |
+
return ori_attn_proc
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def set_transformer_attn_processor(
|
| 16 |
+
transformer: TripoSGDiTModel,
|
| 17 |
+
set_self_attn_proc_func: Callable = default_set_attn_proc_func,
|
| 18 |
+
set_cross_attn_1_proc_func: Callable = default_set_attn_proc_func,
|
| 19 |
+
set_cross_attn_2_proc_func: Callable = default_set_attn_proc_func,
|
| 20 |
+
set_self_attn_module_names: Optional[list[str]] = None,
|
| 21 |
+
set_cross_attn_1_module_names: Optional[list[str]] = None,
|
| 22 |
+
set_cross_attn_2_module_names: Optional[list[str]] = None,
|
| 23 |
+
) -> None:
|
| 24 |
+
do_set_processor = lambda name, module_names: (
|
| 25 |
+
any([name.startswith(module_name) for module_name in module_names])
|
| 26 |
+
if module_names is not None
|
| 27 |
+
else True
|
| 28 |
+
) # prefix match
|
| 29 |
+
|
| 30 |
+
attn_procs = {}
|
| 31 |
+
for name, attn_processor in transformer.attn_processors.items():
|
| 32 |
+
hidden_size = transformer.config.width
|
| 33 |
+
if name.endswith("attn1.processor"):
|
| 34 |
+
# self attention
|
| 35 |
+
attn_procs[name] = (
|
| 36 |
+
set_self_attn_proc_func(name, hidden_size, None, attn_processor)
|
| 37 |
+
if do_set_processor(name, set_self_attn_module_names)
|
| 38 |
+
else attn_processor
|
| 39 |
+
)
|
| 40 |
+
elif name.endswith("attn2.processor"):
|
| 41 |
+
# cross attention
|
| 42 |
+
cross_attention_dim = transformer.config.cross_attention_dim
|
| 43 |
+
attn_procs[name] = (
|
| 44 |
+
set_cross_attn_1_proc_func(
|
| 45 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
| 46 |
+
)
|
| 47 |
+
if do_set_processor(name, set_cross_attn_1_module_names)
|
| 48 |
+
else attn_processor
|
| 49 |
+
)
|
| 50 |
+
elif name.endswith("attn2_2.processor"):
|
| 51 |
+
# cross attention 2
|
| 52 |
+
cross_attention_dim = transformer.config.cross_attention_2_dim
|
| 53 |
+
attn_procs[name] = (
|
| 54 |
+
set_cross_attn_2_proc_func(
|
| 55 |
+
name, hidden_size, cross_attention_dim, attn_processor
|
| 56 |
+
)
|
| 57 |
+
if do_set_processor(name, set_cross_attn_2_module_names)
|
| 58 |
+
else attn_processor
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
transformer.set_attn_processor(attn_procs)
|