From e040c35d34131c7ffcc55bd9c1ec8e21d5f48025 Mon Sep 17 00:00:00 2001 From: Mihail Dumitrescu Date: Wed, 14 Sep 2022 20:25:56 +0300 Subject: [PATCH] Refactor attention.CrossAttention to remove duplicate code and apply optimizations Apply ~6% speedup by moving * self.scale to earlier on a smaller tensor. When we have enough VRAM don't make a useless zeros tensor. Switch between cuda/mps/cpu based on q.device.type to allow cleaner per architecture future optimizations. For cuda and cpu keep VRAM usage and faster slicing consistent. For cpu use smaller slices. Tested ~20% faster on i7, 9.8 to 7.7 s/it. Fix = typo to self.mem_total >= 8 in einsum_op_mps_v2 as per #582 discussion. --- ldm/modules/attention.py | 138 +++++++++++++++++---------------------- 1 file changed, 61 insertions(+), 77 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index ec96230b464..ef9c2d3e653 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -90,7 +90,7 @@ def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) @@ -167,101 +167,85 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) - - if torch.cuda.is_available(): - self.einsum_op = self.einsum_op_cuda - else: - self.mem_total = psutil.virtual_memory().total / (1024**3) - self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2 - - def einsum_op_compvis(self, q, k, v, r1): - s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - r1 = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 - - def einsum_op_mps_v1(self, q, k, v, r1): + + self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) + + def einsum_op_compvis(self, q, k, v): + s = einsum('b i d, b j d -> b i j', q, k) + s = s.softmax(dim=-1, dtype=s.dtype) + return einsum('b i j, b j d -> b i d', s, v) + + def einsum_op_slice_0(self, q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], slice_size): + end = i + slice_size + r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + return r + + def einsum_op_slice_1(self, q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v) + return r + + def einsum_op_mps_v1(self, q, k, v): if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - r1 = self.einsum_op_compvis(q, k, v, r1) + return self.einsum_op_compvis(q, k, v) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 - - def einsum_op_mps_v2(self, q, k, v, r1): - if self.mem_total >= 8 and q.shape[1] <= 4096: - r1 = self.einsum_op_compvis(q, k, v, r1) + return self.einsum_op_slice_1(q, k, v, slice_size) + + def einsum_op_mps_v2(self, q, k, v): + if self.mem_total_gb > 8 and q.shape[1] <= 4096: + return self.einsum_op_compvis(q, k, v) else: - slice_size = 1 - for i in range(0, q.shape[0], slice_size): - end = min(q.shape[0], i + slice_size) - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - return r1 - - def einsum_op_cuda(self, q, k, v, r1): + return self.einsum_op_slice_0(q, k, v, 1) + + def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): + size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) + if size_mb <= max_tensor_mb: + return self.einsum_op_compvis(q, k, v) + div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() + if div <= q.shape[0]: + return self.einsum_op_slice_0(q, k, v, q.shape[0] // div) + return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + + def einsum_op_cuda(self, q, k, v): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch + # Divide factor of safety as there's copying and fragmentation + return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 - mem_required = tensor_size * 2.5 - steps = 1 + def einsum_op(self, q, k, v): + if q.device.type == 'cuda': + return self.einsum_op_cuda(q, k, v) - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + if q.device.type == 'mps': + if self.mem_total_gb >= 32: + return self.einsum_op_mps_v1(q, k, v) + return self.einsum_op_mps_v2(q, k, v) - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = min(q.shape[1], i + slice_size) - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - return r1 + # Smaller slices are faster due to L2/L3/SLC caches. + # Tested on i7 with 8MB L3 cache. + return self.einsum_op_tensor_mem(q, k, v, 32) def forward(self, x, context=None, mask=None): h = self.heads - q_in = self.to_q(x) + q = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) - v_in = self.to_v(context) - device_type = 'mps' if x.device.type == 'mps' else 'cuda' + k = self.to_k(context) * self.scale + v = self.to_v(context) del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - r1 = self.einsum_op(q, k, v, r1) - del q, k, v - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = self.einsum_op(q, k, v) + return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) class BasicTransformerBlock(nn.Module):