Daankular commited on
Commit
167d39f
·
verified ·
1 Parent(s): 13bd15c

Upload external/TripoSG/triposg/models/transformers/__init__.py with huggingface_hub

Browse files
external/TripoSG/triposg/models/transformers/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional
2
+
3
+ from .triposg_transformer import TripoSGDiTModel
4
+
5
+
6
+ def default_set_attn_proc_func(
7
+ name: str,
8
+ hidden_size: int,
9
+ cross_attention_dim: Optional[int],
10
+ ori_attn_proc: object,
11
+ ) -> object:
12
+ return ori_attn_proc
13
+
14
+
15
+ def set_transformer_attn_processor(
16
+ transformer: TripoSGDiTModel,
17
+ set_self_attn_proc_func: Callable = default_set_attn_proc_func,
18
+ set_cross_attn_1_proc_func: Callable = default_set_attn_proc_func,
19
+ set_cross_attn_2_proc_func: Callable = default_set_attn_proc_func,
20
+ set_self_attn_module_names: Optional[list[str]] = None,
21
+ set_cross_attn_1_module_names: Optional[list[str]] = None,
22
+ set_cross_attn_2_module_names: Optional[list[str]] = None,
23
+ ) -> None:
24
+ do_set_processor = lambda name, module_names: (
25
+ any([name.startswith(module_name) for module_name in module_names])
26
+ if module_names is not None
27
+ else True
28
+ ) # prefix match
29
+
30
+ attn_procs = {}
31
+ for name, attn_processor in transformer.attn_processors.items():
32
+ hidden_size = transformer.config.width
33
+ if name.endswith("attn1.processor"):
34
+ # self attention
35
+ attn_procs[name] = (
36
+ set_self_attn_proc_func(name, hidden_size, None, attn_processor)
37
+ if do_set_processor(name, set_self_attn_module_names)
38
+ else attn_processor
39
+ )
40
+ elif name.endswith("attn2.processor"):
41
+ # cross attention
42
+ cross_attention_dim = transformer.config.cross_attention_dim
43
+ attn_procs[name] = (
44
+ set_cross_attn_1_proc_func(
45
+ name, hidden_size, cross_attention_dim, attn_processor
46
+ )
47
+ if do_set_processor(name, set_cross_attn_1_module_names)
48
+ else attn_processor
49
+ )
50
+ elif name.endswith("attn2_2.processor"):
51
+ # cross attention 2
52
+ cross_attention_dim = transformer.config.cross_attention_2_dim
53
+ attn_procs[name] = (
54
+ set_cross_attn_2_proc_func(
55
+ name, hidden_size, cross_attention_dim, attn_processor
56
+ )
57
+ if do_set_processor(name, set_cross_attn_2_module_names)
58
+ else attn_processor
59
+ )
60
+
61
+ transformer.set_attn_processor(attn_procs)