File size: 6,023 Bytes
b055442 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | from typing import Optional
import torch
import torch.nn.functional as F
import torch.distributed as dist
from diffusers.models.attention import Attention
from diffusers.models.embeddings import apply_rotary_emb
try:
import xfuser
from xfuser.core.distributed import (get_sequence_parallel_rank,
get_sequence_parallel_world_size,
get_sp_group, get_world_group,
init_distributed_environment,
initialize_model_parallel)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
except Exception as ex:
get_sequence_parallel_world_size = None
get_sequence_parallel_rank = None
xFuserLongContextAttention = None
get_sp_group = None
get_world_group = None
init_distributed_environment = None
initialize_model_parallel = None
def set_multi_gpus_devices(ulysses_degree, ring_degree):
if ulysses_degree > 1 or ring_degree > 1:
if get_sp_group is None:
raise RuntimeError("xfuser is not installed.")
dist.init_process_group("nccl")
print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % (
ulysses_degree, ring_degree, dist.get_rank(),
dist.get_world_size()))
assert dist.get_world_size() == ring_degree * ulysses_degree, \
"number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size()
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),
ring_degree=ring_degree,
ulysses_degree=ulysses_degree)
# device = torch.device("cuda:%d" % dist.get_rank())
device = torch.device(f"cuda:{get_world_group().local_rank}")
print('rank=%d device=%s' % (get_world_group().rank, str(device)))
else:
device = "cuda"
return device
class CogVideoXMultiGPUsAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if xFuserLongContextAttention is not None:
try:
self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
except Exception:
self.hybrid_seq_parallel_attn = None
else:
self.hybrid_seq_parallel_attn = None
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
if self.hybrid_seq_parallel_attn is None:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states
else:
img_q = query[:, :, text_seq_length:].transpose(1, 2)
txt_q = query[:, :, :text_seq_length].transpose(1, 2)
img_k = key[:, :, text_seq_length:].transpose(1, 2)
txt_k = key[:, :, :text_seq_length].transpose(1, 2)
img_v = value[:, :, text_seq_length:].transpose(1, 2)
txt_v = value[:, :, :text_seq_length].transpose(1, 2)
hidden_states = self.hybrid_seq_parallel_attn(
None,
img_q, img_k, img_v, dropout_p=0.0, causal=False,
joint_tensor_query=txt_q,
joint_tensor_key=txt_k,
joint_tensor_value=txt_v,
joint_strategy='front',
).transpose(1, 2)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
|