From 0e94c12893378b59eb1102045b09e9f761241a25 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 13:23:29 -0700 Subject: [PATCH 01/55] add tcg --- scripts/incantation_base.py | 2 ++ scripts/tcg.py | 41 +++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 scripts/tcg.py diff --git a/scripts/incantation_base.py b/scripts/incantation_base.py index 48c4fa1..fa6dfa6 100644 --- a/scripts/incantation_base.py +++ b/scripts/incantation_base.py @@ -9,6 +9,7 @@ from modules.processing import StableDiffusionProcessing from scripts.ui_wrapper import UIWrapper from scripts.incant import IncantExtensionScript +from scripts.tcg import TCGExtensionScript from scripts.t2i_zero import T2I0ExtensionScript from scripts.scfg import SCFGExtensionScript from scripts.pag import PAGExtensionScript @@ -34,6 +35,7 @@ def __init__(self, module: UIWrapper, module_idx = 0, num_args = -1, arg_idx = - # main scripts submodules: list[SubmoduleInfo] = [ + SubmoduleInfo(module=TCGExtensionScript()), SubmoduleInfo(module=SCFGExtensionScript()), SubmoduleInfo(module=PAGExtensionScript()), SubmoduleInfo(module=T2I0ExtensionScript()), diff --git a/scripts/tcg.py b/scripts/tcg.py new file mode 100644 index 0000000..0b7663e --- /dev/null +++ b/scripts/tcg.py @@ -0,0 +1,41 @@ +import gradio as gr +from scripts.ui_wrapper import UIWrapper +from scripts.incant_utils import module_hooks + +class TCGExtensionScript(UIWrapper): + def __init__(self): + self.infotext_fields: list = [] + self.paste_field_names: list = [] + + def title(self) -> str: + raise 'TCG [arXiv:2404.11824]' + + def setup_ui(self, is_img2img) -> list: + active = gr.Checkbox("Active", value=True) + active.do_not_save_to_config = True + return [active] + + def before_process(self, p, *args, **kwargs): + pass + + def process(self, p, *args, **kwargs): + pass + + def before_process_batch(self, p, *args, **kwargs): + pass + + def process_batch(self, p, *args, **kwargs): + pass + + def postprocess_batch(self, p, *args, **kwargs): + pass + + def unhook_callbacks(self) -> None: + pass + + def get_xyz_axis_options(self) -> dict: + return {} + +def arg(p, field_name: str, variable_name:str, default=None, **kwargs): + """ Get argument from field_name or variable_name, or default if not found """ + return getattr(p, field_name, kwargs.get(variable_name, None)) From ccad86a7a00f96f10a2a861b840481ba401c4cfb Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 13:26:17 -0700 Subject: [PATCH 02/55] fix ui --- scripts/tcg.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 0b7663e..030f291 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -2,6 +2,12 @@ from scripts.ui_wrapper import UIWrapper from scripts.incant_utils import module_hooks +""" +WIP Implementation of https://arxiv.org/abs/2404.11824 +Author: v0xie +GitHub URL: https://github.com/v0xie/sd-webui-incantations + +""" class TCGExtensionScript(UIWrapper): def __init__(self): self.infotext_fields: list = [] @@ -11,9 +17,12 @@ def title(self) -> str: raise 'TCG [arXiv:2404.11824]' def setup_ui(self, is_img2img) -> list: - active = gr.Checkbox("Active", value=True) - active.do_not_save_to_config = True - return [active] + with gr.Accordion('TCG', open=True): + active = gr.Checkbox(label="Active", value=True) + opts = [active] + for opt in opts: + opt.do_not_save_to_config = True + return opts def before_process(self, p, *args, **kwargs): pass From 4bf6d3eaee349abe9ee830066a87b9d0fab66a12 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 13:34:01 -0700 Subject: [PATCH 03/55] add scaffolding --- scripts/tcg.py | 72 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 65 insertions(+), 7 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 030f291..1f1a20d 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -1,4 +1,6 @@ +import logging import gradio as gr +import torch from scripts.ui_wrapper import UIWrapper from scripts.incant_utils import module_hooks @@ -8,6 +10,9 @@ GitHub URL: https://github.com/v0xie/sd-webui-incantations """ + +logger = logging.getLogger(__name__) + class TCGExtensionScript(UIWrapper): def __init__(self): self.infotext_fields: list = [] @@ -23,6 +28,26 @@ def setup_ui(self, is_img2img) -> list: for opt in opts: opt.do_not_save_to_config = True return opts + + def get_modules(self): + return module_hooks.get_modules( module_name_filter='CrossAttention') + + def hook_modules(self): + def tcg_to_q_hook(module, input, kwargs, output): + setattr(module.tcg_parent_module[0], 'tcg_to_q_map', output) + + def tcg_to_k_hook(module, input, kwargs, output): + setattr(module.tcg_parent_module[0], 'tcg_to_k_map', output) + + for module in self.get_modules(): + if not module.network_layer_name.endswith('attn2'): + continue + module_hooks.modules_add_field(module, 'tcg_to_q_map', None) + module_hooks.modules_add_field(module, 'tcg_to_k_map', None) + module_hooks.modules_add_field(module.to_q, 'tcg_parent_module', [module]) + module_hooks.modules_add_field(module.to_k, 'tcg_parent_module', [module]) + module_hooks.module_add_forward_hook(module.to_q, tcg_to_q_hook, with_kwargs=True) + module_hooks.module_add_forward_hook(module.to_k, tcg_to_k_hook, with_kwargs=True) def before_process(self, p, *args, **kwargs): pass @@ -30,21 +55,54 @@ def before_process(self, p, *args, **kwargs): def process(self, p, *args, **kwargs): pass - def before_process_batch(self, p, *args, **kwargs): - pass + def before_process_batch(self, p, active, *args, **kwargs): + active = getattr(p, 'tcg_active', active) + if not active: + return def process_batch(self, p, *args, **kwargs): pass def postprocess_batch(self, p, *args, **kwargs): - pass + self.unhook_callbacks() def unhook_callbacks(self) -> None: - pass + for module in self.get_modules(): + module_hooks.remove_module_forward_hook(module.to_q, 'tcg_to_q_hook') + module_hooks.remove_module_forward_hook(module.to_k, 'tcg_to_k_hook') + module_hooks.modules_remove_field(module, 'tcg_to_q_map') + module_hooks.modules_remove_field(module, 'tcg_to_k_map') + module_hooks.modules_remove_field(module.to_q, 'tcg_parent_module') + module_hooks.modules_remove_field(module.to_k, 'tcg_parent_module') def get_xyz_axis_options(self) -> dict: return {} -def arg(p, field_name: str, variable_name:str, default=None, **kwargs): - """ Get argument from field_name or variable_name, or default if not found """ - return getattr(p, field_name, kwargs.get(variable_name, None)) + +def get_attention_scores(to_q_map, to_k_map, dtype): + """ Calculate the attention scores for the given query and key maps + Arguments: + to_q_map: torch.Tensor - query map + to_k_map: torch.Tensor - key map + dtype: torch.dtype - data type of the tensor + Returns: + torch.Tensor - attention scores + """ + # based on diffusers models/attention.py "get_attention_scores" + # use in place operations vs. softmax to save memory: https://stackoverflow.com/questions/53732209/torch-in-place-operations-to-save-memory-softmax + # 512x: 2.65G -> 2.47G + # attn_probs = attn_scores.softmax(dim=-1).to(device=shared.device, dtype=to_q_map.dtype) + + attn_probs = to_q_map @ to_k_map.transpose(-1, -2) + + # avoid nan by converting to float32 and subtracting max + attn_probs = attn_probs.to(dtype=torch.float32) # + attn_probs -= torch.max(attn_probs) + + torch.exp(attn_probs, out = attn_probs) + summed = attn_probs.sum(dim=-1, keepdim=True, dtype=torch.float32) + attn_probs /= summed + + attn_probs = attn_probs.to(dtype=dtype) + + return attn_probs From c510affb2b69fa288e06c559bc4dccec5bc4e3c1 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 13:50:40 -0700 Subject: [PATCH 04/55] allow debugging without launching webui --- scripts/tcg.py | 94 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 79 insertions(+), 15 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 1f1a20d..c630e85 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -1,6 +1,11 @@ +import os, sys import logging import gradio as gr import torch + +if __name__ == '__main__' and os.environ.get('INCANT_DEBUG', None): + sys.path.append(f'{os.getcwd()}') + sys.path.append(f'{os.getcwd()}/extensions/sd-webui-incantations') from scripts.ui_wrapper import UIWrapper from scripts.incant_utils import module_hooks @@ -32,7 +37,15 @@ def setup_ui(self, is_img2img) -> list: def get_modules(self): return module_hooks.get_modules( module_name_filter='CrossAttention') - def hook_modules(self): + def before_process_batch(self, p, active, *args, **kwargs): + self.unhook_callbacks() + active = getattr(p, 'tcg_active', active) + if not active: + return + + def tcg_forward_hook(module, input, kwargs, output): + pass + def tcg_to_q_hook(module, input, kwargs, output): setattr(module.tcg_parent_module[0], 'tcg_to_q_map', output) @@ -48,20 +61,7 @@ def tcg_to_k_hook(module, input, kwargs, output): module_hooks.modules_add_field(module.to_k, 'tcg_parent_module', [module]) module_hooks.module_add_forward_hook(module.to_q, tcg_to_q_hook, with_kwargs=True) module_hooks.module_add_forward_hook(module.to_k, tcg_to_k_hook, with_kwargs=True) - - def before_process(self, p, *args, **kwargs): - pass - - def process(self, p, *args, **kwargs): - pass - - def before_process_batch(self, p, active, *args, **kwargs): - active = getattr(p, 'tcg_active', active) - if not active: - return - - def process_batch(self, p, *args, **kwargs): - pass + module_hooks.module_add_forward_hook(module, tcg_forward_hook, with_kwargs=True) def postprocess_batch(self, p, *args, **kwargs): self.unhook_callbacks() @@ -75,8 +75,56 @@ def unhook_callbacks(self) -> None: module_hooks.modules_remove_field(module.to_q, 'tcg_parent_module') module_hooks.modules_remove_field(module.to_k, 'tcg_parent_module') + def before_process(self, p, *args, **kwargs): + pass + + def process(self, p, *args, **kwargs): + pass + + def process_batch(self, p, *args, **kwargs): + pass + def get_xyz_axis_options(self) -> dict: return {} + + +def calculate_centroid(attention_map): + """ Calculate the centroid of the attention map + Arguments: + attention_map: torch.Tensor - The attention map to calculate the centroid. Shape: (batch_size, height, width, channels) + Returns: + torch.Tensor - The centroid of the attention map. Shape: (batch_size, 2, channels) + """ + + # Get the height and width + batch_size, height, width, channels = attention_map.shape + + # Create a mesh grid of height and width coordinates + h_coords = torch.arange(height).unsqueeze(1).expand(height, width).to(attention_map.device) + w_coords = torch.arange(width).unsqueeze(0).expand(height, width).to(attention_map.device) + + # Flatten the coordinates to apply the sum + h_coords = h_coords.reshape(-1) + w_coords = w_coords.reshape(-1) + + # Flatten the attention_map for easier manipulation + attention_map_flat = attention_map.view(batch_size, -1, channels) + + # Sum of attention scores for each channel + attention_sum = attention_map_flat.sum(dim=1, keepdim=True) + 1e-10 # Add small value to avoid division by zero + + # Weighted sum of the coordinates + h_weighted_sum = (h_coords.unsqueeze(0) * attention_map_flat).sum(dim=1) + w_weighted_sum = (w_coords.unsqueeze(0) * attention_map_flat).sum(dim=1) + + # Calculate the centroids + centroid_h = h_weighted_sum / attention_sum + centroid_w = w_weighted_sum / attention_sum + + # Combine the centroids into a single tensor of shape (batch_size, 2, channels) + centroids = torch.stack([centroid_h, centroid_w], dim=1) + + return centroids def get_attention_scores(to_q_map, to_k_map, dtype): @@ -106,3 +154,19 @@ def get_attention_scores(to_q_map, to_k_map, dtype): attn_probs = attn_probs.to(dtype=dtype) return attn_probs + + +if __name__ == '__main__': + # Create a simple attention map with known values + attention_map = torch.zeros((1, 5, 5, 1)) # Shape (batch_size, height, width, channels) + attention_map[0, 2, 2, 0] = 1 # Put all attention on the center + + # Calculate centroids + centroids = calculate_centroid(attention_map) + + # Expected centroid is the center of the attention map (2, 2) + expected_centroid = torch.tensor([[[2.0], [2.0]]]) + + # Check if the calculated centroid matches the expected centroid + assert torch.allclose(centroids, expected_centroid), f"Expected {expected_centroid}, but got {centroids}" + print("Sanity check passed!") \ No newline at end of file From 2112488d3dfa8c66a48e4b789da5cc28ccf2b0ac Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 14:13:16 -0700 Subject: [PATCH 05/55] add tcgscript --- scripts/incantation_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/incantation_base.py b/scripts/incantation_base.py index fa6dfa6..b4aba1b 100644 --- a/scripts/incantation_base.py +++ b/scripts/incantation_base.py @@ -69,7 +69,7 @@ def show(self, is_img2img): def ui(self, is_img2img): # setup UI out = [] - with gr.Accordion('Incantations', open=False): + with gr.Accordion('Incantations', open=True): for idx, module_info in enumerate(submodules): module_info.module_idx = idx module = module_info.module From dbe86859375da733410d07bbfef79d916c59d8eb Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 14:13:29 -0700 Subject: [PATCH 06/55] add centroid calculation fn --- scripts/tcg.py | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index c630e85..6db6dd1 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -91,38 +91,29 @@ def get_xyz_axis_options(self) -> dict: def calculate_centroid(attention_map): """ Calculate the centroid of the attention map Arguments: - attention_map: torch.Tensor - The attention map to calculate the centroid. Shape: (batch_size, height, width, channels) + attention_map: torch.Tensor - The attention map to calculate the centroid. Shape: (B, H, W, C) Returns: - torch.Tensor - The centroid of the attention map. Shape: (batch_size, 2, channels) + torch.Tensor - The centroid of the attention map. Shape: (B, C, 2) """ # Get the height and width - batch_size, height, width, channels = attention_map.shape - - # Create a mesh grid of height and width coordinates - h_coords = torch.arange(height).unsqueeze(1).expand(height, width).to(attention_map.device) - w_coords = torch.arange(width).unsqueeze(0).expand(height, width).to(attention_map.device) - - # Flatten the coordinates to apply the sum - h_coords = h_coords.reshape(-1) - w_coords = w_coords.reshape(-1) - - # Flatten the attention_map for easier manipulation - attention_map_flat = attention_map.view(batch_size, -1, channels) + B, H, W, C = attention_map.shape + + h_coords = torch.arange(H).view(1, H, 1, 1).to(attention_map.device) + w_coords = torch.arange(W).view(1, 1, W, 1).to(attention_map.device) # Sum of attention scores for each channel - attention_sum = attention_map_flat.sum(dim=1, keepdim=True) + 1e-10 # Add small value to avoid division by zero + attention_sum = torch.sum(attention_map, dim=(1, 2)) # shape: (B, C) # Weighted sum of the coordinates - h_weighted_sum = (h_coords.unsqueeze(0) * attention_map_flat).sum(dim=1) - w_weighted_sum = (w_coords.unsqueeze(0) * attention_map_flat).sum(dim=1) + h_weighted_sum = torch.sum(h_coords * attention_map, dim=(1,2)) # (B, C) + w_weighted_sum = torch.sum(w_coords * attention_map, dim=(1,2)) # (B, C) # Calculate the centroids centroid_h = h_weighted_sum / attention_sum centroid_w = w_weighted_sum / attention_sum - # Combine the centroids into a single tensor of shape (batch_size, 2, channels) - centroids = torch.stack([centroid_h, centroid_w], dim=1) + centroids = torch.stack([centroid_h, centroid_w], dim=-1) # (B, C, 2) return centroids @@ -158,8 +149,9 @@ def get_attention_scores(to_q_map, to_k_map, dtype): if __name__ == '__main__': # Create a simple attention map with known values - attention_map = torch.zeros((1, 5, 5, 1)) # Shape (batch_size, height, width, channels) + attention_map = torch.zeros((2, 5, 5, 1)) # Shape (batch_size, height, width, channels) attention_map[0, 2, 2, 0] = 1 # Put all attention on the center + attention_map[1, 2, 2, 0] = 1 # Put all attention on the center # Calculate centroids centroids = calculate_centroid(attention_map) @@ -168,5 +160,4 @@ def get_attention_scores(to_q_map, to_k_map, dtype): expected_centroid = torch.tensor([[[2.0], [2.0]]]) # Check if the calculated centroid matches the expected centroid - assert torch.allclose(centroids, expected_centroid), f"Expected {expected_centroid}, but got {centroids}" - print("Sanity check passed!") \ No newline at end of file + assert torch.allclose(centroids, expected_centroid), f"Expected {expected_centroid}, but got {centroids}" \ No newline at end of file From ee5798e2788a34795090f21b7b76d7f633d33dac Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 14:29:57 -0700 Subject: [PATCH 07/55] add conflict detection fn --- scripts/tcg.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/scripts/tcg.py b/scripts/tcg.py index 6db6dd1..7a1924d 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -118,6 +118,28 @@ def calculate_centroid(attention_map): return centroids +def detect_conflict(attention_map, region, theta): + """ + Detect conflict in an attention map with respect to a designated region in PyTorch. + Parameters: + attention_map (torch.Tensor): Attention map of shape (B, H, W, K). + region (torch.Tensor): Binary mask of shape (B, H, W, 1) indicating the region of interest. + theta (float): Threshold value. + Returns: + torch.Tensor: Conflict detection result of shape (B, K), with values 0 or 1 indicating conflict between tokens and the region. + """ + # Ensure region is the same shape as the spatial dimensions of attention_map + assert region.shape[1:] == attention_map.shape[1:3], "Region mask must match spatial dimensions of attention map" + # Calculate the mean attention within the region + region = region.unsqueeze(-1) # Add channel dimension: (B, H, W) -> (B, H, W, 1) + attention_in_region = attention_map * region # Element-wise multiplication + mean_attention_in_region = torch.sum(attention_in_region, dim=(1, 2)) / torch.sum(region, dim=(1, 2)) # Mean over (H, W) + # Compare with threshold theta + conflict = (mean_attention_in_region > theta).float() # Convert boolean to float (0 or 1) + return conflict + + + def get_attention_scores(to_q_map, to_k_map, dtype): """ Calculate the attention scores for the given query and key maps Arguments: @@ -148,6 +170,15 @@ def get_attention_scores(to_q_map, to_k_map, dtype): if __name__ == '__main__': + B, H, W, C = 1, 64, 64, 10 + attention_map = torch.ones(B, H, W, C).to('cuda') # B H W C + region = torch.zeros((B, H, W), dtype=torch.float16, device='cuda') # B H W C + # set the left half of region to 1 + region[:, :, :W//2] = 1 + theta = 0.5 # Example threshold + conflict_detection = detect_conflict(attention_map, region, theta) + print(conflict_detection) + # Create a simple attention map with known values attention_map = torch.zeros((2, 5, 5, 1)) # Shape (batch_size, height, width, channels) attention_map[0, 2, 2, 0] = 1 # Put all attention on the center From 5609dd0c47e64bace3d800ce7233c834ecbd8355 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 14:37:56 -0700 Subject: [PATCH 08/55] modularize test --- scripts/tcg.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 7a1924d..30fb9cb 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -170,7 +170,8 @@ def get_attention_scores(to_q_map, to_k_map, dtype): if __name__ == '__main__': - B, H, W, C = 1, 64, 64, 10 + # conflict detection + B, H, W, C = 1, 64, 64, 1 attention_map = torch.ones(B, H, W, C).to('cuda') # B H W C region = torch.zeros((B, H, W), dtype=torch.float16, device='cuda') # B H W C # set the left half of region to 1 @@ -180,15 +181,14 @@ def get_attention_scores(to_q_map, to_k_map, dtype): print(conflict_detection) # Create a simple attention map with known values - attention_map = torch.zeros((2, 5, 5, 1)) # Shape (batch_size, height, width, channels) - attention_map[0, 2, 2, 0] = 1 # Put all attention on the center - attention_map[1, 2, 2, 0] = 1 # Put all attention on the center + attention_map = torch.zeros((B, H, W, C), device='cuda') # Shape (batch_size, height, width, channels) + attention_map[0, H//2, W//2, 0] = 1.0 # Put all attention on the center # Calculate centroids - centroids = calculate_centroid(attention_map) + centroids = calculate_centroid(attention_map) # (B, C, 2) # Expected centroid is the center of the attention map (2, 2) - expected_centroid = torch.tensor([[[2.0], [2.0]]]) + expected_centroid = torch.tensor([[[H/2, W/2]]], device='cuda') # Check if the calculated centroid matches the expected centroid assert torch.allclose(centroids, expected_centroid), f"Expected {expected_centroid}, but got {centroids}" \ No newline at end of file From e8c302ea242e74b32f53a53d5c7eff459c8d5f3d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 15:01:42 -0700 Subject: [PATCH 09/55] add fns for margin force and distance to edge --- scripts/tcg.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/scripts/tcg.py b/scripts/tcg.py index 30fb9cb..a51b5b6 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -86,6 +86,64 @@ def process_batch(self, p, *args, **kwargs): def get_xyz_axis_options(self) -> dict: return {} + + +def min_distance_to_nearest_edge(verts, h, w): + """ Calculate the distances of the vertices from the nearest edge given the height and width of the image + Arguments: + verts: torch.Tensor - The vertices of the attention map. Shape: (B, C, 2) + h: int - The height of the image + w: int - The width of the image + """ + x_coords, y_coords = verts[:, :, 0], verts[:, :, 1] # coordinates + distances_to_edges = torch.stack([y_coords, h - y_coords, x_coords, w - x_coords], dim=-1) # (B, C, 4) + min_distances = torch.min(distances_to_edges, dim=-1).values # (B, C) + return min_distances + + +def margin_force(attention_map, m, verts): + """ Margin force calculation + Arguments: + attention_map: torch.Tensor - The attention map to calculate the force. Shape: (B, H, W, C) + m: float - The margin force coefficient + verts: torch.Tensor - The vertices of the attention map. Shape: (B, C, 2) + Returns: + torch.Tensor - The direction of force for each vertex. Shape: (B, C, 2) + """ + B, H, W, C = attention_map.shape + min_distances = min_distance_to_nearest_edge(verts, H, W) # (B, C) + force = -m / (min_distances ** 2) + return force + +def repulsive_force(attention_map, xi, pos_vertex, pos_target): + """ Repulsive force + Arguments: + attention_map: torch.Tensor - The attention map to calculate the force. Shape: (B, H, W, C) + xi: float - The global force coefficient + pos_vertex: torch.Tensor - The position of the vertex. Shape: (B, C, 2) + pos_target: torch.Tensor - The position of the target. Shape: (B, C, 2) + Returns: + torch.Tensor - The multi-target force. Shape: (B, C) + """ + force = (-xi) ** 2 + norm_pos = (pos_vertex - pos_target).norm(dim=-1) + return force / norm_pos + + +def multi_target_force(attention_map, omega, xi, pos_vertex, pos_target): + """ Multi-target force calculation + Arguments: + attention_map: torch.Tensor - The attention map to calculate the force. Shape: (B, H, W, C) + omega: torch.tensor - Coefficients for balancing forces amongst j targets + xi: float - The global force coefficient + pos_vertex: torch.Tensor - The position of the vertex. Shape: (B, C, 2) + pos_target: torch.Tensor - The position of the target. Shape: (B, C, 2) + Returns: + torch.Tensor - The multi-target force. Shape: (B, C, 2) + """ + force = -xi ** 2 + pass + def calculate_centroid(attention_map): From 4db2f41d538cfef9547d67f712eeea3fe49839e8 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 15:29:40 -0700 Subject: [PATCH 10/55] implementing forces --- scripts/tcg.py | 55 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index a51b5b6..98e1ddf 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -88,6 +88,20 @@ def get_xyz_axis_options(self) -> dict: return {} +def displacement_force(attention_map, verts, f_rep_strength, f_margin_strength): + """ Given a set of vertices, calculate the displacement force given by the sum of margin force and repulsive force. + Arguments: + attention_map: torch.Tensor - The attention map to calculate the force. Shape: (B, H, W, C) + verts: torch.Tensor - The vertices of the attention map. Shape: (B, C, 2) + f_rep_strength: float - The strength of the repulsive force + f_margin_strength: float - The strength of the margin force + """ + B, H, W, C = attention_map.shape + f_rep = repulsive_force(f_rep_strength, verts, calculate_centroid(attention_map)) + f_margin = margin_force(f_margin_strength, H, W, verts) + return f_rep + f_margin + + def min_distance_to_nearest_edge(verts, h, w): """ Calculate the distances of the vertices from the nearest edge given the height and width of the image Arguments: @@ -101,33 +115,36 @@ def min_distance_to_nearest_edge(verts, h, w): return min_distances -def margin_force(attention_map, m, verts): +def margin_force(strength, H, W, verts): """ Margin force calculation Arguments: - attention_map: torch.Tensor - The attention map to calculate the force. Shape: (B, H, W, C) - m: float - The margin force coefficient + strength: float - The margin force coefficient + H: float - The height of the image + W: float - The width of the image verts: torch.Tensor - The vertices of the attention map. Shape: (B, C, 2) Returns: - torch.Tensor - The direction of force for each vertex. Shape: (B, C, 2) + torch.Tensor - The force for each vertex. Shape: (B, C, 2) """ - B, H, W, C = attention_map.shape min_distances = min_distance_to_nearest_edge(verts, H, W) # (B, C) - force = -m / (min_distances ** 2) + force = -strength / (min_distances ** 2) return force -def repulsive_force(attention_map, xi, pos_vertex, pos_target): - """ Repulsive force + +def repulsive_force(strength, pos_vertex, pos_target): + """ Repulsive force repels the vertices in the direction away from the target Arguments: attention_map: torch.Tensor - The attention map to calculate the force. Shape: (B, H, W, C) - xi: float - The global force coefficient + strength: float - The global force coefficient pos_vertex: torch.Tensor - The position of the vertex. Shape: (B, C, 2) - pos_target: torch.Tensor - The position of the target. Shape: (B, C, 2) + pos_target: torch.Tensor - The position of the target. Shape: (2) Returns: - torch.Tensor - The multi-target force. Shape: (B, C) + torch.Tensor - The force away from the target. Shape: (B, C, 2) """ - force = (-xi) ** 2 - norm_pos = (pos_vertex - pos_target).norm(dim=-1) - return force / norm_pos + d_pos = pos_vertex - pos_target # (B, C, 2) + d_pos_norm = d_pos.norm(dim=-1, keepdim=True) # normalize the direction + d_pos /= d_pos_norm + force = (-strength) ** 2 + return force * d_pos def multi_target_force(attention_map, omega, xi, pos_vertex, pos_target): @@ -228,8 +245,16 @@ def get_attention_scores(to_q_map, to_k_map, dtype): if __name__ == '__main__': - # conflict detection + # repulsive force B, H, W, C = 1, 64, 64, 1 + verts = torch.tensor([[[8, 8], [16, 48], [31, 31]]], dtype=torch.float16, device='cuda') # B C 2 + target = torch.tensor([[[32, 32]]], dtype=torch.float16, device='cuda') # B 1 2 + r_force = repulsive_force(1, verts, target) + + + + + # conflict detection attention_map = torch.ones(B, H, W, C).to('cuda') # B H W C region = torch.zeros((B, H, W), dtype=torch.float16, device='cuda') # B H W C # set the left half of region to 1 From 8f320212bd2852d929cf450c4f8dfa2028874f38 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 16:31:12 -0700 Subject: [PATCH 11/55] working on translate img fn --- scripts/tcg.py | 97 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 82 insertions(+), 15 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 98e1ddf..c54c986 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -2,6 +2,7 @@ import logging import gradio as gr import torch +import torch.nn.functional as F if __name__ == '__main__' and os.environ.get('INCANT_DEBUG', None): sys.path.append(f'{os.getcwd()}') @@ -88,31 +89,46 @@ def get_xyz_axis_options(self) -> dict: return {} -def displacement_force(attention_map, verts, f_rep_strength, f_margin_strength): +def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margin_strength): """ Given a set of vertices, calculate the displacement force given by the sum of margin force and repulsive force. Arguments: attention_map: torch.Tensor - The attention map to calculate the force. Shape: (B, H, W, C) verts: torch.Tensor - The vertices of the attention map. Shape: (B, C, 2) + target : torch.Tensor - The vertices of the targets. Shape: (2) f_rep_strength: float - The strength of the repulsive force f_margin_strength: float - The strength of the margin force """ B, H, W, C = attention_map.shape - f_rep = repulsive_force(f_rep_strength, verts, calculate_centroid(attention_map)) + f_rep = repulsive_force(f_rep_strength, verts, target_pos) f_margin = margin_force(f_margin_strength, H, W, verts) return f_rep + f_margin def min_distance_to_nearest_edge(verts, h, w): - """ Calculate the distances of the vertices from the nearest edge given the height and width of the image + """ Calculate the distances and direction to the nearest edge bounded by (H, W) for each channel's vertices Arguments: - verts: torch.Tensor - The vertices of the attention map. Shape: (B, C, 2) + verts: torch.Tensor - The vertices. Shape: (B, C, 2) h: int - The height of the image w: int - The width of the image + Returns: + torch.Tensor, torch.Tensor: + - The minimum distance of each vertex to the nearest edge. Shape: (B, C) + - The direction to the nearest edge. Shape: (B, C, 2) """ - x_coords, y_coords = verts[:, :, 0], verts[:, :, 1] # coordinates - distances_to_edges = torch.stack([y_coords, h - y_coords, x_coords, w - x_coords], dim=-1) # (B, C, 4) - min_distances = torch.min(distances_to_edges, dim=-1).values # (B, C) - return min_distances + x = verts[..., 0] + y = verts[..., 1] + + # Calculate distances to the edges + distances = torch.stack([y, h - y, x, w - x], dim=-1) + + # Find the minimum distance and the corresponding edge + min_distances, min_indices = distances.min(dim=-1) + + # Map edge indices to direction vectors + directions = torch.tensor([[0, -1], [0, 1], [-1, 0], [1, 0]]).to(verts.device) + nearest_edge_dir = directions[min_indices] + + return min_distances, nearest_edge_dir def margin_force(strength, H, W, verts): @@ -125,9 +141,10 @@ def margin_force(strength, H, W, verts): Returns: torch.Tensor - The force for each vertex. Shape: (B, C, 2) """ - min_distances = min_distance_to_nearest_edge(verts, H, W) # (B, C) + min_distances, nearest_edge_dir = min_distance_to_nearest_edge(verts, H, W) # (B, C), (B, C, 2) + min_distances = min_distances.unsqueeze(-1) # (B, C, 1) force = -strength / (min_distances ** 2) - return force + return force * nearest_edge_dir def repulsive_force(strength, pos_vertex, pos_target): @@ -161,7 +178,6 @@ def multi_target_force(attention_map, omega, xi, pos_vertex, pos_target): force = -xi ** 2 pass - def calculate_centroid(attention_map): """ Calculate the centroid of the attention map @@ -214,6 +230,54 @@ def detect_conflict(attention_map, region, theta): return conflict +### TODO: do this +def translate_image(image, tx, ty): + """ + Translate an image tensor by (tx, ty). + + Parameters: + - image: The image tensor of shape (B, C, H, W) + - tx: The translation along the x-axis (B, C, 2) + - ty: The translation along the y-axis (B, C, 2) + + Returns: + - Translated image tensor + """ + B, C, H, W = image.size() + + # Create an affine transformation matrix for translation + theta = torch.tensor([ [1, 0, 0], [0, 1, 0]], dtype=image.dtype, device=image.device) + theta = theta.unsqueeze(0).repeat(B, 1, 1) + + # sgfdgfdfgsdfg + ... + + # Create the grid + grid = F.affine_grid(theta, image.size(), align_corners=False) + + # Apply the grid to the image using grid_sample + translated_image = F.grid_sample(image, grid, mode='bilinear', padding_mode='zeros', align_corners=False) + + return translated_image + + +def apply_displacements(attention_map, displacements): + """ Update the attention map based on the displacements. + The attention map is updated by displacing the attention values based on the displacements. + - Areas that are displaced out of the attention map are discarded. + - Areas that are displaced into the attention map are initialized with zeros. + Arguments: + attention_map: torch.Tensor - The attention map to update. Shape: (B, H, W, C) + displacements: torch.Tensor - The displacements to apply. Shape: (B, C, 2) + """ + B, H, W, C = attention_map.shape + attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) + # apply displacements + attention_map = translate_image(attention_map, displacements[..., 0], displacements[..., 1]) + + attention_map = attention_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) + return attention_map + def get_attention_scores(to_q_map, to_k_map, dtype): """ Calculate the attention scores for the given query and key maps @@ -245,13 +309,16 @@ def get_attention_scores(to_q_map, to_k_map, dtype): if __name__ == '__main__': - # repulsive force - B, H, W, C = 1, 64, 64, 1 - verts = torch.tensor([[[8, 8], [16, 48], [31, 31]]], dtype=torch.float16, device='cuda') # B C 2 + B, H, W, C = 1, 64, 64, 3 + verts = torch.tensor([[[1, 2], [16, 48], [31, 31]]], dtype=torch.float16, device='cuda') # B C 2 target = torch.tensor([[[32, 32]]], dtype=torch.float16, device='cuda') # B 1 2 - r_force = repulsive_force(1, verts, target) + attention_map = torch.ones(B, H, W, C).to('cuda') # B H W C + s_margin = 10.0 + s_repl = 1.0 + displ_force = displacement_force(attention_map, verts, target, s_repl, s_margin) + new_attention_map = apply_displacements(attention_map, displ_force) # conflict detection From 9c7486d66cc96ba537b1cb68a4596e8c8e8aa93d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sat, 25 May 2024 19:17:25 -0700 Subject: [PATCH 12/55] wip translation fn --- scripts/tcg.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index c54c986..823c213 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -234,6 +234,7 @@ def detect_conflict(attention_map, region, theta): def translate_image(image, tx, ty): """ Translate an image tensor by (tx, ty). + https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html Parameters: - image: The image tensor of shape (B, C, H, W) @@ -243,22 +244,37 @@ def translate_image(image, tx, ty): Returns: - Translated image tensor """ + + #image = image.unsqueeze(dim=1) + B, C, H, W = image.size() - # Create an affine transformation matrix for translation - theta = torch.tensor([ [1, 0, 0], [0, 1, 0]], dtype=image.dtype, device=image.device) - theta = theta.unsqueeze(0).repeat(B, 1, 1) + # Create an grid matrix for the translation + c_dim = torch.linspace(-1, 1, C, device=image.device, dtype=image.dtype) # channel dim from [-1 to 1] + h_dim = torch.linspace(-1, 1, H, device=image.device, dtype=image.dtype) # height dim from [-1 to 1] + w_dim = torch.linspace(-1, 1, W, device=image.device, dtype=image.dtype) # width dim to [-1 to 1] + + c_dim = c_dim.view(C, 1, 1).repeat(1, H, W) + h_dim = h_dim.view(1, H, 1).repeat(C, 1, W) + w_dim = w_dim.view(1, 1, W).repeat(1, H, 1) - # sgfdgfdfgsdfg - ... + # translate each dim by the displacements + h_dim = h_dim + ty.squeeze(0).view(C, 1, 1) + w_dim = w_dim + tx.squeeze(0).view(C, 1, 1) - # Create the grid - grid = F.affine_grid(theta, image.size(), align_corners=False) + c_dim = c_dim.unsqueeze(-1) + h_dim = h_dim.unsqueeze(-1) + w_dim = w_dim.unsqueeze(-1) + + # Create 4D grid for 5D input + grid = torch.cat([c_dim, h_dim, w_dim], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1, 1) # (B, C, H, W, 3) + + image = image.unsqueeze(dim=1) # (B, 1, C, H, W) # Apply the grid to the image using grid_sample translated_image = F.grid_sample(image, grid, mode='bilinear', padding_mode='zeros', align_corners=False) - return translated_image + return translated_image.squeeze(1) def apply_displacements(attention_map, displacements): @@ -272,6 +288,7 @@ def apply_displacements(attention_map, displacements): """ B, H, W, C = attention_map.shape attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) + # apply displacements attention_map = translate_image(attention_map, displacements[..., 0], displacements[..., 1]) @@ -309,12 +326,12 @@ def get_attention_scores(to_q_map, to_k_map, dtype): if __name__ == '__main__': - B, H, W, C = 1, 64, 64, 3 - verts = torch.tensor([[[1, 2], [16, 48], [31, 31]]], dtype=torch.float16, device='cuda') # B C 2 + B, H, W, C = 2, 64, 64, 6 + verts = torch.tensor([[[1, 2], [16, 48], [31, 31], [63, 63], [48, 12], [62,2]]], dtype=torch.float16, device='cuda') # B C 2 target = torch.tensor([[[32, 32]]], dtype=torch.float16, device='cuda') # B 1 2 attention_map = torch.ones(B, H, W, C).to('cuda') # B H W C - s_margin = 10.0 + s_margin = 1.0 s_repl = 1.0 displ_force = displacement_force(attention_map, verts, target, s_repl, s_margin) From 40cfdeb5f9ff8a3de684700bdf42beaf59289a1b Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 12:25:18 -0700 Subject: [PATCH 13/55] add translate_image_2d --- scripts/tcg.py | 98 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 90 insertions(+), 8 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 823c213..5fbca5b 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -8,7 +8,7 @@ sys.path.append(f'{os.getcwd()}') sys.path.append(f'{os.getcwd()}/extensions/sd-webui-incantations') from scripts.ui_wrapper import UIWrapper -from scripts.incant_utils import module_hooks +from scripts.incant_utils import module_hooks, plot_tools """ WIP Implementation of https://arxiv.org/abs/2404.11824 @@ -230,6 +230,50 @@ def detect_conflict(attention_map, region, theta): return conflict +def translate_image_2d(image, tx, ty): + """ + Translate an image tensor by (tx, ty). + https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + + Parameters: + - image: The image tensor of shape (B, 1, H, W) + - tx: The translation along the x-axis (B, 2) + - ty: The translation along the y-axis (B, 2) + + Returns: + - Translated image tensor + """ + + #image = image.unsqueeze(dim=1) + + B, C, H, W = image.size() + + # Create an grid matrix for the translation + c_dim = torch.tensor([0], device=image.device, dtype=image.dtype) + h_dim = torch.linspace(-1, 1, H, device=image.device, dtype=image.dtype) # height dim from [-1 to 1] + w_dim = torch.linspace(-1, 1, W, device=image.device, dtype=image.dtype) # width dim to [-1 to 1] + + c_dim = c_dim.view(C, 1, 1).repeat(1, H, W) + h_dim = h_dim.view(1, H, 1).repeat(C, 1, W) + w_dim = w_dim.view(1, 1, W).repeat(1, H, 1) + + # translate each dim by the displacements + h_dim = h_dim + ty.squeeze(0).view(C, 1, 1) + w_dim = w_dim + tx.squeeze(0).view(C, 1, 1) + + c_dim = c_dim.unsqueeze(-1) + h_dim = h_dim.unsqueeze(-1) + w_dim = w_dim.unsqueeze(-1) + + # Create 4D grid for 5D input + grid = torch.cat([h_dim, w_dim], dim=-1).repeat(1, 1, 1, 1) # (B, H, W, 2) + + # Apply the grid to the image using grid_sample + translated_image = F.grid_sample(image, grid, mode='nearest', padding_mode='zeros', align_corners=True) + + return translated_image + + ### TODO: do this def translate_image(image, tx, ty): """ @@ -250,7 +294,10 @@ def translate_image(image, tx, ty): B, C, H, W = image.size() # Create an grid matrix for the translation - c_dim = torch.linspace(-1, 1, C, device=image.device, dtype=image.dtype) # channel dim from [-1 to 1] + if C > 1: + c_dim = torch.linspace(-1, 1, C, device=image.device, dtype=image.dtype) # channel dim from [-1 to 1] + else: + c_dim = torch.tensor([0], device=image.device, dtype=image.dtype) h_dim = torch.linspace(-1, 1, H, device=image.device, dtype=image.dtype) # height dim from [-1 to 1] w_dim = torch.linspace(-1, 1, W, device=image.device, dtype=image.dtype) # width dim to [-1 to 1] @@ -272,7 +319,7 @@ def translate_image(image, tx, ty): image = image.unsqueeze(dim=1) # (B, 1, C, H, W) # Apply the grid to the image using grid_sample - translated_image = F.grid_sample(image, grid, mode='bilinear', padding_mode='zeros', align_corners=False) + translated_image = F.grid_sample(image, grid, mode='nearest', padding_mode='zeros', align_corners=True) return translated_image.squeeze(1) @@ -285,13 +332,16 @@ def apply_displacements(attention_map, displacements): Arguments: attention_map: torch.Tensor - The attention map to update. Shape: (B, H, W, C) displacements: torch.Tensor - The displacements to apply. Shape: (B, C, 2) + Returns: + torch.Tensor - The updated attention map. Shape: (B, H, W, C) """ B, H, W, C = attention_map.shape attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) # apply displacements - attention_map = translate_image(attention_map, displacements[..., 0], displacements[..., 1]) - + attention_map = translate_image_2d(attention_map, displacements[..., 0], displacements[..., 1]) +# attention_map = translate_image(attention_map, displacements[..., 0], displacements[..., 1]) +# attention_map = attention_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) return attention_map @@ -326,10 +376,42 @@ def get_attention_scores(to_q_map, to_k_map, dtype): if __name__ == '__main__': - B, H, W, C = 2, 64, 64, 6 - verts = torch.tensor([[[1, 2], [16, 48], [31, 31], [63, 63], [48, 12], [62,2]]], dtype=torch.float16, device='cuda') # B C 2 + tempdir = os.path.join(os.getcwd(), 'temp') + os.makedirs(tempdir, exist_ok=True) + + # macro for saving to png + _png = lambda attnmap, name, title: plot_tools.plot_attention_map( + attnmap[0, :, :, 0], + save_path=os.path.join(tempdir, f'{name}.png'), + title=f'{title}', + ) + + B, H, W, C = 1, 8, 8, 1 + dtype = torch.float16 + device = 'cuda' + + # initialize a map with all ones + attention_map = torch.ones((B, H, W, C)).to(device, dtype) # B H W C + + # color half of it with zeros + attention_map[:, :, :W//2] = 0 + + _png(attention_map, 0, 'Initial Attn Map') + + displacements = torch.tensor([0.1, 0], dtype=torch.float16, device='cuda').repeat(B, C, 1) # 2 + new_attention_map = apply_displacements(attention_map, displacements) + + _png(new_attention_map, 1, 'Displaced Attn Map') + + plot_tools.plot_attention_map( + new_attention_map[0, :, :, 0], + save_path = _png('1'), + title='Attention Map Displaced [0.5, 0.5]' + ) + + verts = torch.tensor([[[16, 16]]], dtype=torch.float16, device='cuda') # B C 2 + #verts = torch.tensor([[[1, 2], [16, 48], [31, 31], [63, 63], [48, 12], [62,2]]], dtype=torch.float16, device='cuda') # B C 2 target = torch.tensor([[[32, 32]]], dtype=torch.float16, device='cuda') # B 1 2 - attention_map = torch.ones(B, H, W, C).to('cuda') # B H W C s_margin = 1.0 s_repl = 1.0 From e267e9adb8d7dacb7f8758729dd134ba7ed1966a Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 12:49:21 -0700 Subject: [PATCH 14/55] implement translate 2d --- scripts/tcg.py | 46 +++++++++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 5fbca5b..6198ff7 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -97,6 +97,8 @@ def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margi target : torch.Tensor - The vertices of the targets. Shape: (2) f_rep_strength: float - The strength of the repulsive force f_margin_strength: float - The strength of the margin force + Returns: + torch.Tensor - The displacement force for each vertex. Shape: (B, C, 2) """ B, H, W, C = attention_map.shape f_rep = repulsive_force(f_rep_strength, verts, target_pos) @@ -230,15 +232,14 @@ def detect_conflict(attention_map, region, theta): return conflict -def translate_image_2d(image, tx, ty): +def translate_image_2d(image, txy): """ Translate an image tensor by (tx, ty). https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html Parameters: - - image: The image tensor of shape (B, 1, H, W) - - tx: The translation along the x-axis (B, 2) - - ty: The translation along the y-axis (B, 2) + - image: The image tensor of shape (1, C, H, W) + - txy: The translation along y and x-axis (C, 2) Returns: - Translated image tensor @@ -246,30 +247,40 @@ def translate_image_2d(image, tx, ty): #image = image.unsqueeze(dim=1) - B, C, H, W = image.size() + C, H, W = image.size() + # swap N and C + image = image.unsqueeze(dim=1).to(torch.float32) # (C, 1, H, W) + + # grid bounds, not doing this means losing information at the edges at low resolution + pos = 1 - 1e-6 # + neg = -pos # Create an grid matrix for the translation - c_dim = torch.tensor([0], device=image.device, dtype=image.dtype) - h_dim = torch.linspace(-1, 1, H, device=image.device, dtype=image.dtype) # height dim from [-1 to 1] - w_dim = torch.linspace(-1, 1, W, device=image.device, dtype=image.dtype) # width dim to [-1 to 1] + h_dim = torch.linspace(neg, pos, H, device=image.device, dtype=image.dtype) # height dim from [-1 to 1] + w_dim = torch.linspace(neg, pos, W, device=image.device, dtype=image.dtype) # width dim to [-1 to 1] - c_dim = c_dim.view(C, 1, 1).repeat(1, H, W) + if C > 1: + c_dim = torch.linspace(-1, 1, C, device=image.device, dtype=image.dtype) + else: + c_dim = torch.tensor([0], device=image.device, dtype=image.dtype) + + # c_dim = b_dim.view(C, 1, 1).repeat(1, H, W) h_dim = h_dim.view(1, H, 1).repeat(C, 1, W) - w_dim = w_dim.view(1, 1, W).repeat(1, H, 1) + w_dim = w_dim.view(1, 1, W).repeat(C, H, 1) # translate each dim by the displacements - h_dim = h_dim + ty.squeeze(0).view(C, 1, 1) - w_dim = w_dim + tx.squeeze(0).view(C, 1, 1) + # h_dim = h_dim + ty.squeeze(0).view(C, 1, 1) + # w_dim = w_dim + tx.squeeze(0).view(C, 1, 1) c_dim = c_dim.unsqueeze(-1) h_dim = h_dim.unsqueeze(-1) w_dim = w_dim.unsqueeze(-1) # Create 4D grid for 5D input - grid = torch.cat([h_dim, w_dim], dim=-1).repeat(1, 1, 1, 1) # (B, H, W, 2) + grid = torch.cat([w_dim, h_dim], dim=-1).repeat(1, 1, 1, 1) # (C, H, W, 2) # Apply the grid to the image using grid_sample - translated_image = F.grid_sample(image, grid, mode='nearest', padding_mode='zeros', align_corners=True) + translated_image = F.grid_sample(image, grid, mode='nearest', padding_mode='zeros', align_corners=False) return translated_image @@ -337,12 +348,13 @@ def apply_displacements(attention_map, displacements): """ B, H, W, C = attention_map.shape attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) - + out_attn_map = attention_map.detach().clone() + for batch_idx in range(B): # apply displacements - attention_map = translate_image_2d(attention_map, displacements[..., 0], displacements[..., 1]) + out_attn_map[batch_idx] = translate_image_2d(attention_map[batch_idx], displacements[batch_idx]) # attention_map = translate_image(attention_map, displacements[..., 0], displacements[..., 1]) # - attention_map = attention_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) + out_attn_map = out_attn_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) return attention_map From e599bb90d4d46083ddbcefad2fea707907e71f64 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 12:55:42 -0700 Subject: [PATCH 15/55] apply transformation --- scripts/tcg.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 6198ff7..8512dc3 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -265,19 +265,23 @@ def translate_image_2d(image, txy): c_dim = torch.tensor([0], device=image.device, dtype=image.dtype) # c_dim = b_dim.view(C, 1, 1).repeat(1, H, W) - h_dim = h_dim.view(1, H, 1).repeat(C, 1, W) - w_dim = w_dim.view(1, 1, W).repeat(C, H, 1) + h_dim = h_dim.view(1, H).repeat(C, 1) + w_dim = w_dim.view(1, W).repeat(C, 1) # translate each dim by the displacements - # h_dim = h_dim + ty.squeeze(0).view(C, 1, 1) - # w_dim = w_dim + tx.squeeze(0).view(C, 1, 1) + tx, ty = txy[..., 0], txy[..., 1] + h_dim = h_dim + tx.unsqueeze(-1) + w_dim = w_dim + ty.unsqueeze(-1) - c_dim = c_dim.unsqueeze(-1) + h_dim = h_dim.unsqueeze(dim=-1).repeat(1, 1, W) # (C, H, W) + w_dim = w_dim.unsqueeze(dim=1).repeat(1, H, 1) # (C, H, W) + + #c_dim = c_dim.unsqueeze(-1) h_dim = h_dim.unsqueeze(-1) w_dim = w_dim.unsqueeze(-1) # Create 4D grid for 5D input - grid = torch.cat([w_dim, h_dim], dim=-1).repeat(1, 1, 1, 1) # (C, H, W, 2) + grid = torch.cat([w_dim, h_dim], dim=-1) # (C, H, W, 2) # Apply the grid to the image using grid_sample translated_image = F.grid_sample(image, grid, mode='nearest', padding_mode='zeros', align_corners=False) From 973a948f1f98dcb639bfbb7c91a280e29e1e9d7a Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 13:45:45 -0700 Subject: [PATCH 16/55] fix weird flip --- scripts/tcg.py | 100 ++++++++++++++++++++++--------------------------- 1 file changed, 44 insertions(+), 56 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 8512dc3..f3b47f4 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -109,16 +109,16 @@ def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margi def min_distance_to_nearest_edge(verts, h, w): """ Calculate the distances and direction to the nearest edge bounded by (H, W) for each channel's vertices Arguments: - verts: torch.Tensor - The vertices. Shape: (B, C, 2) + verts: torch.Tensor - The vertices. Shape: (B, C, 2), where the last 2 dims are (y, x) h: int - The height of the image w: int - The width of the image Returns: torch.Tensor, torch.Tensor: - The minimum distance of each vertex to the nearest edge. Shape: (B, C) - - The direction to the nearest edge. Shape: (B, C, 2) + - The direction to the nearest edge. Shape: (B, C, 2), where the last 2 dims are (y, x) """ - x = verts[..., 0] - y = verts[..., 1] + y = verts[..., 0] # y-axis is 0! + x = verts[..., 1] # Calculate distances to the edges distances = torch.stack([y, h - y, x, w - x], dim=-1) @@ -127,7 +127,7 @@ def min_distance_to_nearest_edge(verts, h, w): min_distances, min_indices = distances.min(dim=-1) # Map edge indices to direction vectors - directions = torch.tensor([[0, -1], [0, 1], [-1, 0], [1, 0]]).to(verts.device) + directions = torch.tensor([[-1, 0], [1, 0], [0, 1], [0, -1]]).to(verts.device) nearest_edge_dir = directions[min_indices] return min_distances, nearest_edge_dir @@ -232,61 +232,54 @@ def detect_conflict(attention_map, region, theta): return conflict -def translate_image_2d(image, txy): +def translate_image_2d(image, tyx): """ - Translate an image tensor by (tx, ty). + Translate an image tensor by (ty, tx). https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html Parameters: - - image: The image tensor of shape (1, C, H, W) - - txy: The translation along y and x-axis (C, 2) + - image: The image tensor of shape (B, C, H, W) where B = 1 + - tyx: The translation along y and x-axis for each channel (C, 2), where the last 2 dims are the translation by [Y, X] Returns: - Translated image tensor """ - #image = image.unsqueeze(dim=1) - - C, H, W = image.size() + B, C, H, W = image.size() - # swap N and C - image = image.unsqueeze(dim=1).to(torch.float32) # (C, 1, H, W) + # swap B and C + image = image.transpose(0, 1) # (C, B, H, W) # grid bounds, not doing this means losing information at the edges at low resolution pos = 1 - 1e-6 # neg = -pos # Create an grid matrix for the translation - h_dim = torch.linspace(neg, pos, H, device=image.device, dtype=image.dtype) # height dim from [-1 to 1] - w_dim = torch.linspace(neg, pos, W, device=image.device, dtype=image.dtype) # width dim to [-1 to 1] - - if C > 1: - c_dim = torch.linspace(-1, 1, C, device=image.device, dtype=image.dtype) - else: - c_dim = torch.tensor([0], device=image.device, dtype=image.dtype) + # (-1, -1) is left top pixel, (1, 1) is right bottom pixel + h_dim = torch.linspace(neg, pos, H, device=image.device, dtype=image.dtype) # height dim from [-1 (top) to 1 (bottom)] + w_dim = torch.linspace(neg, pos, W, device=image.device, dtype=image.dtype) # width dim to [-1 (left) to 1 (right)] - # c_dim = b_dim.view(C, 1, 1).repeat(1, H, W) - h_dim = h_dim.view(1, H).repeat(C, 1) - w_dim = w_dim.view(1, W).repeat(C, 1) + h_dim = h_dim.view(1, H, 1).repeat(C, 1, W) + w_dim = w_dim.view(1, 1, W).repeat(C, H, 1) # translate each dim by the displacements - tx, ty = txy[..., 0], txy[..., 1] - h_dim = h_dim + tx.unsqueeze(-1) - w_dim = w_dim + ty.unsqueeze(-1) + ty, tx = tyx[..., 0], tyx[..., 1] # C, C + h_dim = h_dim + ty.view(C, 1, 1) + w_dim = w_dim + tx.view(C, 1, 1) - h_dim = h_dim.unsqueeze(dim=-1).repeat(1, 1, W) # (C, H, W) - w_dim = w_dim.unsqueeze(dim=1).repeat(1, H, 1) # (C, H, W) + #h_dim = h_dim.unsqueeze(dim=-1).repeat(1, W, 1) # (C, H, W) + #w_dim = w_dim.unsqueeze(dim=1).repeat(1, 1, H) # (C, H, W) #c_dim = c_dim.unsqueeze(-1) - h_dim = h_dim.unsqueeze(-1) - w_dim = w_dim.unsqueeze(-1) + h_dim = h_dim.unsqueeze(-1) # (C, H, W, 1) + w_dim = w_dim.unsqueeze(-1) # (C, H, W, 1) # Create 4D grid for 5D input - grid = torch.cat([w_dim, h_dim], dim=-1) # (C, H, W, 2) + grid = torch.cat([h_dim, w_dim], dim=-1) # (C, H, W, 2) # Apply the grid to the image using grid_sample - translated_image = F.grid_sample(image, grid, mode='nearest', padding_mode='zeros', align_corners=False) + translated_image = F.grid_sample(image, grid, mode='nearest', padding_mode='zeros', align_corners=False) # C N H W - return translated_image + return translated_image.transpose(0, 1) # N C H W ### TODO: do this @@ -346,20 +339,21 @@ def apply_displacements(attention_map, displacements): - Areas that are displaced into the attention map are initialized with zeros. Arguments: attention_map: torch.Tensor - The attention map to update. Shape: (B, H, W, C) - displacements: torch.Tensor - The displacements to apply. Shape: (B, C, 2) + displacements: torch.Tensor - The displacements to apply. Shape: (B, C, 2), where the last 2 dims are the translation by [Y, X] Returns: torch.Tensor - The updated attention map. Shape: (B, H, W, C) """ B, H, W, C = attention_map.shape - attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) + attention_map = attention_map.permute(0, 3, 2, 1) # (B, H, W, C) -> (B, C, H, W) + #attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) out_attn_map = attention_map.detach().clone() - for batch_idx in range(B): + # apply displacements - out_attn_map[batch_idx] = translate_image_2d(attention_map[batch_idx], displacements[batch_idx]) -# attention_map = translate_image(attention_map, displacements[..., 0], displacements[..., 1]) + for batch_idx in range(B): + out_attn_map[batch_idx] = translate_image_2d(attention_map[batch_idx].unsqueeze(0), displacements[batch_idx]).squeeze(0) # out_attn_map = out_attn_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) - return attention_map + return out_attn_map def get_attention_scores(to_q_map, to_k_map, dtype): @@ -402,28 +396,22 @@ def get_attention_scores(to_q_map, to_k_map, dtype): title=f'{title}', ) - B, H, W, C = 1, 8, 8, 1 + B, H, W, C = 1, 64, 64, 1 dtype = torch.float16 device = 'cuda' # initialize a map with all ones - attention_map = torch.ones((B, H, W, C)).to(device, dtype) # B H W C - - # color half of it with zeros - attention_map[:, :, :W//2] = 0 - + attention_map = torch.zeros((B, H, W, C)).to(device, dtype) # B H W C + attention_map[:, :, W//4:2*W//4] = 1.0 _png(attention_map, 0, 'Initial Attn Map') - displacements = torch.tensor([0.1, 0], dtype=torch.float16, device='cuda').repeat(B, C, 1) # 2 - new_attention_map = apply_displacements(attention_map, displacements) - - _png(new_attention_map, 1, 'Displaced Attn Map') - - plot_tools.plot_attention_map( - new_attention_map[0, :, :, 0], - save_path = _png('1'), - title='Attention Map Displaced [0.5, 0.5]' - ) + # apply a simple transformation + # translate Y by -1, translate x by 0 + for i in range(3): + ofs = round(0.5 * (i+1), 2) + displacements = torch.tensor([ofs, 0], dtype=torch.float16, device='cuda').repeat(B, C, 1) # 2 + new_attention_map = apply_displacements(attention_map, displacements) + _png(new_attention_map, 1+i, f'Move Initial [{ofs}, 0]') verts = torch.tensor([[[16, 16]]], dtype=torch.float16, device='cuda') # B C 2 #verts = torch.tensor([[[1, 2], [16, 48], [31, 31], [63, 63], [48, 12], [62,2]]], dtype=torch.float16, device='cuda') # B C 2 From 0d8ae5bef88af512ad0ccdc72be643c2a5aec1d7 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 13:53:24 -0700 Subject: [PATCH 17/55] fix weird translation, switch to align corners true --- scripts/tcg.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index f3b47f4..2195ecb 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -251,8 +251,11 @@ def translate_image_2d(image, tyx): image = image.transpose(0, 1) # (C, B, H, W) # grid bounds, not doing this means losing information at the edges at low resolution - pos = 1 - 1e-6 # - neg = -pos + # hack to prevent out of boudns when align_corners is false + # pos = 1 - 1e-3 # hack to prevent out of bounds + # neg = -pos + pos, neg = 1, -1 + # Create an grid matrix for the translation # (-1, -1) is left top pixel, (1, 1) is right bottom pixel h_dim = torch.linspace(neg, pos, H, device=image.device, dtype=image.dtype) # height dim from [-1 (top) to 1 (bottom)] @@ -263,6 +266,8 @@ def translate_image_2d(image, tyx): # translate each dim by the displacements ty, tx = tyx[..., 0], tyx[..., 1] # C, C + ty *= -1 # invert y for some weird reason + tx *= -1 # invert x for some weird reason h_dim = h_dim + ty.view(C, 1, 1) w_dim = w_dim + tx.view(C, 1, 1) @@ -277,7 +282,7 @@ def translate_image_2d(image, tyx): grid = torch.cat([h_dim, w_dim], dim=-1) # (C, H, W, 2) # Apply the grid to the image using grid_sample - translated_image = F.grid_sample(image, grid, mode='nearest', padding_mode='zeros', align_corners=False) # C N H W + translated_image = F.grid_sample(image, grid, mode='nearest', padding_mode='zeros', align_corners=True) # C N H W return translated_image.transpose(0, 1) # N C H W @@ -388,6 +393,7 @@ def get_attention_scores(to_q_map, to_k_map, dtype): if __name__ == '__main__': tempdir = os.path.join(os.getcwd(), 'temp') os.makedirs(tempdir, exist_ok=True) + img_idx = 0 # macro for saving to png _png = lambda attnmap, name, title: plot_tools.plot_attention_map( @@ -402,7 +408,7 @@ def get_attention_scores(to_q_map, to_k_map, dtype): # initialize a map with all ones attention_map = torch.zeros((B, H, W, C)).to(device, dtype) # B H W C - attention_map[:, :, W//4:2*W//4] = 1.0 + attention_map[:, :, W//4:3*W//4] = 1.0 _png(attention_map, 0, 'Initial Attn Map') # apply a simple transformation @@ -411,7 +417,15 @@ def get_attention_scores(to_q_map, to_k_map, dtype): ofs = round(0.5 * (i+1), 2) displacements = torch.tensor([ofs, 0], dtype=torch.float16, device='cuda').repeat(B, C, 1) # 2 new_attention_map = apply_displacements(attention_map, displacements) - _png(new_attention_map, 1+i, f'Move Initial [{ofs}, 0]') + _png(new_attention_map, img_idx+1, f'Move Initial [{ofs}, 0]') + img_idx += 1 + + for i in range(3): + ofs = round(0.5 * (i+1), 2) + displacements = torch.tensor([0, ofs], dtype=torch.float16, device='cuda').repeat(B, C, 1) # 2 + new_attention_map = apply_displacements(attention_map, displacements) + _png(new_attention_map, img_idx+1, f'Move Initial [0, {ofs}]') + img_idx += 1 verts = torch.tensor([[[16, 16]]], dtype=torch.float16, device='cuda') # B C 2 #verts = torch.tensor([[[1, 2], [16, 48], [31, 31], [63, 63], [48, 12], [62,2]]], dtype=torch.float16, device='cuda') # B C 2 From c1c65e566a3510414307a3cbe48e93504aecab63 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 14:19:51 -0700 Subject: [PATCH 18/55] centroid seems to work --- scripts/tcg.py | 64 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 2195ecb..0e2a968 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -389,6 +389,38 @@ def get_attention_scores(to_q_map, to_k_map, dtype): return attn_probs +def plot_point(image, point, radius=1, color=1.0): + """ Plot a point on an image tensor + Arguments: + image: torch.Tensor - The image tensor to plot the point on. Shape: (B, H, W, C) + point: tuple - The point to plot (y, x) + radius: int - The radius of the point + color: tuple - The color of the point + Returns: + torch.Tensor - The image tensor with the point plotted + """ + y, x = point + image[:, y - radius:y + radius + 1, x - radius:x + radius + 1, :] = color + + +def color_region(image, yx, ab, color=1.0, mode='set'): + """ Color in a region of an image tensor + Arguments: + image: torch.Tensor - The image tensor to plot the point on. Shape: (B, H, W, C) + yx: (int, int) - The y-coordinate and x-coordinate of the upper left corner of the region + ab: (int, int) - The x-coordinate and y-coordinate of the lower right corner of the region + color: tuple - The color of the region + mode: str - The mode of coloring. 'set' to set the region to the color, 'add' to add the color to the region + Returns: + torch.Tensor - The image tensor with the point plotted + """ + y, x = yx + a, b = ab + if mode == 'set': + image[:, y:a, x:b, :] = color + elif mode == 'add': + image[:, y:a, x:b, :] += color + if __name__ == '__main__': tempdir = os.path.join(os.getcwd(), 'temp') @@ -406,9 +438,39 @@ def get_attention_scores(to_q_map, to_k_map, dtype): dtype = torch.float16 device = 'cuda' + # plotted points as proxies for vertices + vert_list = [ + [3*H//4, W//2], # (lower middle) + [H//4, W//4], # (upper left middle quadrant) + [H//2, W] # (right middle) + ] + target_position = [H//2, W//2] + # upper left, lower right + verts = torch.tensor([vert_list], dtype=torch.float16, device='cuda') # B C 2 + + # region to represent the target region + region_yx = [H//8, W//8] + region_ab = [3*H//8, 3*W//8] + # initialize a map with all ones attention_map = torch.zeros((B, H, W, C)).to(device, dtype) # B H W C - attention_map[:, :, W//4:3*W//4] = 1.0 + + + # calculate centroid of region + centroid = calculate_centroid(attention_map) # (B, C, 2) + centroid = centroid.squeeze(0).squeeze(0).cpu().numpy().astype(int) + + # color a middleish region + attention_map[:, :, W//4:3*W//4] = 0.5 + + # plot verts + for v in vert_list: + plot_point(attention_map, v, radius=1) + + # color the target region and plot centroid last + color_region(attention_map, region_yx, region_ab, color=0.5, mode='add') + plot_point(attention_map, centroid, radius=3, color=1) + _png(attention_map, 0, 'Initial Attn Map') # apply a simple transformation From 167bc8e13c0b309f06566bd3df6551af7d0fa296 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 14:26:36 -0700 Subject: [PATCH 19/55] fix centroid calculation divide by zero --- scripts/tcg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 0e2a968..fc18f99 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -196,7 +196,7 @@ def calculate_centroid(attention_map): w_coords = torch.arange(W).view(1, 1, W, 1).to(attention_map.device) # Sum of attention scores for each channel - attention_sum = torch.sum(attention_map, dim=(1, 2)) # shape: (B, C) + attention_sum = torch.sum(attention_map, dim=(1, 2)) + torch.finfo(attention_map.dtype).eps # shape: (B, C) # Weighted sum of the coordinates h_weighted_sum = torch.sum(h_coords * attention_map, dim=(1,2)) # (B, C) From 47f715994344d69fe64d06a73d8092a58e06d59f Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 14:50:52 -0700 Subject: [PATCH 20/55] testing region moving --- scripts/tcg.py | 92 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 55 insertions(+), 37 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index fc18f99..aaf06dc 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -93,7 +93,7 @@ def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margi """ Given a set of vertices, calculate the displacement force given by the sum of margin force and repulsive force. Arguments: attention_map: torch.Tensor - The attention map to calculate the force. Shape: (B, H, W, C) - verts: torch.Tensor - The vertices of the attention map. Shape: (B, C, 2) + verts: torch.Tensor - The centroid vertices of the attention map. Shape: (B, C, 2) target : torch.Tensor - The vertices of the targets. Shape: (2) f_rep_strength: float - The strength of the repulsive force f_margin_strength: float - The strength of the margin force @@ -160,7 +160,7 @@ def repulsive_force(strength, pos_vertex, pos_target): torch.Tensor - The force away from the target. Shape: (B, C, 2) """ d_pos = pos_vertex - pos_target # (B, C, 2) - d_pos_norm = d_pos.norm(dim=-1, keepdim=True) # normalize the direction + d_pos_norm = d_pos.norm(dim=-1, keepdim=True) + torch.finfo(d_pos.dtype).eps # normalize the direction d_pos /= d_pos_norm force = (-strength) ** 2 return force * d_pos @@ -400,7 +400,11 @@ def plot_point(image, point, radius=1, color=1.0): torch.Tensor - The image tensor with the point plotted """ y, x = point - image[:, y - radius:y + radius + 1, x - radius:x + radius + 1, :] = color + y_min = (y- radius).to(torch.int32) + y_max = (y + radius + 1).to(torch.int32) + x_min = (x - radius).to(torch.int32) + x_max = (x + radius + 1).to(torch.int32) + image[:, y_min:y_max, x_min:x_max, :] = color def color_region(image, yx, ab, color=1.0, mode='set'): @@ -440,9 +444,9 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # plotted points as proxies for vertices vert_list = [ - [3*H//4, W//2], # (lower middle) - [H//4, W//4], # (upper left middle quadrant) - [H//2, W] # (right middle) + #[3*H//4, W//2], # (lower middle) + [H//4+1, W//4+1], # (upper left middle quadrant) + #[H//2, W] # (right middle) ] target_position = [H//2, W//2] # upper left, lower right @@ -455,49 +459,63 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # initialize a map with all ones attention_map = torch.zeros((B, H, W, C)).to(device, dtype) # B H W C + # color the target region + color_region(attention_map, region_yx, region_ab, color=0.5, mode='add') + + _png(attention_map, 0, 'Initial Attn Map') # calculate centroid of region centroid = calculate_centroid(attention_map) # (B, C, 2) - centroid = centroid.squeeze(0).squeeze(0).cpu().numpy().astype(int) - - # color a middleish region - attention_map[:, :, W//4:3*W//4] = 0.5 + centroid_points = centroid.squeeze(0).squeeze(0) + #centroid_points = centroid.squeeze(0).squeeze(0).cpu().numpy().astype(int) + # plot region centroid + plot_point(attention_map, centroid_points, radius=1, color=1) + _png(attention_map, 1, 'Attn Map Region + Centroid') # plot verts - for v in vert_list: - plot_point(attention_map, v, radius=1) - - # color the target region and plot centroid last - color_region(attention_map, region_yx, region_ab, color=0.5, mode='add') - plot_point(attention_map, centroid, radius=3, color=1) - - _png(attention_map, 0, 'Initial Attn Map') + attn_map_points = attention_map.detach().clone() + for v in verts.squeeze(0): + plot_point(attn_map_points, v, radius=1) + _png(attn_map_points, 2, 'Points') # apply a simple transformation # translate Y by -1, translate x by 0 - for i in range(3): - ofs = round(0.5 * (i+1), 2) - displacements = torch.tensor([ofs, 0], dtype=torch.float16, device='cuda').repeat(B, C, 1) # 2 - new_attention_map = apply_displacements(attention_map, displacements) - _png(new_attention_map, img_idx+1, f'Move Initial [{ofs}, 0]') - img_idx += 1 - - for i in range(3): - ofs = round(0.5 * (i+1), 2) - displacements = torch.tensor([0, ofs], dtype=torch.float16, device='cuda').repeat(B, C, 1) # 2 - new_attention_map = apply_displacements(attention_map, displacements) - _png(new_attention_map, img_idx+1, f'Move Initial [0, {ofs}]') - img_idx += 1 - - verts = torch.tensor([[[16, 16]]], dtype=torch.float16, device='cuda') # B C 2 - #verts = torch.tensor([[[1, 2], [16, 48], [31, 31], [63, 63], [48, 12], [62,2]]], dtype=torch.float16, device='cuda') # B C 2 - target = torch.tensor([[[32, 32]]], dtype=torch.float16, device='cuda') # B 1 2 + # for i in range(3): + # ofs = round(0.5 * (i+1), 2) + # displacements = torch.tensor([ofs, 0], dtype=torch.float16, device='cuda').repeat(B, C, 1) # 2 + # new_attention_map = apply_displacements(attention_map, displacements) + # _png(new_attention_map, img_idx+1, f'Move Initial [{ofs}, 0]') + # img_idx += 1 + + # for i in range(3): + # ofs = round(0.5 * (i+1), 2) + # displacements = torch.tensor([0, ofs], dtype=torch.float16, device='cuda').repeat(B, C, 1) # 2 + # new_attention_map = apply_displacements(attention_map, displacements) + # _png(new_attention_map, img_idx+1, f'Move Initial [0, {ofs}]') + # img_idx += 1 s_margin = 1.0 s_repl = 1.0 - displ_force = displacement_force(attention_map, verts, target, s_repl, s_margin) - new_attention_map = apply_displacements(attention_map, displ_force) + # simulate displacement forces on our points + img_idx = 3 + for i in range(3): + attn_map_points = torch.zeros_like(attention_map) + #attn_map_points = attention_map.detach().clone() + for v in verts: + plot_point(attn_map_points, v.squeeze(0), radius=1) + displ_force = displacement_force(attention_map, verts, centroid, s_repl, s_margin) # B C 2 + new_attention_map = apply_displacements(attention_map, displ_force) + + color_region(attn_map_points, region_yx, region_ab, color=0.5, mode='add') + plot_point(attn_map_points, centroid_points, radius=1, color=1) + + _png(new_attention_map, img_idx+i, f'Displacement Forces') + + new_vert_pos = verts + displ_force + delta_vert_pos = new_vert_pos - verts + verts += delta_vert_pos + # verts = torch.tensor([vert_list], dtype=torch.float16, device='cuda') # B C 2 # conflict detection From 9a0c88a5cf4a5808b2af0f0a7795c1687c2ca702 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 15:00:24 -0700 Subject: [PATCH 21/55] update test --- scripts/tcg.py | 75 +++++++++++++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index aaf06dc..43a4f02 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -434,7 +434,7 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # macro for saving to png _png = lambda attnmap, name, title: plot_tools.plot_attention_map( attnmap[0, :, :, 0], - save_path=os.path.join(tempdir, f'{name}.png'), + save_path=os.path.join(tempdir, f'{name:04}.png'), title=f'{title}', ) @@ -473,7 +473,11 @@ def color_region(image, yx, ab, color=1.0, mode='set'): _png(attention_map, 1, 'Attn Map Region + Centroid') # plot verts - attn_map_points = attention_map.detach().clone() + attn_map_points = torch.randn_like(attention_map) + + # attn_centroid = calculate_centroid(attn_map_points) # (B, C, 2) + + #attn_map_points = attention_map.detach().clone() for v in verts.squeeze(0): plot_point(attn_map_points, v, radius=1) _png(attn_map_points, 2, 'Points') @@ -494,48 +498,57 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # _png(new_attention_map, img_idx+1, f'Move Initial [0, {ofs}]') # img_idx += 1 - s_margin = 1.0 - s_repl = 1.0 + s_margin = 0.2 + s_repl = 0.2 # simulate displacement forces on our points img_idx = 3 - for i in range(3): - attn_map_points = torch.zeros_like(attention_map) - #attn_map_points = attention_map.detach().clone() - for v in verts: - plot_point(attn_map_points, v.squeeze(0), radius=1) - displ_force = displacement_force(attention_map, verts, centroid, s_repl, s_margin) # B C 2 - new_attention_map = apply_displacements(attention_map, displ_force) + iters = 10 + for i in range(iters): + # copy the map + attn_map_points = attn_map_points.detach().clone() + displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin) # B C 2 - color_region(attn_map_points, region_yx, region_ab, color=0.5, mode='add') - plot_point(attn_map_points, centroid_points, radius=1, color=1) + # new map with displacements + displaced_map = apply_displacements(attn_map_points, displ_force) - _png(new_attention_map, img_idx+i, f'Displacement Forces') + # copy to put on top visualizations + copied_map = displaced_map.detach().clone() + + for v in verts: + plot_point(copied_map, v.squeeze(0), radius=1) + color_region(copied_map, region_yx, region_ab, color=1.0, mode='add') + plot_point(copied_map, centroid_points, radius=1, color=1) + + _png(copied_map, img_idx+i, f'Displacement Forces') new_vert_pos = verts + displ_force delta_vert_pos = new_vert_pos - verts verts += delta_vert_pos + + # copy back to the original map + attn_map_points = displaced_map # verts = torch.tensor([vert_list], dtype=torch.float16, device='cuda') # B C 2 - # conflict detection - attention_map = torch.ones(B, H, W, C).to('cuda') # B H W C - region = torch.zeros((B, H, W), dtype=torch.float16, device='cuda') # B H W C - # set the left half of region to 1 - region[:, :, :W//2] = 1 - theta = 0.5 # Example threshold - conflict_detection = detect_conflict(attention_map, region, theta) - print(conflict_detection) + # # conflict detection + # attention_map = torch.ones(B, H, W, C).to('cuda') # B H W C + # region = torch.zeros((B, H, W), dtype=torch.float16, device='cuda') # B H W C + # # set the left half of region to 1 + # region[:, :, :W//2] = 1 + # theta = 0.5 # Example threshold + # conflict_detection = detect_conflict(attention_map, region, theta) + # print(conflict_detection) - # Create a simple attention map with known values - attention_map = torch.zeros((B, H, W, C), device='cuda') # Shape (batch_size, height, width, channels) - attention_map[0, H//2, W//2, 0] = 1.0 # Put all attention on the center + # # Create a simple attention map with known values + # attention_map = torch.zeros((B, H, W, C), device='cuda') # Shape (batch_size, height, width, channels) + # attention_map[0, H//2, W//2, 0] = 1.0 # Put all attention on the center - # Calculate centroids - centroids = calculate_centroid(attention_map) # (B, C, 2) + # # Calculate centroids + # centroids = calculate_centroid(attention_map) # (B, C, 2) - # Expected centroid is the center of the attention map (2, 2) - expected_centroid = torch.tensor([[[H/2, W/2]]], device='cuda') + # # Expected centroid is the center of the attention map (2, 2) + # expected_centroid = torch.tensor([[[H/2, W/2]]], device='cuda') - # Check if the calculated centroid matches the expected centroid - assert torch.allclose(centroids, expected_centroid), f"Expected {expected_centroid}, but got {centroids}" \ No newline at end of file + # # Check if the calculated centroid matches the expected centroid + # assert torch.allclose(centroids, expected_centroid), f"Expected {expected_centroid}, but got {centroids}" \ No newline at end of file From cd134079ecd23af83e3f0731faa415c303896509 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 17:44:02 -0700 Subject: [PATCH 22/55] margin and repulsive forces seem to work --- scripts/tcg.py | 206 +++++++++++++++++++++++++++++++------------------ 1 file changed, 129 insertions(+), 77 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 43a4f02..a90c6e6 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -18,6 +18,10 @@ """ logger = logging.getLogger(__name__) +if os.environ.get('INCANT_DEBUG', None): + # suppress excess logging + logging.getLogger("PIL").setLevel(logging.WARNING) + logging.getLogger("matplotlib").setLevel(logging.WARNING) class TCGExtensionScript(UIWrapper): def __init__(self): @@ -101,11 +105,42 @@ def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margi torch.Tensor - The displacement force for each vertex. Shape: (B, C, 2) """ B, H, W, C = attention_map.shape - f_rep = repulsive_force(f_rep_strength, verts, target_pos) - f_margin = margin_force(f_margin_strength, H, W, verts) + clamp_min = -1 + clamp_max = 1 + clamp = lambda x: torch.clamp(x, min=clamp_min, max=clamp_max) + + f_rep = clamp(repulsive_force(f_rep_strength, verts, target_pos)) + f_margin = clamp(margin_force(f_margin_strength, H, W, verts)) + + logger.debug(f"Repulsive force: {f_rep}, Margin force: {f_margin}") return f_rep + f_margin +def distances_to_nearest_edges(verts, h, w): + """ Calculate the distances and direction to the nearest edge bounded by (H, W) for each channel's vertices + Arguments: + verts: torch.Tensor - The vertices. Shape: (B, C, 2), where the last 2 dims are (y, x) + h: int - The height of the image + w: int - The width of the image + Returns: + torch.Tensor, torch.Tensor: + - The minimum distance of each vertex to the nearest edge. Shape: (B, C, 1) + - The direction to the nearest edge. Shape: (B, C, 4, 2), where the last 2 dims are (y, x) + """ + # y axis is 0! + y = verts[..., 0] # (B, C, 2) + x = verts[..., 1] # (B, C, 2) + B, C, _ = verts.shape + + distances = torch.stack([y, h - y, x, w - x], dim=-1) # (B, C, 4) + + directions = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1]]).view(1, 1, 4, 2).repeat(B, C, 1, 1) # (4, 2) -> (B, C, 4, 2) + directions = directions.to(verts.device) + + return distances, directions + + + def min_distance_to_nearest_edge(verts, h, w): """ Calculate the distances and direction to the nearest edge bounded by (H, W) for each channel's vertices Arguments: @@ -120,14 +155,18 @@ def min_distance_to_nearest_edge(verts, h, w): y = verts[..., 0] # y-axis is 0! x = verts[..., 1] - # Calculate distances to the edges - distances = torch.stack([y, h - y, x, w - x], dim=-1) + # Calculate distances to the edges (y, h-y, x, w-x) + # y: distance to top edge + # h - y: distance to bottom edge + # x: distance to left edge + # w - x: distance to right edge + distances = torch.abs(torch.stack([y, h - y, x, w - x], dim=-1)) - # Find the minimum distance and the corresponding edge + # Find the minimum distance and the corresponding closest edge min_distances, min_indices = distances.min(dim=-1) # Map edge indices to direction vectors - directions = torch.tensor([[-1, 0], [1, 0], [0, 1], [0, -1]]).to(verts.device) + directions = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1]]).to(verts.device) nearest_edge_dir = directions[min_indices] return min_distances, nearest_edge_dir @@ -143,10 +182,15 @@ def margin_force(strength, H, W, verts): Returns: torch.Tensor - The force for each vertex. Shape: (B, C, 2) """ - min_distances, nearest_edge_dir = min_distance_to_nearest_edge(verts, H, W) # (B, C), (B, C, 2) - min_distances = min_distances.unsqueeze(-1) # (B, C, 1) - force = -strength / (min_distances ** 2) - return force * nearest_edge_dir + distances, edge_dirs = distances_to_nearest_edges(verts, H, W) # (B, C, 4), (B, C, 4, 2) + distances = distances.unsqueeze(-1) + + #distances = distances.unsqueeze(-1) # (B, C, 1) + force_multiplier = -strength / (distances ** 2 + torch.finfo(distances.dtype).eps) + forces = force_multiplier * edge_dirs # (B, C, 4, 2) + forces = forces.sum(dim=-2) # (B, C, 2) # sum over the 4 directions to get total force + + return forces def repulsive_force(strength, pos_vertex, pos_target): @@ -159,11 +203,12 @@ def repulsive_force(strength, pos_vertex, pos_target): Returns: torch.Tensor - The force away from the target. Shape: (B, C, 2) """ - d_pos = pos_vertex - pos_target # (B, C, 2) + d_pos = pos_target - pos_vertex # (B, C, 2) d_pos_norm = d_pos.norm(dim=-1, keepdim=True) + torch.finfo(d_pos.dtype).eps # normalize the direction - d_pos /= d_pos_norm - force = (-strength) ** 2 - return force * d_pos + d_pos = d_pos / d_pos_norm + # d_pos /= d_pos_norm + force = -(strength ** 2) + return force / d_pos def multi_target_force(attention_map, omega, xi, pos_vertex, pos_target): @@ -186,28 +231,28 @@ def calculate_centroid(attention_map): Arguments: attention_map: torch.Tensor - The attention map to calculate the centroid. Shape: (B, H, W, C) Returns: - torch.Tensor - The centroid of the attention map. Shape: (B, C, 2) + torch.Tensor - The centroid of the attention map. Shape: (B, C, 2), where the last 2 dims are (y, x) """ # Get the height and width B, H, W, C = attention_map.shape - h_coords = torch.arange(H).view(1, H, 1, 1).to(attention_map.device) - w_coords = torch.arange(W).view(1, 1, W, 1).to(attention_map.device) + # Create a coordinate grid + y_coords = torch.arange(H).reshape(1, H, 1, 1).expand(B, H, W, C).to(attention_map.device) + x_coords = torch.arange(W).reshape(1, 1, W, 1).expand(B, H, W, C).to(attention_map.device) - # Sum of attention scores for each channel - attention_sum = torch.sum(attention_map, dim=(1, 2)) + torch.finfo(attention_map.dtype).eps # shape: (B, C) + # Flatten the height and width dimensions + flattened_matrix = attention_map.reshape(B, -1, C) + y_coords = y_coords.reshape(B, -1, C) + x_coords = x_coords.reshape(B, -1, C) - # Weighted sum of the coordinates - h_weighted_sum = torch.sum(h_coords * attention_map, dim=(1,2)) # (B, C) - w_weighted_sum = torch.sum(w_coords * attention_map, dim=(1,2)) # (B, C) - - # Calculate the centroids - centroid_h = h_weighted_sum / attention_sum - centroid_w = w_weighted_sum / attention_sum - - centroids = torch.stack([centroid_h, centroid_w], dim=-1) # (B, C, 2) + # Calculate weighted sums + total_weight = flattened_matrix.sum(dim=1, keepdim=True) + centroid_y = (y_coords * flattened_matrix).sum(dim=1, keepdim=True) / total_weight + centroid_x = (x_coords * flattened_matrix).sum(dim=1, keepdim=True) / total_weight + # Combine x and y centroids + centroids = torch.cat([centroid_y, centroid_x], dim=-1) return centroids @@ -266,8 +311,6 @@ def translate_image_2d(image, tyx): # translate each dim by the displacements ty, tx = tyx[..., 0], tyx[..., 1] # C, C - ty *= -1 # invert y for some weird reason - tx *= -1 # invert x for some weird reason h_dim = h_dim + ty.view(C, 1, 1) w_dim = w_dim + tx.view(C, 1, 1) @@ -282,7 +325,7 @@ def translate_image_2d(image, tyx): grid = torch.cat([h_dim, w_dim], dim=-1) # (C, H, W, 2) # Apply the grid to the image using grid_sample - translated_image = F.grid_sample(image, grid, mode='nearest', padding_mode='zeros', align_corners=True) # C N H W + translated_image = F.grid_sample(image, grid, mode='bicubic', padding_mode='zeros', align_corners=True) # C N H W return translated_image.transpose(0, 1) # N C H W @@ -444,8 +487,9 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # plotted points as proxies for vertices vert_list = [ - #[3*H//4, W//2], # (lower middle) - [H//4+1, W//4+1], # (upper left middle quadrant) + [3*H//4, W//2], # (lower middle) + #[H//4+1, W//4+1], # (upper left middle quadrant) + #[H//4+1, W//4+1], # (upper left middle quadrant) #[H//2, W] # (right middle) ] target_position = [H//2, W//2] @@ -453,82 +497,90 @@ def color_region(image, yx, ab, color=1.0, mode='set'): verts = torch.tensor([vert_list], dtype=torch.float16, device='cuda') # B C 2 # region to represent the target region - region_yx = [H//8, W//8] - region_ab = [3*H//8, 3*W//8] + region_yx = [H//4, W//4] + region_ab = [3*H//4, 3*W//4] + + #region_yx = [H//4, 3*H//4] + #region_ab = [W//4, 3*W//4] # initialize a map with all ones attention_map = torch.zeros((B, H, W, C)).to(device, dtype) # B H W C # color the target region - color_region(attention_map, region_yx, region_ab, color=0.5, mode='add') + color_region(attention_map, region_yx, region_ab, color=1.0, mode='set') _png(attention_map, 0, 'Initial Attn Map') # calculate centroid of region centroid = calculate_centroid(attention_map) # (B, C, 2) centroid_points = centroid.squeeze(0).squeeze(0) - #centroid_points = centroid.squeeze(0).squeeze(0).cpu().numpy().astype(int) + # plot region centroid plot_point(attention_map, centroid_points, radius=1, color=1) _png(attention_map, 1, 'Attn Map Region + Centroid') # plot verts - attn_map_points = torch.randn_like(attention_map) + attn_map_points = torch.zeros_like(attention_map) + d_region_yx = [1*H//8, 1*W//8] + d_region_ab = [2*H//8, 2*W//8] + color_region(attn_map_points, d_region_yx, d_region_ab, color=0.5, mode='set') + + # set areas outside the region to 0 + #attn_map_points = attn_map_points * attention_map + verts = calculate_centroid(attn_map_points) # (B, C, 2) + for v in verts.squeeze(0): + plot_point(attn_map_points, v, radius=1, color=1) + + _png(attn_map_points, 2, 'Attn Map Proxy Map') # attn_centroid = calculate_centroid(attn_map_points) # (B, C, 2) #attn_map_points = attention_map.detach().clone() - for v in verts.squeeze(0): - plot_point(attn_map_points, v, radius=1) - _png(attn_map_points, 2, 'Points') - - # apply a simple transformation - # translate Y by -1, translate x by 0 - # for i in range(3): - # ofs = round(0.5 * (i+1), 2) - # displacements = torch.tensor([ofs, 0], dtype=torch.float16, device='cuda').repeat(B, C, 1) # 2 - # new_attention_map = apply_displacements(attention_map, displacements) - # _png(new_attention_map, img_idx+1, f'Move Initial [{ofs}, 0]') - # img_idx += 1 - - # for i in range(3): - # ofs = round(0.5 * (i+1), 2) - # displacements = torch.tensor([0, ofs], dtype=torch.float16, device='cuda').repeat(B, C, 1) # 2 - # new_attention_map = apply_displacements(attention_map, displacements) - # _png(new_attention_map, img_idx+1, f'Move Initial [0, {ofs}]') - # img_idx += 1 - - s_margin = 0.2 - s_repl = 0.2 + # for v in verts.squeeze(0): + # plot_point(attn_map_points, v, radius=1) + # _png(attn_map_points, 3, 'Points') + + # strengths + s_margin = 1.0 + s_repl = 0.1 + + # displacement forces + d_down = torch.tensor([[[0.1, 0]]], dtype=torch.float16, device='cuda') # B C 2 # simulate displacement forces on our points - img_idx = 3 - iters = 10 + img_idx = 4 + iters = 100 + steps = 5 for i in range(iters): # copy the map - attn_map_points = attn_map_points.detach().clone() displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin) # B C 2 - # new map with displacements - displaced_map = apply_displacements(attn_map_points, displ_force) + # fix me + displ_force = -displ_force + + attn_map_points = apply_displacements(attn_map_points, displ_force) + + new_vert_pos = verts + displ_force + new_centroid = calculate_centroid(attn_map_points) # (B, C, 2) + + logger.debug(f'Displacement Force: {displ_force}, Centroid: {new_centroid}') + + verts = new_centroid # copy to put on top visualizations - copied_map = displaced_map.detach().clone() + #copied_map = torch.zeros_like(displaced_map).to(device, dtype) + copied_map = attn_map_points.detach().clone() - for v in verts: - plot_point(copied_map, v.squeeze(0), radius=1) - color_region(copied_map, region_yx, region_ab, color=1.0, mode='add') - plot_point(copied_map, centroid_points, radius=1, color=1) + color_region(copied_map, region_yx, region_ab, color=0.1, mode='add') - _png(copied_map, img_idx+i, f'Displacement Forces') + for v in verts: + plot_point(copied_map, v.squeeze(0), radius=1, color=1.0) - new_vert_pos = verts + displ_force - delta_vert_pos = new_vert_pos - verts - verts += delta_vert_pos + #if i % 10 == 0: + _png(copied_map, img_idx+i, f'Displacement Forces Step {i}') - # copy back to the original map - attn_map_points = displaced_map - # verts = torch.tensor([vert_list], dtype=torch.float16, device='cuda') # B C 2 + # # copy back to the original map + # attn_map_points = displaced_map # # conflict detection From a0628ddf590a27a64d7c487bc308ec1249e75754 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 19:43:48 -0700 Subject: [PATCH 23/55] fix directions --- scripts/tcg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index a90c6e6..a5156a1 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -134,7 +134,7 @@ def distances_to_nearest_edges(verts, h, w): distances = torch.stack([y, h - y, x, w - x], dim=-1) # (B, C, 4) - directions = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1]]).view(1, 1, 4, 2).repeat(B, C, 1, 1) # (4, 2) -> (B, C, 4, 2) + directions = torch.tensor([[1, 0], [-1, 0], [0, 1], [0, -1]]).view(1, 1, 4, 2).repeat(B, C, 1, 1) # (4, 2) -> (B, C, 4, 2) directions = directions.to(verts.device) return distances, directions @@ -203,7 +203,7 @@ def repulsive_force(strength, pos_vertex, pos_target): Returns: torch.Tensor - The force away from the target. Shape: (B, C, 2) """ - d_pos = pos_target - pos_vertex # (B, C, 2) + d_pos = pos_vertex - pos_target # (B, C, 2) d_pos_norm = d_pos.norm(dim=-1, keepdim=True) + torch.finfo(d_pos.dtype).eps # normalize the direction d_pos = d_pos / d_pos_norm # d_pos /= d_pos_norm @@ -556,7 +556,7 @@ def color_region(image, yx, ab, color=1.0, mode='set'): displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin) # B C 2 # fix me - displ_force = -displ_force + # displ_force = -displ_force attn_map_points = apply_displacements(attn_map_points, displ_force) From c9eb00b115e8c2209cc36ae9bade2e9585521e3a Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 22:47:52 -0700 Subject: [PATCH 24/55] sort of fix calculating warping force --- scripts/tcg.py | 85 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 69 insertions(+), 16 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index a5156a1..f4f8af0 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -140,7 +140,6 @@ def distances_to_nearest_edges(verts, h, w): return distances, directions - def min_distance_to_nearest_edge(verts, h, w): """ Calculate the distances and direction to the nearest edge bounded by (H, W) for each channel's vertices Arguments: @@ -192,6 +191,55 @@ def margin_force(strength, H, W, verts): return forces +def warping_force(attention_map, verts, displacements, h, w): + """ Rescales the attention map based on the displacements. + Arguments: + attention_map: torch.Tensor - The attention map to update. Shape: (B, H, W, C) + verts: torch.Tensor - The centroid vertices of the attention map. Shape: (B, C, 2) + displacements: torch.Tensor - The displacements to apply. Shape: (B, C, 2), where the last 2 dims are the translation by [Y, X] + h: int - The height of the image + w: int - The width of the image + Returns: + torch.Tensor - The updated attention map. Shape: (B, H, W, C) + """ + B, H, W, C = attention_map.shape + + # relative to H and W + old_centroids = verts # (B, C, 2) + new_centroids = old_centroids + displacements # (B, C, 2) + + #delta_centroids = new_centroids - old_centroids # (B, C, 2) + + # calculate scaling factors, which are the min of 1 or the calculated scaling factor + # is it w or h? + s_y = (h - 1)/new_centroids[..., 0] # (B, C) + s_x = (w - 1)/new_centroids[..., 1] # (B, C) + torch.clamp_max(s_y, 1.0, out=s_y) + torch.clamp_max(s_x, 1.0, out=s_x) + + # displacements + delta_h = displacements[..., 0] - new_centroids[..., 0]# (B, C) + delta_w = displacements[..., 1] - new_centroids[..., 1] # (B, C) + + # construct affine transformation matrices (sx, 0, delta_x - o_new_x), (0, sy, delta_y - o_new_y) + theta = torch.tensor([[1, 0, 0],[0, 1, 0]], dtype=torch.float32, device=attention_map.device) + theta = theta.unsqueeze(0).repeat(C, 1, 1) + theta[:, 0, 0] = s_x + theta[:, 1, 1] = s_y + theta[:, 0, 2] = delta_w / w + theta[:, 1, 2] = delta_h / h + + # apply the affine transformation + grid = F.affine_grid(theta, [B, C, H, W], align_corners=False) # (C, H, W, 2) + attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) + attention_map = attention_map.to(torch.float32) + out_attn_map = F.grid_sample(attention_map, grid, mode='bicubic', padding_mode='zeros', align_corners=False) + attention_map = attention_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) + + out_attn_map = out_attn_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) + return out_attn_map, new_centroids + + def repulsive_force(strength, pos_vertex, pos_target): """ Repulsive force repels the vertices in the direction away from the target @@ -380,28 +428,29 @@ def translate_image(image, tx, ty): return translated_image.squeeze(1) -def apply_displacements(attention_map, displacements): +def apply_displacements(attention_map, verts, displacements): """ Update the attention map based on the displacements. The attention map is updated by displacing the attention values based on the displacements. - Areas that are displaced out of the attention map are discarded. - Areas that are displaced into the attention map are initialized with zeros. Arguments: attention_map: torch.Tensor - The attention map to update. Shape: (B, H, W, C) + verts: torch.Tensor - The centroid vertices of the attention map. Shape: (B, C, 2) displacements: torch.Tensor - The displacements to apply. Shape: (B, C, 2), where the last 2 dims are the translation by [Y, X] Returns: torch.Tensor - The updated attention map. Shape: (B, H, W, C) """ B, H, W, C = attention_map.shape - attention_map = attention_map.permute(0, 3, 2, 1) # (B, H, W, C) -> (B, C, H, W) - #attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) out_attn_map = attention_map.detach().clone() - + out_verts = verts.detach().clone() # apply displacements for batch_idx in range(B): - out_attn_map[batch_idx] = translate_image_2d(attention_map[batch_idx].unsqueeze(0), displacements[batch_idx]).squeeze(0) + out_attn_map[batch_idx], out_verts[batch_idx] = warping_force(attention_map[batch_idx].unsqueeze(0), verts[batch_idx].unsqueeze(0), displacements[batch_idx], H, W) + out_attn_map[batch_idx] = out_attn_map[batch_idx].squeeze(0) +# out_attn_map[batch_idx] = translate_image_2d(attention_map[batch_idx].unsqueeze(0), displacements[batch_idx]).squeeze(0) # - out_attn_map = out_attn_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) - return out_attn_map + #out_attn_map = out_attn_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) + return out_attn_map, out_verts def get_attention_scores(to_q_map, to_k_map, dtype): @@ -432,6 +481,8 @@ def get_attention_scores(to_q_map, to_k_map, dtype): return attn_probs +####################### +### Debug stuff def plot_point(image, point, radius=1, color=1.0): """ Plot a point on an image tensor Arguments: @@ -542,7 +593,7 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # strengths s_margin = 1.0 - s_repl = 0.1 + s_repl = 1.0 # displacement forces d_down = torch.tensor([[[0.1, 0]]], dtype=torch.float16, device='cuda') # B C 2 @@ -555,17 +606,19 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # copy the map displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin) # B C 2 - # fix me - # displ_force = -displ_force + # check for nan + if torch.isnan(displ_force).any(): + logger.warning(f'Nan in displ_force at iter {i}') - attn_map_points = apply_displacements(attn_map_points, displ_force) + # fix me + attn_map_points, out_verts = apply_displacements(attn_map_points, verts, displ_force) - new_vert_pos = verts + displ_force - new_centroid = calculate_centroid(attn_map_points) # (B, C, 2) + # new_vert_pos = verts + displ_force + # new_centroid = calculate_centroid(attn_map_points) # (B, C, 2) - logger.debug(f'Displacement Force: {displ_force}, Centroid: {new_centroid}') + logger.debug(f'Displacement Force: {displ_force}, Centroid: {out_verts}') - verts = new_centroid + verts = out_verts # copy to put on top visualizations #copied_map = torch.zeros_like(displaced_map).to(device, dtype) From 30d7c7a12355aeaddd11af4c8f5674be7070d019 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 22:59:28 -0700 Subject: [PATCH 25/55] debugging warping force --- scripts/tcg.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index f4f8af0..9821719 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -204,6 +204,12 @@ def warping_force(attention_map, verts, displacements, h, w): """ B, H, W, C = attention_map.shape + # rescale verts to -1 to 1 + px_to_norm = torch.tensor([H, W], dtype=verts.dtype, device=verts.device) + + # verts = verts / px_to_norm * 2 - 1 # (B, C, 2) + # displacements = displacements / px_to_norm * 2 - 1 + # relative to H and W old_centroids = verts # (B, C, 2) new_centroids = old_centroids + displacements # (B, C, 2) @@ -213,7 +219,9 @@ def warping_force(attention_map, verts, displacements, h, w): # calculate scaling factors, which are the min of 1 or the calculated scaling factor # is it w or h? s_y = (h - 1)/new_centroids[..., 0] # (B, C) - s_x = (w - 1)/new_centroids[..., 1] # (B, C) + s_x = (w - 1)/new_centroids[..., 1] # (B, C) + #s_y = 1/new_centroids[..., 0] # (B, C) + #s_x = 1/new_centroids[..., 1] # (B, C) torch.clamp_max(s_y, 1.0, out=s_y) torch.clamp_max(s_x, 1.0, out=s_x) @@ -226,17 +234,19 @@ def warping_force(attention_map, verts, displacements, h, w): theta = theta.unsqueeze(0).repeat(C, 1, 1) theta[:, 0, 0] = s_x theta[:, 1, 1] = s_y - theta[:, 0, 2] = delta_w / w - theta[:, 1, 2] = delta_h / h + theta[:, 0, 2] = delta_h / h + theta[:, 1, 2] = delta_w / w # apply the affine transformation - grid = F.affine_grid(theta, [B, C, H, W], align_corners=False) # (C, H, W, 2) + grid = F.affine_grid(theta, [B, C, H, W], align_corners=True) # (C, H, W, 2) attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) attention_map = attention_map.to(torch.float32) - out_attn_map = F.grid_sample(attention_map, grid, mode='bicubic', padding_mode='zeros', align_corners=False) + out_attn_map = F.grid_sample(attention_map, grid, mode='bilinear', padding_mode='zeros', align_corners=True) attention_map = attention_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) out_attn_map = out_attn_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) + + # rescale centroids to pixel space return out_attn_map, new_centroids @@ -573,7 +583,7 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # plot verts attn_map_points = torch.zeros_like(attention_map) d_region_yx = [1*H//8, 1*W//8] - d_region_ab = [2*H//8, 2*W//8] + d_region_ab = [6*H//8, 3*W//8] color_region(attn_map_points, d_region_yx, d_region_ab, color=0.5, mode='set') # set areas outside the region to 0 From 01f201ca3f7882ec3aed5a1f4c238afd2f646e20 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Sun, 26 May 2024 23:05:44 -0700 Subject: [PATCH 26/55] wip bbox fn --- scripts/tcg.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/scripts/tcg.py b/scripts/tcg.py index 9821719..37c10c3 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -284,6 +284,39 @@ def multi_target_force(attention_map, omega, xi, pos_vertex, pos_target): pass +def calculate_region(attention_map): + """ Given an attention map of shape [B, H, W, C], calculate a bounding box over each C + Arguments: + attention_map: torch.Tensor - The attention map to calculate the bounding box. Shape: (B, H, W, C) + Returns: + torch.Tensor - The bounding box of the region. Shape: (B, C, 4), where the last 4 dims are (y, x, a, b) + y, x: The top left corner of the bounding box + a, b: The height and width of the bounding box + """ + B, H, W, C = attention_map.shape + # Calculate the sum of attention map along the height and width dimensions + sum_map = attention_map.sum(dim=(1, 2)) # (B, C) + # Find the indices of the maximum attention value for each channel + max_indices = sum_map.argmax(dim=1, keepdim=True) # (B, C) + # Initialize the bounding box tensor + bounding_box = torch.zeros((B, C, 4), dtype=torch.int32, device=attention_map.device) + # Iterate over each channel + for batch_idx in range(B): + for channel_idx in range(C): + # Calculate the row and column indices of the maximum attention value + row_index = max_indices[batch_idx, channel_idx] // W + col_index = max_indices[batch_idx, channel_idx] % W + # Calculate the top left corner coordinates of the bounding box + y = max(0, row_index - 1) + x = max(0, col_index - 1) + # Calculate the height and width of the bounding box + a = min(H - y, row_index + 2) - y + b = min(W - x, col_index + 2) - x + # Store the bounding box coordinates in the tensor + bounding_box[batch_idx, channel_idx] = torch.tensor([y, x, a, b]) + return bounding_box + + def calculate_centroid(attention_map): """ Calculate the centroid of the attention map Arguments: @@ -614,6 +647,10 @@ def color_region(image, yx, ab, color=1.0, mode='set'): steps = 5 for i in range(iters): # copy the map + bbox_map = calculate_region(attn_map_points) # (B, C, 4) + + + displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin) # B C 2 # check for nan From 24b4d5c9a3f00967f5cf16f347fc7d6d268544ac Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 27 May 2024 15:19:54 -0700 Subject: [PATCH 27/55] maybe fix warping force in case displacement is 0,0 --- scripts/tcg.py | 64 ++++++++++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 37c10c3..5e0af54 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -93,7 +93,11 @@ def get_xyz_axis_options(self) -> dict: return {} -def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margin_strength): + +debug_coord = lambda x: (round(x[0,0,0].item(),3), round(x[0,0,1].item(), 3)) + + +def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margin_strength, clamp=0): """ Given a set of vertices, calculate the displacement force given by the sum of margin force and repulsive force. Arguments: attention_map: torch.Tensor - The attention map to calculate the force. Shape: (B, H, W, C) @@ -105,14 +109,16 @@ def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margi torch.Tensor - The displacement force for each vertex. Shape: (B, C, 2) """ B, H, W, C = attention_map.shape - clamp_min = -1 - clamp_max = 1 - clamp = lambda x: torch.clamp(x, min=clamp_min, max=clamp_max) + f_clamp = lambda x: x + if clamp > 0: + clamp_min = -clamp + clamp_max = clamp + f_clamp = lambda x: torch.clamp(x, min=clamp_min, max=clamp_max) - f_rep = clamp(repulsive_force(f_rep_strength, verts, target_pos)) - f_margin = clamp(margin_force(f_margin_strength, H, W, verts)) + f_rep = f_clamp(repulsive_force(f_rep_strength, verts, target_pos)) + f_margin = f_clamp(margin_force(f_margin_strength, H, W, verts)) - logger.debug(f"Repulsive force: {f_rep}, Margin force: {f_margin}") + logger.debug(f"Repulsive force: {debug_coord(f_rep)}, Margin force: {debug_coord(f_margin)}") return f_rep + f_margin @@ -204,38 +210,37 @@ def warping_force(attention_map, verts, displacements, h, w): """ B, H, W, C = attention_map.shape - # rescale verts to -1 to 1 - px_to_norm = torch.tensor([H, W], dtype=verts.dtype, device=verts.device) + # - # verts = verts / px_to_norm * 2 - 1 # (B, C, 2) - # displacements = displacements / px_to_norm * 2 - 1 - - # relative to H and W old_centroids = verts # (B, C, 2) new_centroids = old_centroids + displacements # (B, C, 2) - #delta_centroids = new_centroids - old_centroids # (B, C, 2) + # check if new_centroids are out of bounds + min_bounds = torch.tensor([0, 0], dtype=torch.float32, device=attention_map.device) + max_bounds = torch.tensor([h-1, w-1], dtype=torch.float32, device=attention_map.device) + oob_new_centroids = torch.clamp(new_centroids, min_bounds, max_bounds) + + # diferenct between old and new centroids + correction = oob_new_centroids - new_centroids + new_centroids = new_centroids + correction - # calculate scaling factors, which are the min of 1 or the calculated scaling factor - # is it w or h? s_y = (h - 1)/new_centroids[..., 0] # (B, C) s_x = (w - 1)/new_centroids[..., 1] # (B, C) - #s_y = 1/new_centroids[..., 0] # (B, C) - #s_x = 1/new_centroids[..., 1] # (B, C) torch.clamp_max(s_y, 1.0, out=s_y) torch.clamp_max(s_x, 1.0, out=s_x) # displacements - delta_h = displacements[..., 0] - new_centroids[..., 0]# (B, C) - delta_w = displacements[..., 1] - new_centroids[..., 1] # (B, C) + o_new = old_centroids + displacements - new_centroids + delta_h = old_centroids + displacements[..., 0] - new_centroids[..., 0]# (B, C) + delta_w = old_centroids + displacements[..., 1] - new_centroids[..., 1] # (B, C) # construct affine transformation matrices (sx, 0, delta_x - o_new_x), (0, sy, delta_y - o_new_y) theta = torch.tensor([[1, 0, 0],[0, 1, 0]], dtype=torch.float32, device=attention_map.device) theta = theta.unsqueeze(0).repeat(C, 1, 1) theta[:, 0, 0] = s_x theta[:, 1, 1] = s_y - theta[:, 0, 2] = delta_h / h - theta[:, 1, 2] = delta_w / w + theta[:, 0, 2] = o_new[..., 1] / w # X + theta[:, 1, 2] = o_new[..., 0] / h # Y # apply the affine transformation grid = F.affine_grid(theta, [B, C, H, W], align_corners=True) # (C, H, W, 2) @@ -262,7 +267,8 @@ def repulsive_force(strength, pos_vertex, pos_target): torch.Tensor - The force away from the target. Shape: (B, C, 2) """ d_pos = pos_vertex - pos_target # (B, C, 2) - d_pos_norm = d_pos.norm(dim=-1, keepdim=True) + torch.finfo(d_pos.dtype).eps # normalize the direction + d_pos_norm = d_pos.norm() + torch.finfo(d_pos.dtype).eps # normalize the direction + #d_pos_norm = d_pos.norm(dim=-1, keepdim=True) + torch.finfo(d_pos.dtype).eps # normalize the direction d_pos = d_pos / d_pos_norm # d_pos /= d_pos_norm force = -(strength ** 2) @@ -479,7 +485,7 @@ def apply_displacements(attention_map, verts, displacements): Arguments: attention_map: torch.Tensor - The attention map to update. Shape: (B, H, W, C) verts: torch.Tensor - The centroid vertices of the attention map. Shape: (B, C, 2) - displacements: torch.Tensor - The displacements to apply. Shape: (B, C, 2), where the last 2 dims are the translation by [Y, X] + displacements: torch.Tensor - The displacements to apply in pixel space. Shape: (B, C, 2), where the last 2 dims are the translation by [Y, X] Returns: torch.Tensor - The updated attention map. Shape: (B, H, W, C) """ @@ -636,9 +642,10 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # strengths s_margin = 1.0 - s_repl = 1.0 + s_repl = 0.0 # displacement forces + d_zero = torch.tensor([[[0.0, 0]]], dtype=torch.float16, device='cuda') # B C 2 d_down = torch.tensor([[[0.1, 0]]], dtype=torch.float16, device='cuda') # B C 2 # simulate displacement forces on our points @@ -649,8 +656,6 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # copy the map bbox_map = calculate_region(attn_map_points) # (B, C, 4) - - displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin) # B C 2 # check for nan @@ -658,12 +663,15 @@ def color_region(image, yx, ab, color=1.0, mode='set'): logger.warning(f'Nan in displ_force at iter {i}') # fix me + displ_force = d_zero + attn_map_points, out_verts = apply_displacements(attn_map_points, verts, displ_force) # new_vert_pos = verts + displ_force # new_centroid = calculate_centroid(attn_map_points) # (B, C, 2) - logger.debug(f'Displacement Force: {displ_force}, Centroid: {out_verts}') + + logger.debug(f'Displacement Force: {debug_coord(displ_force)}, Centroid: {debug_coord(out_verts)}') verts = out_verts From 5ac21e62df693ab2a8f437c4c28150e96ae8c172 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 27 May 2024 15:29:21 -0700 Subject: [PATCH 28/55] fix sampling --- scripts/tcg.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 5e0af54..6e8945a 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -230,15 +230,17 @@ def warping_force(attention_map, verts, displacements, h, w): torch.clamp_max(s_x, 1.0, out=s_x) # displacements - o_new = old_centroids + displacements - new_centroids - delta_h = old_centroids + displacements[..., 0] - new_centroids[..., 0]# (B, C) - delta_w = old_centroids + displacements[..., 1] - new_centroids[..., 1] # (B, C) + o_new = displacements - correction + #delta_h = old_centroids + displacements[..., 0] - new_centroids[..., 0]# (B, C) + #delta_w = old_centroids + displacements[..., 1] - new_centroids[..., 1] # (B, C) # construct affine transformation matrices (sx, 0, delta_x - o_new_x), (0, sy, delta_y - o_new_y) theta = torch.tensor([[1, 0, 0],[0, 1, 0]], dtype=torch.float32, device=attention_map.device) theta = theta.unsqueeze(0).repeat(C, 1, 1) theta[:, 0, 0] = s_x theta[:, 1, 1] = s_y + #theta[:, 0, 2] = o_new[..., 1] / w # X + #theta[:, 1, 2] = o_new[..., 0] / h # Y theta[:, 0, 2] = o_new[..., 1] / w # X theta[:, 1, 2] = o_new[..., 0] / h # Y @@ -246,7 +248,7 @@ def warping_force(attention_map, verts, displacements, h, w): grid = F.affine_grid(theta, [B, C, H, W], align_corners=True) # (C, H, W, 2) attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) attention_map = attention_map.to(torch.float32) - out_attn_map = F.grid_sample(attention_map, grid, mode='bilinear', padding_mode='zeros', align_corners=True) + out_attn_map = F.grid_sample(attention_map, grid, mode='bicubic', padding_mode='zeros', align_corners=True) attention_map = attention_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) out_attn_map = out_attn_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) @@ -642,7 +644,7 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # strengths s_margin = 1.0 - s_repl = 0.0 + s_repl = 1.0 # displacement forces d_zero = torch.tensor([[[0.0, 0]]], dtype=torch.float16, device='cuda') # B C 2 @@ -656,14 +658,14 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # copy the map bbox_map = calculate_region(attn_map_points) # (B, C, 4) - displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin) # B C 2 + displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin, clamp = 2) # B C 2 # check for nan if torch.isnan(displ_force).any(): logger.warning(f'Nan in displ_force at iter {i}') # fix me - displ_force = d_zero + # displ_force = d_zero attn_map_points, out_verts = apply_displacements(attn_map_points, verts, displ_force) @@ -687,9 +689,6 @@ def color_region(image, yx, ab, color=1.0, mode='set'): #if i % 10 == 0: _png(copied_map, img_idx+i, f'Displacement Forces Step {i}') - # # copy back to the original map - # attn_map_points = displaced_map - # # conflict detection # attention_map = torch.ones(B, H, W, C).to('cuda') # B H W C From 903f8bddac1df4506cec1e82e92b55e9859dd345 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 27 May 2024 16:03:13 -0700 Subject: [PATCH 29/55] experimenting with hyperparameters --- scripts/tcg.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 6e8945a..23bbaad 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -228,6 +228,8 @@ def warping_force(attention_map, verts, displacements, h, w): s_x = (w - 1)/new_centroids[..., 1] # (B, C) torch.clamp_max(s_y, 1.0, out=s_y) torch.clamp_max(s_x, 1.0, out=s_x) + if s_x < 0.99 or s_y < 0.99: + logger.debug(f"Scaling factor: {s_x}, {s_y}") # displacements o_new = displacements - correction @@ -245,10 +247,10 @@ def warping_force(attention_map, verts, displacements, h, w): theta[:, 1, 2] = o_new[..., 0] / h # Y # apply the affine transformation - grid = F.affine_grid(theta, [B, C, H, W], align_corners=True) # (C, H, W, 2) + grid = F.affine_grid(theta, [B, C, H, W], align_corners=False) # (C, H, W, 2) attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) attention_map = attention_map.to(torch.float32) - out_attn_map = F.grid_sample(attention_map, grid, mode='bicubic', padding_mode='zeros', align_corners=True) + out_attn_map = F.grid_sample(attention_map, grid, mode='bicubic', padding_mode='zeros', align_corners=False) attention_map = attention_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) out_attn_map = out_attn_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) @@ -635,16 +637,9 @@ def color_region(image, yx, ab, color=1.0, mode='set'): _png(attn_map_points, 2, 'Attn Map Proxy Map') - # attn_centroid = calculate_centroid(attn_map_points) # (B, C, 2) - - #attn_map_points = attention_map.detach().clone() - # for v in verts.squeeze(0): - # plot_point(attn_map_points, v, radius=1) - # _png(attn_map_points, 3, 'Points') - # strengths - s_margin = 1.0 - s_repl = 1.0 + s_margin = 500.0 + s_repl = 10.0 # displacement forces d_zero = torch.tensor([[[0.0, 0]]], dtype=torch.float16, device='cuda') # B C 2 @@ -658,7 +653,7 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # copy the map bbox_map = calculate_region(attn_map_points) # (B, C, 4) - displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin, clamp = 2) # B C 2 + displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin, clamp = 10) # B C 2 # check for nan if torch.isnan(displ_force).any(): @@ -675,7 +670,8 @@ def color_region(image, yx, ab, color=1.0, mode='set'): logger.debug(f'Displacement Force: {debug_coord(displ_force)}, Centroid: {debug_coord(out_verts)}') - verts = out_verts + verts = calculate_centroid(attn_map_points) # (B, C, 2) + # verts = out_verts # copy to put on top visualizations #copied_map = torch.zeros_like(displaced_map).to(device, dtype) From fe68cadf361c0e875456ab845094de796d915f6d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 27 May 2024 16:29:48 -0700 Subject: [PATCH 30/55] Fix calculate_centroid shape when channel > 1 --- scripts/tcg.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 23bbaad..446dce0 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -348,12 +348,12 @@ def calculate_centroid(attention_map): x_coords = x_coords.reshape(B, -1, C) # Calculate weighted sums - total_weight = flattened_matrix.sum(dim=1, keepdim=True) - centroid_y = (y_coords * flattened_matrix).sum(dim=1, keepdim=True) / total_weight - centroid_x = (x_coords * flattened_matrix).sum(dim=1, keepdim=True) / total_weight + total_weight = flattened_matrix.sum(dim=1) + centroid_y = (y_coords * flattened_matrix).sum(dim=1) / total_weight + centroid_x = (x_coords * flattened_matrix).sum(dim=1) / total_weight # Combine x and y centroids - centroids = torch.cat([centroid_y, centroid_x], dim=-1) + centroids = torch.stack([centroid_y, centroid_x], dim=-1) return centroids From 43f046d0171d4351d2754abbd0e2fc76c68f2a74 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 27 May 2024 16:33:52 -0700 Subject: [PATCH 31/55] Fix debug statement dim error --- scripts/tcg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 446dce0..719b016 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -228,7 +228,7 @@ def warping_force(attention_map, verts, displacements, h, w): s_x = (w - 1)/new_centroids[..., 1] # (B, C) torch.clamp_max(s_y, 1.0, out=s_y) torch.clamp_max(s_x, 1.0, out=s_x) - if s_x < 0.99 or s_y < 0.99: + if torch.any(s_x < 0.99) or torch.any(s_y < 0.99): logger.debug(f"Scaling factor: {s_x}, {s_y}") # displacements From 4fe4911e957740369c3ed811036689226fe58ae6 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 27 May 2024 16:40:19 -0700 Subject: [PATCH 32/55] fix warping force batch dims --- scripts/tcg.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 719b016..5a403c9 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -198,19 +198,17 @@ def margin_force(strength, H, W, verts): return forces def warping_force(attention_map, verts, displacements, h, w): - """ Rescales the attention map based on the displacements. + """ Rescales the attention map based on the displacements. Expects a batch size of 1 to operate on all channels at once. Arguments: - attention_map: torch.Tensor - The attention map to update. Shape: (B, H, W, C) - verts: torch.Tensor - The centroid vertices of the attention map. Shape: (B, C, 2) - displacements: torch.Tensor - The displacements to apply. Shape: (B, C, 2), where the last 2 dims are the translation by [Y, X] + attention_map: torch.Tensor - The attention map to update. Shape: (1, H, W, C) + verts: torch.Tensor - The centroid vertices of the attention map. Shape: (1, C, 2) + displacements: torch.Tensor - The displacements to apply. Shape: (1, C, 2), where the last 2 dims are the translation by [Y, X] h: int - The height of the image w: int - The width of the image Returns: torch.Tensor - The updated attention map. Shape: (B, H, W, C) """ - B, H, W, C = attention_map.shape - - # + _, H, W, C = attention_map.shape old_centroids = verts # (B, C, 2) new_centroids = old_centroids + displacements # (B, C, 2) @@ -233,27 +231,24 @@ def warping_force(attention_map, verts, displacements, h, w): # displacements o_new = displacements - correction - #delta_h = old_centroids + displacements[..., 0] - new_centroids[..., 0]# (B, C) - #delta_w = old_centroids + displacements[..., 1] - new_centroids[..., 1] # (B, C) # construct affine transformation matrices (sx, 0, delta_x - o_new_x), (0, sy, delta_y - o_new_y) theta = torch.tensor([[1, 0, 0],[0, 1, 0]], dtype=torch.float32, device=attention_map.device) theta = theta.unsqueeze(0).repeat(C, 1, 1) theta[:, 0, 0] = s_x theta[:, 1, 1] = s_y - #theta[:, 0, 2] = o_new[..., 1] / w # X - #theta[:, 1, 2] = o_new[..., 0] / h # Y theta[:, 0, 2] = o_new[..., 1] / w # X theta[:, 1, 2] = o_new[..., 0] / h # Y # apply the affine transformation - grid = F.affine_grid(theta, [B, C, H, W], align_corners=False) # (C, H, W, 2) - attention_map = attention_map.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W) + grid = F.affine_grid(theta, [C, 1, H, W], align_corners=False) # (C, H, W, 2) + + attention_map = attention_map.permute(3, 0, 1, 2) # (B, H, W, C) -> (C, B, H, W) attention_map = attention_map.to(torch.float32) out_attn_map = F.grid_sample(attention_map, grid, mode='bicubic', padding_mode='zeros', align_corners=False) - attention_map = attention_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) + attention_map = attention_map.permute(1, 2, 3, 0) # (C, B, H, W) -> (B, H, W, C) - out_attn_map = out_attn_map.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C) + out_attn_map = out_attn_map.permute(1, 2, 3, 0) # (C, B, H, W) -> (B, H, W, C) # rescale centroids to pixel space return out_attn_map, new_centroids From 2ffee0840897166a5973ebbed913d4b41fde8b89 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 27 May 2024 17:20:50 -0700 Subject: [PATCH 33/55] support translating multiple channels --- scripts/tcg.py | 137 ++++++++++++++++++++++++------------------------- 1 file changed, 66 insertions(+), 71 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 5a403c9..678fb3d 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -2,6 +2,7 @@ import logging import gradio as gr import torch +import random import torch.nn.functional as F if __name__ == '__main__' and os.environ.get('INCANT_DEBUG', None): @@ -245,7 +246,7 @@ def warping_force(attention_map, verts, displacements, h, w): attention_map = attention_map.permute(3, 0, 1, 2) # (B, H, W, C) -> (C, B, H, W) attention_map = attention_map.to(torch.float32) - out_attn_map = F.grid_sample(attention_map, grid, mode='bicubic', padding_mode='zeros', align_corners=False) + out_attn_map = F.grid_sample(attention_map, grid, mode='bilinear', padding_mode='zeros', align_corners=False) attention_map = attention_map.permute(1, 2, 3, 0) # (C, B, H, W) -> (B, H, W, C) out_attn_map = out_attn_map.permute(1, 2, 3, 0) # (C, B, H, W) -> (B, H, W, C) @@ -329,26 +330,27 @@ def calculate_centroid(attention_map): Returns: torch.Tensor - The centroid of the attention map. Shape: (B, C, 2), where the last 2 dims are (y, x) """ - - # Get the height and width - B, H, W, C = attention_map.shape + # necessary to avoid inf + attention_map = attention_map.to(torch.float32) + + # Create tensors of the y and x coordinates + y_coords = torch.arange(H, dtype=attention_map.dtype, device=attention_map.device).view(1, H, 1, 1) + x_coords = torch.arange(W, dtype=attention_map.dtype, device=attention_map.device).view(1, 1, W, 1) + + # Calculate the weighted sums of the coordinates + weighted_sum_y = torch.sum(y_coords * attention_map, dim=[1, 2]) + weighted_sum_x = torch.sum(x_coords * attention_map, dim=[1, 2]) + + # Calculate the total weights + total_weights = torch.sum(attention_map, dim=[1, 2]) + + # Calculate the centroids + centroid_y = weighted_sum_y / total_weights + centroid_x = weighted_sum_x / total_weights - # Create a coordinate grid - y_coords = torch.arange(H).reshape(1, H, 1, 1).expand(B, H, W, C).to(attention_map.device) - x_coords = torch.arange(W).reshape(1, 1, W, 1).expand(B, H, W, C).to(attention_map.device) - - # Flatten the height and width dimensions - flattened_matrix = attention_map.reshape(B, -1, C) - y_coords = y_coords.reshape(B, -1, C) - x_coords = x_coords.reshape(B, -1, C) - - # Calculate weighted sums - total_weight = flattened_matrix.sum(dim=1) - centroid_y = (y_coords * flattened_matrix).sum(dim=1) / total_weight - centroid_x = (x_coords * flattened_matrix).sum(dim=1) / total_weight - # Combine x and y centroids - centroids = torch.stack([centroid_y, centroid_x], dim=-1) + centroids = torch.stack([centroid_y, centroid_x], dim=-1) + return centroids @@ -363,10 +365,11 @@ def detect_conflict(attention_map, region, theta): torch.Tensor: Conflict detection result of shape (B, K), with values 0 or 1 indicating conflict between tokens and the region. """ # Ensure region is the same shape as the spatial dimensions of attention_map - assert region.shape[1:] == attention_map.shape[1:3], "Region mask must match spatial dimensions of attention map" + assert region.shape[1:] == attention_map.shape[1:], "Region mask must match spatial dimensions of attention map" # Calculate the mean attention within the region - region = region.unsqueeze(-1) # Add channel dimension: (B, H, W) -> (B, H, W, 1) + #region = region.unsqueeze(-1) # Add channel dimension: (B, H, W) -> (B, H, W, 1) attention_in_region = attention_map * region # Element-wise multiplication + #mean_attention_in_region = attention_in_region[attention_in_region > 0] mean_attention_in_region = torch.sum(attention_in_region, dim=(1, 2)) / torch.sum(region, dim=(1, 2)) # Mean over (H, W) # Compare with threshold theta conflict = (mean_attention_in_region > theta).float() # Convert boolean to float (0 or 1) @@ -580,7 +583,7 @@ def color_region(image, yx, ab, color=1.0, mode='set'): title=f'{title}', ) - B, H, W, C = 1, 64, 64, 1 + B, H, W, C = 1, 64, 64, 3 dtype = torch.float16 device = 'cuda' @@ -610,12 +613,17 @@ def color_region(image, yx, ab, color=1.0, mode='set'): _png(attention_map, 0, 'Initial Attn Map') + attention_map_region = attention_map.detach().clone() + # calculate centroid of region centroid = calculate_centroid(attention_map) # (B, C, 2) - centroid_points = centroid.squeeze(0).squeeze(0) + #centroid_points = centroid.squeeze(0).squeeze(0) + centroid_points = centroid # plot region centroid - plot_point(attention_map, centroid_points, radius=1, color=1) + for batch_idx in range(B): + for channel_idx in range(C): + plot_point(attention_map[batch_idx, ..., channel_idx].unsqueeze(0).unsqueeze(-1), list(centroid_points[batch_idx, channel_idx]), radius=1, color=1) _png(attention_map, 1, 'Attn Map Region + Centroid') # plot verts @@ -624,6 +632,19 @@ def color_region(image, yx, ab, color=1.0, mode='set'): d_region_ab = [6*H//8, 3*W//8] color_region(attn_map_points, d_region_yx, d_region_ab, color=0.5, mode='set') + c0_region_yx = [12,24] + c0_region_ab = [36, 48] + + c1_region_yx = [36,36] + c1_region_ab = [48, 48] + + c2_region_yx = [48, 48] + c2_region_ab = [63, 63] + if attn_map_points.shape[-1] > 1: + color_region(attn_map_points[..., 0].unsqueeze(-1), c0_region_yx, c0_region_ab, color=1, mode='set') + color_region(attn_map_points[..., 1].unsqueeze(-1), c1_region_yx, c1_region_ab, color=1, mode='set') + color_region(attn_map_points[..., 2].unsqueeze(-1), c2_region_yx, c2_region_ab, color=1, mode='set') + # set areas outside the region to 0 #attn_map_points = attn_map_points * attention_map verts = calculate_centroid(attn_map_points) # (B, C, 2) @@ -633,8 +654,8 @@ def color_region(image, yx, ab, color=1.0, mode='set'): _png(attn_map_points, 2, 'Attn Map Proxy Map') # strengths - s_margin = 500.0 - s_repl = 10.0 + s_margin = 1.0 + s_repl = 1.0 # displacement forces d_zero = torch.tensor([[[0.0, 0]]], dtype=torch.float16, device='cuda') # B C 2 @@ -645,60 +666,34 @@ def color_region(image, yx, ab, color=1.0, mode='set'): iters = 100 steps = 5 for i in range(iters): - # copy the map - bbox_map = calculate_region(attn_map_points) # (B, C, 4) + # Check for conflicts between target region and attention map + theta = 0.001 + conflict_detection = detect_conflict(attn_map_points, attention_map_region, theta) # (B, C) + logger.debug(f'Conflict Detection: {conflict_detection}') + if not conflict_detection.any(): + logger.info(f'No conflict detected at iter {i}') + break + verts = calculate_centroid(attn_map_points) # (B, C, 2) + # Displacement forces displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin, clamp = 10) # B C 2 - - # check for nan if torch.isnan(displ_force).any(): logger.warning(f'Nan in displ_force at iter {i}') - # fix me - # displ_force = d_zero - + # apply displacements and calculate new centroids attn_map_points, out_verts = apply_displacements(attn_map_points, verts, displ_force) - - # new_vert_pos = verts + displ_force - # new_centroid = calculate_centroid(attn_map_points) # (B, C, 2) - - - logger.debug(f'Displacement Force: {debug_coord(displ_force)}, Centroid: {debug_coord(out_verts)}') - verts = calculate_centroid(attn_map_points) # (B, C, 2) - # verts = out_verts - # copy to put on top visualizations - #copied_map = torch.zeros_like(displaced_map).to(device, dtype) + # debug output copied_map = attn_map_points.detach().clone() - color_region(copied_map, region_yx, region_ab, color=0.1, mode='add') - for v in verts: - plot_point(copied_map, v.squeeze(0), radius=1, color=1.0) - - #if i % 10 == 0: - _png(copied_map, img_idx+i, f'Displacement Forces Step {i}') - - - # # conflict detection - # attention_map = torch.ones(B, H, W, C).to('cuda') # B H W C - # region = torch.zeros((B, H, W), dtype=torch.float16, device='cuda') # B H W C - # # set the left half of region to 1 - # region[:, :, :W//2] = 1 - # theta = 0.5 # Example threshold - # conflict_detection = detect_conflict(attention_map, region, theta) - # print(conflict_detection) + # for v in verts: + # plot_point(copied_map, v.squeeze(0), radius=1, color=1.0) - # # Create a simple attention map with known values - # attention_map = torch.zeros((B, H, W, C), device='cuda') # Shape (batch_size, height, width, channels) - # attention_map[0, H//2, W//2, 0] = 1.0 # Put all attention on the center - - # # Calculate centroids - # centroids = calculate_centroid(attention_map) # (B, C, 2) - - # # Expected centroid is the center of the attention map (2, 2) - # expected_centroid = torch.tensor([[[H/2, W/2]]], device='cuda') - - # # Check if the calculated centroid matches the expected centroid - # assert torch.allclose(centroids, expected_centroid), f"Expected {expected_centroid}, but got {centroids}" \ No newline at end of file + logger.debug(f'Displacement Force: {debug_coord(displ_force)}, Centroid: {debug_coord(out_verts)}') + for c in range(C): + ofs = c * 100 + #plot_point(copied_map[:c], img_idx+ofs+i+c, radius=1, color=1.0) + _png(copied_map[..., c].unsqueeze(-1), img_idx+ofs+i, f'Displacement Forces Channel {c} Step {i}') + #_png(copied_map, img_idx+i, f'Displacement Forces Step {i}') \ No newline at end of file From 93e2ddc6e5da89f0df14377c8b2f6c5bf342598d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 27 May 2024 20:31:20 -0700 Subject: [PATCH 34/55] loosen assert reqs --- scripts/tcg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 678fb3d..f13933e 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -365,7 +365,7 @@ def detect_conflict(attention_map, region, theta): torch.Tensor: Conflict detection result of shape (B, K), with values 0 or 1 indicating conflict between tokens and the region. """ # Ensure region is the same shape as the spatial dimensions of attention_map - assert region.shape[1:] == attention_map.shape[1:], "Region mask must match spatial dimensions of attention map" + assert region.shape[1:3] == attention_map.shape[1:3], "Region mask must match spatial dimensions of attention map" # Calculate the mean attention within the region #region = region.unsqueeze(-1) # Add channel dimension: (B, H, W) -> (B, H, W, 1) attention_in_region = attention_map * region # Element-wise multiplication From 91a8758058477ef6d40d8b48ef90dfb6143b1070 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 27 May 2024 21:10:34 -0700 Subject: [PATCH 35/55] preliminary impl --- scripts/tcg.py | 99 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 86 insertions(+), 13 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index f13933e..e254cf4 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -9,7 +9,8 @@ sys.path.append(f'{os.getcwd()}') sys.path.append(f'{os.getcwd()}/extensions/sd-webui-incantations') from scripts.ui_wrapper import UIWrapper -from scripts.incant_utils import module_hooks, plot_tools +from scripts.incant_utils import module_hooks, plot_tools, prompt_utils +from modules import shared, script_callbacks """ WIP Implementation of https://arxiv.org/abs/2404.11824 @@ -48,21 +49,86 @@ def before_process_batch(self, p, active, *args, **kwargs): active = getattr(p, 'tcg_active', active) if not active: return + + batch_size = p.batch_size + height, width = p.height, p.width + hw = height * width + + token_count, max_length = prompt_utils.get_token_count(p.prompt, p.steps, is_positive=True) + min_idx = 1 + max_idx = token_count+1 + token_indices = list(range(min_idx, max_idx)) def tcg_forward_hook(module, input, kwargs, output): - pass + # calc attn scores + q_map = module.tcg_to_q_map + k_map = module.tcg_to_k_map + attn_scores = get_attention_scores(q_map, k_map, dtype=q_map.dtype) # (2*B, H*W, C) + B, HW, C = attn_scores.shape + + downscale_h = round((HW * (height / width)) ** 0.5) + attn_scores = attn_scores.view(2, attn_scores.size(0)//2, downscale_h, HW//downscale_h, attn_scores.size(-1)).mean(dim=0) # (2*B, HW, C) -> (B, H, W, C) + + # slice attn map + attn_map = attn_scores[..., module.tcg_token_indices] + attn_map = attn_map.detach().clone() # (B, H, W, K) where K is the subset of tokens + + # region mask + region_mask = module.tcg_region_mask + if region_mask.shape[1:3] != attn_map.shape[1:3]: + region_mask = region_mask.permute(0, 3, 1, 2) # (B, H, W, 1) -> (B, 1, H, W) + region_mask = F.interpolate(region_mask, size=(attn_map.shape[1:3]), mode='nearest') + region_mask = region_mask.permute(0, 2, 3, 1) # (B, 1, H, W) -> (B, H, W, 1) + module.tcg_region_mask = region_mask + + region_mask_centroid = calculate_centroid(region_mask) # (B, C, 2) + + # detect conflicts and return if none + theta = 0.0000001 # parameterize this + conflicts = detect_conflict(attn_map, region_mask, theta) # (B, C) + if not torch.any(conflicts > 0.01): + return + + + centroids = calculate_centroid(attn_map) # (B, C, 2) + s_margin = 5.0 + s_repl = 1.0 + displ_force = displacement_force(attn_map, centroids, region_mask_centroid, s_repl, s_margin, clamp = 10) # B C 2 + + # reassign the attn map + output_attn_map = output.detach().clone() + #output_attn_map[..., module.tcg_token_indices] = output.detach().clone() + modified_attn_map, out_centroids = apply_displacements(output_attn_map[..., module.tcg_token_indices].unsqueeze(0), centroids, displ_force) + output_attn_map[..., module.tcg_token_indices] = modified_attn_map.squeeze(0) + + output = output_attn_map + + + def tcg_to_q_hook(module, input, kwargs, output): setattr(module.tcg_parent_module[0], 'tcg_to_q_map', output) def tcg_to_k_hook(module, input, kwargs, output): setattr(module.tcg_parent_module[0], 'tcg_to_k_map', output) + + def cfg_denoised_callback(params: script_callbacks.CFGDenoisedParams): + pass + + script_callbacks.on_cfg_denoised(cfg_denoised_callback) + + # TODO: Parameterize this + mask_H, mask_W = 64, 64 + temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) + temp_region_mask[0, 0:mask_H-1, 0:mask_W//2] = 1.0 # mask the left half of the thing for module in self.get_modules(): if not module.network_layer_name.endswith('attn2'): continue module_hooks.modules_add_field(module, 'tcg_to_q_map', None) module_hooks.modules_add_field(module, 'tcg_to_k_map', None) + module_hooks.modules_add_field(module, 'tcg_region_mask', temp_region_mask) + module_hooks.modules_add_field(module, 'tcg_token_indices', torch.tensor(token_indices, dtype=torch.int32, device=shared.device)) module_hooks.modules_add_field(module.to_q, 'tcg_parent_module', [module]) module_hooks.modules_add_field(module.to_k, 'tcg_parent_module', [module]) module_hooks.module_add_forward_hook(module.to_q, tcg_to_q_hook, with_kwargs=True) @@ -73,11 +139,14 @@ def postprocess_batch(self, p, *args, **kwargs): self.unhook_callbacks() def unhook_callbacks(self) -> None: + script_callbacks.remove_current_script_callbacks() for module in self.get_modules(): module_hooks.remove_module_forward_hook(module.to_q, 'tcg_to_q_hook') module_hooks.remove_module_forward_hook(module.to_k, 'tcg_to_k_hook') module_hooks.modules_remove_field(module, 'tcg_to_q_map') module_hooks.modules_remove_field(module, 'tcg_to_k_map') + module_hooks.modules_remove_field(module, 'tcg_region_mask') + module_hooks.modules_remove_field(module, 'tcg_token_indices') module_hooks.modules_remove_field(module.to_q, 'tcg_parent_module') module_hooks.modules_remove_field(module.to_k, 'tcg_parent_module') @@ -97,7 +166,6 @@ def get_xyz_axis_options(self) -> dict: debug_coord = lambda x: (round(x[0,0,0].item(),3), round(x[0,0,1].item(), 3)) - def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margin_strength, clamp=0): """ Given a set of vertices, calculate the displacement force given by the sum of margin force and repulsive force. Arguments: @@ -332,6 +400,7 @@ def calculate_centroid(attention_map): """ # necessary to avoid inf attention_map = attention_map.to(torch.float32) + B, H, W, C = attention_map.shape # Create tensors of the y and x coordinates y_coords = torch.arange(H, dtype=attention_map.dtype, device=attention_map.device).view(1, H, 1, 1) @@ -342,7 +411,7 @@ def calculate_centroid(attention_map): weighted_sum_x = torch.sum(x_coords * attention_map, dim=[1, 2]) # Calculate the total weights - total_weights = torch.sum(attention_map, dim=[1, 2]) + total_weights = torch.sum(attention_map, dim=[1, 2]) + torch.finfo(attention_map.dtype).eps # Calculate the centroids centroid_y = weighted_sum_y / total_weights @@ -368,7 +437,7 @@ def detect_conflict(attention_map, region, theta): assert region.shape[1:3] == attention_map.shape[1:3], "Region mask must match spatial dimensions of attention map" # Calculate the mean attention within the region #region = region.unsqueeze(-1) # Add channel dimension: (B, H, W) -> (B, H, W, 1) - attention_in_region = attention_map * region # Element-wise multiplication + attention_in_region = attention_map * region.unsqueeze(-1) # Element-wise multiplication #mean_attention_in_region = attention_in_region[attention_in_region > 0] mean_attention_in_region = torch.sum(attention_in_region, dim=(1, 2)) / torch.sum(region, dim=(1, 2)) # Mean over (H, W) # Compare with threshold theta @@ -519,16 +588,19 @@ def get_attention_scores(to_q_map, to_k_map, dtype): # attn_probs = attn_scores.softmax(dim=-1).to(device=shared.device, dtype=to_q_map.dtype) attn_probs = to_q_map @ to_k_map.transpose(-1, -2) + channel_dim = to_q_map.shape[-1] + attn_probs /= (channel_dim ** 0.5) + attn_probs = attn_probs.softmax(dim=-1).to(device=shared.device, dtype=to_q_map.dtype) - # avoid nan by converting to float32 and subtracting max - attn_probs = attn_probs.to(dtype=torch.float32) # - attn_probs -= torch.max(attn_probs) + # # avoid nan by converting to float32 and subtracting max + # attn_probs = attn_probs.to(dtype=torch.float32) # + # attn_probs -= torch.max(attn_probs) - torch.exp(attn_probs, out = attn_probs) - summed = attn_probs.sum(dim=-1, keepdim=True, dtype=torch.float32) - attn_probs /= summed + # torch.exp(attn_probs, out = attn_probs) + # summed = attn_probs.sum(dim=-1, keepdim=True, dtype=torch.float32) + # attn_probs /= summed + torch.finfo(torch.float32).eps - attn_probs = attn_probs.to(dtype=dtype) + #attn_probs = attn_probs.to(dtype=dtype) return attn_probs @@ -572,6 +644,7 @@ def color_region(image, yx, ab, color=1.0, mode='set'): if __name__ == '__main__': + tempdir = os.path.join(os.getcwd(), 'temp') os.makedirs(tempdir, exist_ok=True) img_idx = 0 @@ -657,7 +730,7 @@ def color_region(image, yx, ab, color=1.0, mode='set'): s_margin = 1.0 s_repl = 1.0 - # displacement forces + # test displacement forces d_zero = torch.tensor([[[0.0, 0]]], dtype=torch.float16, device='cuda') # B C 2 d_down = torch.tensor([[[0.1, 0]]], dtype=torch.float16, device='cuda') # B C 2 From e3b986868f3c0d5ce9eecb44306ef09d23dd240d Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 27 May 2024 21:42:28 -0700 Subject: [PATCH 36/55] implement basic controls --- scripts/tcg.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index e254cf4..89b99df 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -36,7 +36,10 @@ def title(self) -> str: def setup_ui(self, is_img2img) -> list: with gr.Accordion('TCG', open=True): active = gr.Checkbox(label="Active", value=True) - opts = [active] + strength = gr.Slider(label="Strength", value=1.0, minimum=-200.0, maximum=200.0, step=1.0) + f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=-10.0, maximum=10.0, step=0.1) + f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=-10.0, maximum=10.0, step=0.1) + opts = [active, strength, f_margin, f_repl] for opt in opts: opt.do_not_save_to_config = True return opts @@ -44,7 +47,7 @@ def setup_ui(self, is_img2img) -> list: def get_modules(self): return module_hooks.get_modules( module_name_filter='CrossAttention') - def before_process_batch(self, p, active, *args, **kwargs): + def before_process_batch(self, p, active, strength, f_margin, f_repl, *args, **kwargs): self.unhook_callbacks() active = getattr(p, 'tcg_active', active) if not active: @@ -89,10 +92,8 @@ def tcg_forward_hook(module, input, kwargs, output): if not torch.any(conflicts > 0.01): return - centroids = calculate_centroid(attn_map) # (B, C, 2) - s_margin = 5.0 - s_repl = 1.0 + logger.debug(centroids) displ_force = displacement_force(attn_map, centroids, region_mask_centroid, s_repl, s_margin, clamp = 10) # B C 2 # reassign the attn map @@ -101,10 +102,12 @@ def tcg_forward_hook(module, input, kwargs, output): modified_attn_map, out_centroids = apply_displacements(output_attn_map[..., module.tcg_token_indices].unsqueeze(0), centroids, displ_force) output_attn_map[..., module.tcg_token_indices] = modified_attn_map.squeeze(0) - output = output_attn_map - - + loss = output - output_attn_map + loss = loss / (loss.norm() + torch.finfo(loss.dtype).eps) + + output += strength * loss + #output = output_attn_map def tcg_to_q_hook(module, input, kwargs, output): setattr(module.tcg_parent_module[0], 'tcg_to_q_map', output) @@ -120,7 +123,7 @@ def cfg_denoised_callback(params: script_callbacks.CFGDenoisedParams): # TODO: Parameterize this mask_H, mask_W = 64, 64 temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) - temp_region_mask[0, 0:mask_H-1, 0:mask_W//2] = 1.0 # mask the left half of the thing + temp_region_mask[0, mask_H//4:3*mask_H//4, mask_W//4:3*mask_W//4] = 1.0 # mask the center of the canvas for module in self.get_modules(): if not module.network_layer_name.endswith('attn2'): @@ -143,6 +146,7 @@ def unhook_callbacks(self) -> None: for module in self.get_modules(): module_hooks.remove_module_forward_hook(module.to_q, 'tcg_to_q_hook') module_hooks.remove_module_forward_hook(module.to_k, 'tcg_to_k_hook') + module_hooks.remove_module_forward_hook(module, 'tcg_forward_hook') module_hooks.modules_remove_field(module, 'tcg_to_q_map') module_hooks.modules_remove_field(module, 'tcg_to_k_map') module_hooks.modules_remove_field(module, 'tcg_region_mask') From 4881468b4b5fc59db41a187f56c41ebae71ff097 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Mon, 27 May 2024 22:12:24 -0700 Subject: [PATCH 37/55] add controls and xyz --- scripts/tcg.py | 60 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 89b99df..8d70b46 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -10,7 +10,7 @@ sys.path.append(f'{os.getcwd()}/extensions/sd-webui-incantations') from scripts.ui_wrapper import UIWrapper from scripts.incant_utils import module_hooks, plot_tools, prompt_utils -from modules import shared, script_callbacks +from modules import shared, scripts, script_callbacks """ WIP Implementation of https://arxiv.org/abs/2404.11824 @@ -28,17 +28,22 @@ class TCGExtensionScript(UIWrapper): def __init__(self): self.infotext_fields: list = [] - self.paste_field_names: list = [] + self.paste_field_names: list = [ + "tcg_active", + "tcg_strength", + "tcg_f_margin", + "tcg_f_repl" + ] def title(self) -> str: raise 'TCG [arXiv:2404.11824]' def setup_ui(self, is_img2img) -> list: with gr.Accordion('TCG', open=True): - active = gr.Checkbox(label="Active", value=True) - strength = gr.Slider(label="Strength", value=1.0, minimum=-200.0, maximum=200.0, step=1.0) - f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=-10.0, maximum=10.0, step=0.1) - f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=-10.0, maximum=10.0, step=0.1) + active = gr.Checkbox(label="Active", value=True, elem_id="tcg_active") + strength = gr.Slider(label="Strength", value=1.0, minimum=-200.0, maximum=200.0, step=1.0, elem_id="tcg_strength") + f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_margin") + f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_repl") opts = [active, strength, f_margin, f_repl] for opt in opts: opt.do_not_save_to_config = True @@ -52,6 +57,9 @@ def before_process_batch(self, p, active, strength, f_margin, f_repl, *args, **k active = getattr(p, 'tcg_active', active) if not active: return + strength = getattr(p, 'tcg_strength', strength) + f_margin = getattr(p, 'tcg_f_margin', f_margin) + f_repl = getattr(p, 'tcg_f_repl', f_repl) batch_size = p.batch_size height, width = p.height, p.width @@ -80,7 +88,7 @@ def tcg_forward_hook(module, input, kwargs, output): region_mask = module.tcg_region_mask if region_mask.shape[1:3] != attn_map.shape[1:3]: region_mask = region_mask.permute(0, 3, 1, 2) # (B, H, W, 1) -> (B, 1, H, W) - region_mask = F.interpolate(region_mask, size=(attn_map.shape[1:3]), mode='nearest') + region_mask = F.interpolate(region_mask, size=(attn_map.shape[1:3]), mode='bilinear') region_mask = region_mask.permute(0, 2, 3, 1) # (B, 1, H, W) -> (B, H, W, 1) module.tcg_region_mask = region_mask @@ -90,11 +98,12 @@ def tcg_forward_hook(module, input, kwargs, output): theta = 0.0000001 # parameterize this conflicts = detect_conflict(attn_map, region_mask, theta) # (B, C) if not torch.any(conflicts > 0.01): - return + logger.debug("No conflicts detected") + #return centroids = calculate_centroid(attn_map) # (B, C, 2) logger.debug(centroids) - displ_force = displacement_force(attn_map, centroids, region_mask_centroid, s_repl, s_margin, clamp = 10) # B C 2 + displ_force = displacement_force(attn_map, centroids, region_mask_centroid, f_repl, f_margin, clamp = 10) # B C 2 # reassign the attn map output_attn_map = output.detach().clone() @@ -102,7 +111,7 @@ def tcg_forward_hook(module, input, kwargs, output): modified_attn_map, out_centroids = apply_displacements(output_attn_map[..., module.tcg_token_indices].unsqueeze(0), centroids, displ_force) output_attn_map[..., module.tcg_token_indices] = modified_attn_map.squeeze(0) - loss = output - output_attn_map + loss = output_attn_map - output loss = loss / (loss.norm() + torch.finfo(loss.dtype).eps) output += strength * loss @@ -123,7 +132,7 @@ def cfg_denoised_callback(params: script_callbacks.CFGDenoisedParams): # TODO: Parameterize this mask_H, mask_W = 64, 64 temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) - temp_region_mask[0, mask_H//4:3*mask_H//4, mask_W//4:3*mask_W//4] = 1.0 # mask the center of the canvas + temp_region_mask[0, 1*mask_H//8 : 7*mask_H//8 , 2*mask_W//8 : 5*mask_W//8] = 1.0 # mask the left half ish of the canvas for module in self.get_modules(): if not module.network_layer_name.endswith('attn2'): @@ -164,7 +173,14 @@ def process_batch(self, p, *args, **kwargs): pass def get_xyz_axis_options(self) -> dict: - return {} + xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ in ("xyz_grid.py", "scripts.xyz_grid")][0].module + extra_axis_options = { + xyz_grid.AxisOption("[TCG] Active", str, tcg_apply_override('tcg_active', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)), + xyz_grid.AxisOption("[TCG] Strength", float, tcg_apply_field("tcg_strength")), + xyz_grid.AxisOption("[TCG] Repulsion Force", float, tcg_apply_field("tcg_f_repl")), + xyz_grid.AxisOption("[TCG] Margin Force", float, tcg_apply_field("tcg_f_margin")), + } + return extra_axis_options @@ -753,7 +769,7 @@ def color_region(image, yx, ab, color=1.0, mode='set'): verts = calculate_centroid(attn_map_points) # (B, C, 2) # Displacement forces - displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin, clamp = 10) # B C 2 + displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin, clamp = 0) # B C 2 if torch.isnan(displ_force).any(): logger.warning(f'Nan in displ_force at iter {i}') @@ -773,4 +789,20 @@ def color_region(image, yx, ab, color=1.0, mode='set'): ofs = c * 100 #plot_point(copied_map[:c], img_idx+ofs+i+c, radius=1, color=1.0) _png(copied_map[..., c].unsqueeze(-1), img_idx+ofs+i, f'Displacement Forces Channel {c} Step {i}') - #_png(copied_map, img_idx+i, f'Displacement Forces Step {i}') \ No newline at end of file + #_png(copied_map, img_idx+i, f'Displacement Forces Step {i}') + +# XYZ Plot +# Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py +def tcg_apply_override(field, boolean: bool = False): + def fun(p, x, xs): + if boolean: + x = True if x.lower() == "true" else False + setattr(p, field, x) + return fun + +def tcg_apply_field(field): + def fun(p, x, xs): + if not hasattr(p, "tcg_active"): + p.tcg_active = True + setattr(p, field, x) + return fun \ No newline at end of file From 309ef40074a5a9960ad2a4522b43a505328b12ce Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Tue, 28 May 2024 14:21:09 -0700 Subject: [PATCH 38/55] add soft threshold fn --- scripts/tcg.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/scripts/tcg.py b/scripts/tcg.py index 8d70b46..2f42890 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -211,6 +211,27 @@ def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margi return f_rep + f_margin +def soft_threshold(attention_map, threshold=0.5, sharpness=10): + """ Soft threshold the attention map channels based on the given threshold. Derived from arXiv:2306.00986 + Arguments: + attention_map: torch.Tensor - The attention map to threshold. Shape: (B, H, W, C) + threshold: float - The threshold value between 0.0 and 1.0 relative to the minimum/maximum attention value + sharpness: float - The sharpness of the thresholding function + Returns: + torch.Tensor - The attention map thresholded over all C. Shape: (B, H, W, C) + """ + def normalize_map(attnmap): + flattened_attnmap = attnmap.view(attnmap.shape[0], -1, attnmap.shape[-1]) + min_val = torch.min(flattened_attnmap, dim=1, keepdim=True).values.unsqueeze(dim=1) + max_val = torch.max(flattened_attnmap, dim=1, keepdim=True).values.unsqueeze(dim=1) + normalized_attn = attnmap - min_val / (max_val - min_val) + return normalized_attn + threshold = max(0.0, min(1.0, threshold)) + normalized_attn = normalize_map(attention_map) + normalized_attn = normalize_map(torch.sigmoid(sharpness * (normalized_attn - threshold))) + return normalized_attn + + def distances_to_nearest_edges(verts, h, w): """ Calculate the distances and direction to the nearest edge bounded by (H, W) for each channel's vertices Arguments: From e8a6ee1a11e40650c001ae8e5624cac53462c01e Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Tue, 28 May 2024 14:21:31 -0700 Subject: [PATCH 39/55] hack to fix conflict detector --- scripts/tcg.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 2f42890..a222bd5 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -478,7 +478,11 @@ def detect_conflict(attention_map, region, theta): assert region.shape[1:3] == attention_map.shape[1:3], "Region mask must match spatial dimensions of attention map" # Calculate the mean attention within the region #region = region.unsqueeze(-1) # Add channel dimension: (B, H, W) -> (B, H, W, 1) - attention_in_region = attention_map * region.unsqueeze(-1) # Element-wise multiplication + # HACK: fixme + if region.dim() != attention_map.dim(): + attention_in_region = attention_map * region.unsqueeze(-1) # Element-wise multiplication + else: + attention_in_region = attention_map * region #mean_attention_in_region = attention_in_region[attention_in_region > 0] mean_attention_in_region = torch.sum(attention_in_region, dim=(1, 2)) / torch.sum(region, dim=(1, 2)) # Mean over (H, W) # Compare with threshold theta From 34bf4558b6d5e339a50fc32d69ee13a78b2d0489 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Tue, 28 May 2024 14:21:52 -0700 Subject: [PATCH 40/55] update routine to zero displacements where unneeded --- scripts/tcg.py | 47 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index a222bd5..f71f745 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -701,6 +701,15 @@ def color_region(image, yx, ab, color=1.0, mode='set'): title=f'{title}', ) + def _png_batch(attnmap, name, title): + for batch_idx, _ in enumerate(attnmap): + for channel_idx in range(attnmap.shape[-1]): + plot_tools.plot_attention_map( + attnmap[batch_idx, :, :, channel_idx], + save_path=os.path.join(tempdir, f'batch{batch_idx:04}_{name:04}_{channel_idx:02}.png'), + title=f'{title} - Channel {channel_idx}', + ) + B, H, W, C = 1, 64, 64, 3 dtype = torch.float16 device = 'cuda' @@ -728,10 +737,12 @@ def color_region(image, yx, ab, color=1.0, mode='set'): # color the target region color_region(attention_map, region_yx, region_ab, color=1.0, mode='set') - _png(attention_map, 0, 'Initial Attn Map') attention_map_region = attention_map.detach().clone() + attention_map_region = torch.zeros((B, H, W, 1)).to(device, dtype) # B H W 1 + color_region(attention_map_region, region_yx, region_ab, color=1.0, mode='set') + _png(attention_map_region, 1, 'Attn Map Region + Centroid') # calculate centroid of region centroid = calculate_centroid(attention_map) # (B, C, 2) @@ -742,13 +753,13 @@ def color_region(image, yx, ab, color=1.0, mode='set'): for batch_idx in range(B): for channel_idx in range(C): plot_point(attention_map[batch_idx, ..., channel_idx].unsqueeze(0).unsqueeze(-1), list(centroid_points[batch_idx, channel_idx]), radius=1, color=1) - _png(attention_map, 1, 'Attn Map Region + Centroid') + #_png(attention_map, 1, 'Attn Map Region + Centroid') # plot verts attn_map_points = torch.zeros_like(attention_map) d_region_yx = [1*H//8, 1*W//8] d_region_ab = [6*H//8, 3*W//8] - color_region(attn_map_points, d_region_yx, d_region_ab, color=0.5, mode='set') + # color_region(attn_map_points, d_region_yx, d_region_ab, color=0.5, mode='set') c0_region_yx = [12,24] c0_region_ab = [36, 48] @@ -760,8 +771,13 @@ def color_region(image, yx, ab, color=1.0, mode='set'): c2_region_ab = [63, 63] if attn_map_points.shape[-1] > 1: color_region(attn_map_points[..., 0].unsqueeze(-1), c0_region_yx, c0_region_ab, color=1, mode='set') - color_region(attn_map_points[..., 1].unsqueeze(-1), c1_region_yx, c1_region_ab, color=1, mode='set') - color_region(attn_map_points[..., 2].unsqueeze(-1), c2_region_yx, c2_region_ab, color=1, mode='set') + color_region(attn_map_points[..., 1].unsqueeze(-1), c1_region_yx, c1_region_ab, color=0.5, mode='set') + color_region(attn_map_points[..., 2].unsqueeze(-1), c2_region_yx, c2_region_ab, color=0.2, mode='set') + _png_batch(attn_map_points, 3, 'Thresholded Attn Map Points - Channel 0') + + attn_map_points_thresholded = soft_threshold(attn_map_points, threshold=0.5, sharpness=10) + #_png(attn_map_points_thresholded, 3, 'Thresholded Attn Map Points') + _png_batch(attn_map_points_thresholded, 4, 'Thresholded Attn Map Points') # set areas outside the region to 0 #attn_map_points = attn_map_points * attention_map @@ -769,23 +785,26 @@ def color_region(image, yx, ab, color=1.0, mode='set'): for v in verts.squeeze(0): plot_point(attn_map_points, v, radius=1, color=1) - _png(attn_map_points, 2, 'Attn Map Proxy Map') + _png(attn_map_points, 6, 'Attn Map Proxy Map') # strengths - s_margin = 1.0 + s_margin = 5.0 s_repl = 1.0 # test displacement forces - d_zero = torch.tensor([[[0.0, 0]]], dtype=torch.float16, device='cuda') # B C 2 + displ_zero = torch.tensor([0.0, 0.0], dtype=torch.float16, device='cuda') # B C 2 + #d_zero = torch.tensor([[[0.0, 0]]], dtype=torch.float16, device='cuda') # B C 2 d_down = torch.tensor([[[0.1, 0]]], dtype=torch.float16, device='cuda') # B C 2 # simulate displacement forces on our points - img_idx = 4 + img_idx = 7 iters = 100 steps = 5 for i in range(iters): + logger.debug('Step %d', i) + # Check for conflicts between target region and attention map - theta = 0.001 + theta = 0.01 conflict_detection = detect_conflict(attn_map_points, attention_map_region, theta) # (B, C) logger.debug(f'Conflict Detection: {conflict_detection}') if not conflict_detection.any(): @@ -793,8 +812,14 @@ def color_region(image, yx, ab, color=1.0, mode='set'): break verts = calculate_centroid(attn_map_points) # (B, C, 2) + # Displacement forces - displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin, clamp = 0) # B C 2 + displ_force = displacement_force(attn_map_points, verts, centroid, s_repl, s_margin, clamp = 10) # B C 2 + + # zero displacement where conflict is none + displ_force = displ_force * conflict_detection.unsqueeze(-1) + + if torch.isnan(displ_force).any(): logger.warning(f'Nan in displ_force at iter {i}') From d501d5dab549c2231da9307f721d3fd97ac0b1e6 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Tue, 28 May 2024 14:59:57 -0700 Subject: [PATCH 41/55] fix normalize map fn --- scripts/tcg.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index f71f745..265efc6 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -221,10 +221,13 @@ def soft_threshold(attention_map, threshold=0.5, sharpness=10): torch.Tensor - The attention map thresholded over all C. Shape: (B, H, W, C) """ def normalize_map(attnmap): - flattened_attnmap = attnmap.view(attnmap.shape[0], -1, attnmap.shape[-1]) - min_val = torch.min(flattened_attnmap, dim=1, keepdim=True).values.unsqueeze(dim=1) - max_val = torch.max(flattened_attnmap, dim=1, keepdim=True).values.unsqueeze(dim=1) - normalized_attn = attnmap - min_val / (max_val - min_val) + B, H, W, C = attnmap.shape + flattened_attnmap = attnmap.view(attnmap.shape[0], H*W, attnmap.shape[-1]).transpose(-1, -2) # B, C, H*W + min_val = torch.min(flattened_attnmap, dim=-1).values.unsqueeze(-1) # (B, C, 1) + max_val = torch.max(flattened_attnmap, dim=-1).values.unsqueeze(-1) # (B, C, 1) + normalized_attn = (flattened_attnmap - min_val) / ((max_val - min_val) + torch.finfo(attnmap.dtype).eps) + normalized_attn = normalized_attn.view(B, C, H*W).transpose(-1, -2) # B, H*W, C + normalized_attn = normalized_attn.view(B, H, W, C) return normalized_attn threshold = max(0.0, min(1.0, threshold)) normalized_attn = normalize_map(attention_map) From be0563490175834ea8908d371b9dd98d551f09aa Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Tue, 28 May 2024 15:57:49 -0700 Subject: [PATCH 42/55] add more parameters and selfguidance term --- scripts/tcg.py | 117 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 95 insertions(+), 22 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 265efc6..40b9d2b 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -5,12 +5,25 @@ import random import torch.nn.functional as F +if os.environ.get('INCANT_DEBUG', None): + # suppress excess logging + logging.getLogger("PIL").setLevel(logging.WARNING) + logging.getLogger("matplotlib").setLevel(logging.WARNING) + if __name__ == '__main__' and os.environ.get('INCANT_DEBUG', None): sys.path.append(f'{os.getcwd()}') sys.path.append(f'{os.getcwd()}/extensions/sd-webui-incantations') +else: + from scripts.incant_utils import module_hooks, plot_tools, prompt_utils + from modules import shared, scripts, script_callbacks + +from scripts.incant_utils import plot_tools from scripts.ui_wrapper import UIWrapper -from scripts.incant_utils import module_hooks, plot_tools, prompt_utils -from modules import shared, scripts, script_callbacks + + +logger = logging.getLogger(__name__) +logging.basicConfig() +logger.setLevel(logging.DEBUG) """ WIP Implementation of https://arxiv.org/abs/2404.11824 @@ -19,11 +32,6 @@ """ -logger = logging.getLogger(__name__) -if os.environ.get('INCANT_DEBUG', None): - # suppress excess logging - logging.getLogger("PIL").setLevel(logging.WARNING) - logging.getLogger("matplotlib").setLevel(logging.WARNING) class TCGExtensionScript(UIWrapper): def __init__(self): @@ -32,7 +40,11 @@ def __init__(self): "tcg_active", "tcg_strength", "tcg_f_margin", - "tcg_f_repl" + "tcg_f_repl", + "tcg_theta", + "tcg_attn_threshold", + "tcg_sharpness", + "tcg_selfguidance_scale", ] def title(self) -> str: @@ -41,10 +53,14 @@ def title(self) -> str: def setup_ui(self, is_img2img) -> list: with gr.Accordion('TCG', open=True): active = gr.Checkbox(label="Active", value=True, elem_id="tcg_active") - strength = gr.Slider(label="Strength", value=1.0, minimum=-200.0, maximum=200.0, step=1.0, elem_id="tcg_strength") + strength = gr.Slider(label="Strength", value=1.0, minimum=-5.0, maximum=5.0, step=0.1, elem_id="tcg_strength") f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_margin") f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_repl") - opts = [active, strength, f_margin, f_repl] + theta = gr.Slider(label="Conflict Threshold", value=0.01, minimum=0.0, maximum=1.0, step=0.001, elem_id="tcg_theta") + threshold = gr.Slider(label="Soft Threshold", value=0.5, minimum=0.0, maximum=1.0, step=0.01, elem_id="tcg_attn_threshold") + sharpness = gr.Slider(label="Threshold Sharpness", value=10.0, minimum=0.1, maximum=20.0, step=0.1, elem_id="tcg_sharpness") + selfguidance_scale = gr.Slider(label="Self-Guidance Scale", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_selfguidance_scale") + opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale] for opt in opts: opt.do_not_save_to_config = True return opts @@ -52,7 +68,7 @@ def setup_ui(self, is_img2img) -> list: def get_modules(self): return module_hooks.get_modules( module_name_filter='CrossAttention') - def before_process_batch(self, p, active, strength, f_margin, f_repl, *args, **kwargs): + def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, *args, **kwargs): self.unhook_callbacks() active = getattr(p, 'tcg_active', active) if not active: @@ -60,6 +76,10 @@ def before_process_batch(self, p, active, strength, f_margin, f_repl, *args, **k strength = getattr(p, 'tcg_strength', strength) f_margin = getattr(p, 'tcg_f_margin', f_margin) f_repl = getattr(p, 'tcg_f_repl', f_repl) + theta = getattr(p, 'tcg_theta', theta) + threshold= getattr(p, 'tcg_attn_threshold', threshold) + sharpness = getattr(p, 'tcg_sharpness', sharpness) + selfguidance_scale = getattr(p, 'tcg_selfguidance_scale', selfguidance_scale) batch_size = p.batch_size height, width = p.height, p.width @@ -74,15 +94,37 @@ def tcg_forward_hook(module, input, kwargs, output): # calc attn scores q_map = module.tcg_to_q_map k_map = module.tcg_to_k_map + # select k tokens + k_map = k_map.transpose(-1, -2)[..., module.tcg_token_indices].transpose(-1,-2) attn_scores = get_attention_scores(q_map, k_map, dtype=q_map.dtype) # (2*B, H*W, C) + attn_scores = attn_scores.to(torch.float32) B, HW, C = attn_scores.shape downscale_h = round((HW * (height / width)) ** 0.5) attn_scores = attn_scores.view(2, attn_scores.size(0)//2, downscale_h, HW//downscale_h, attn_scores.size(-1)).mean(dim=0) # (2*B, HW, C) -> (B, H, W, C) # slice attn map - attn_map = attn_scores[..., module.tcg_token_indices] - attn_map = attn_map.detach().clone() # (B, H, W, K) where K is the subset of tokens + attn_map = attn_scores.detach().clone() # (B, H, W, K) where K is the subset of tokens + + # threshold it + # also represents object shape + attn_map = soft_threshold(attn_map, threshold=threshold, sharpness=sharpness) + + # self-guidance + inner_dims = attn_map.shape[1:-1] + attn_map = attn_map.view(attn_map.size(0), -1, attn_map.size(-1)) + + shape_sum = torch.sum(attn_map, dim=1) # (B, HW) + + obj_appearance = shape_sum * attn_map + obj_appearance /= shape_sum + + self_guidance = obj_appearance + self_guidance = self_guidance.to(output.dtype) + self_guidance_factor = output.detach().clone() + self_guidance_factor[..., module.tcg_token_indices] = self_guidance + + attn_map = attn_map.view(attn_map.size(0), *inner_dims, attn_map.size(-1)) # region mask region_mask = module.tcg_region_mask @@ -95,26 +137,32 @@ def tcg_forward_hook(module, input, kwargs, output): region_mask_centroid = calculate_centroid(region_mask) # (B, C, 2) # detect conflicts and return if none - theta = 0.0000001 # parameterize this conflicts = detect_conflict(attn_map, region_mask, theta) # (B, C) if not torch.any(conflicts > 0.01): logger.debug("No conflicts detected") - #return + return centroids = calculate_centroid(attn_map) # (B, C, 2) - logger.debug(centroids) + #logger.debug(centroids) + displ_force = displacement_force(attn_map, centroids, region_mask_centroid, f_repl, f_margin, clamp = 10) # B C 2 + # zero out displacement force + displ_force = displ_force * conflicts.unsqueeze(-1) + logger.debug("Displacements: %s", displ_force) + # reassign the attn map output_attn_map = output.detach().clone() - #output_attn_map[..., module.tcg_token_indices] = output.detach().clone() modified_attn_map, out_centroids = apply_displacements(output_attn_map[..., module.tcg_token_indices].unsqueeze(0), centroids, displ_force) output_attn_map[..., module.tcg_token_indices] = modified_attn_map.squeeze(0) - loss = output_attn_map - output - loss = loss / (loss.norm() + torch.finfo(loss.dtype).eps) + loss = output - output_attn_map + loss = normalize_map(loss) + loss **= 2 + output += strength * loss + output += selfguidance_scale * self_guidance_factor #output = output_attn_map @@ -179,6 +227,10 @@ def get_xyz_axis_options(self) -> dict: xyz_grid.AxisOption("[TCG] Strength", float, tcg_apply_field("tcg_strength")), xyz_grid.AxisOption("[TCG] Repulsion Force", float, tcg_apply_field("tcg_f_repl")), xyz_grid.AxisOption("[TCG] Margin Force", float, tcg_apply_field("tcg_f_margin")), + xyz_grid.AxisOption("[TCG] Conflict Threshold", float, tcg_apply_field("tcg_theta")), + xyz_grid.AxisOption("[TCG] Soft Threshold", float, tcg_apply_field("tcg_attn_threshold")), + xyz_grid.AxisOption("[TCG] Threshold Sharpness", float, tcg_apply_field("tcg_sharpness")), + xyz_grid.AxisOption("[TCG] Self-Guidance Scale", float, tcg_apply_field("tcg_selfguidance_scale")), } return extra_axis_options @@ -211,6 +263,21 @@ def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margi return f_rep + f_margin +def normalize_map(attnmap): + """ Normalize the attention map over the channel dimension + Arguments: + attnmap: torch.Tensor - The attention map to normalize. Shape: (B, HW, C) + Returns: + torch.Tensor - The attention map normalized to (0, 1). Shape: (B, HW, C) + """ + flattened_attnmap = attnmap.transpose(-1, -2) + min_val = torch.min(flattened_attnmap, dim=-1).values.unsqueeze(-1) # (B, C, 1) + max_val = torch.max(flattened_attnmap, dim=-1).values.unsqueeze(-1) # (B, C, 1) + normalized_attn = (flattened_attnmap - min_val) / ((max_val - min_val) + torch.finfo(attnmap.dtype).eps) + normalized_attn = normalized_attn.transpose(-1, -2) + return normalized_attn + + def soft_threshold(attention_map, threshold=0.5, sharpness=10): """ Soft threshold the attention map channels based on the given threshold. Derived from arXiv:2306.00986 Arguments: @@ -220,7 +287,13 @@ def soft_threshold(attention_map, threshold=0.5, sharpness=10): Returns: torch.Tensor - The attention map thresholded over all C. Shape: (B, H, W, C) """ - def normalize_map(attnmap): + def _normalize_map(attnmap): + """ Normalize the attention map over the channel dimension + Arguments: + attnmap: torch.Tensor - The attention map to normalize. Shape: (B, H, W, C) or (B, HW, C) + Returns: + torch.Tensor - The attention map normalized to (0, 1). Shape: (B, H, W, C) + """ B, H, W, C = attnmap.shape flattened_attnmap = attnmap.view(attnmap.shape[0], H*W, attnmap.shape[-1]).transpose(-1, -2) # B, C, H*W min_val = torch.min(flattened_attnmap, dim=-1).values.unsqueeze(-1) # (B, C, 1) @@ -230,8 +303,8 @@ def normalize_map(attnmap): normalized_attn = normalized_attn.view(B, H, W, C) return normalized_attn threshold = max(0.0, min(1.0, threshold)) - normalized_attn = normalize_map(attention_map) - normalized_attn = normalize_map(torch.sigmoid(sharpness * (normalized_attn - threshold))) + normalized_attn = _normalize_map(attention_map) + normalized_attn = _normalize_map(torch.sigmoid(sharpness * (normalized_attn - threshold))) return normalized_attn From 36934d8a50bb3b5011a799cbb01ddb2c85cb98f4 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Tue, 28 May 2024 23:05:59 -0700 Subject: [PATCH 43/55] try to fix completely wrong dimensions when swapping token output --- scripts/tcg.py | 154 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 124 insertions(+), 30 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 40b9d2b..4b752e7 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -1,5 +1,7 @@ import os, sys import logging +import numpy as np +from PIL import Image import gradio as gr import torch import random @@ -25,6 +27,7 @@ logging.basicConfig() logger.setLevel(logging.DEBUG) + """ WIP Implementation of https://arxiv.org/abs/2404.11824 Author: v0xie @@ -32,7 +35,6 @@ """ - class TCGExtensionScript(UIWrapper): def __init__(self): self.infotext_fields: list = [] @@ -52,23 +54,48 @@ def title(self) -> str: def setup_ui(self, is_img2img) -> list: with gr.Accordion('TCG', open=True): - active = gr.Checkbox(label="Active", value=True, elem_id="tcg_active") - strength = gr.Slider(label="Strength", value=1.0, minimum=-5.0, maximum=5.0, step=0.1, elem_id="tcg_strength") - f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_margin") - f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_repl") - theta = gr.Slider(label="Conflict Threshold", value=0.01, minimum=0.0, maximum=1.0, step=0.001, elem_id="tcg_theta") - threshold = gr.Slider(label="Soft Threshold", value=0.5, minimum=0.0, maximum=1.0, step=0.01, elem_id="tcg_attn_threshold") - sharpness = gr.Slider(label="Threshold Sharpness", value=10.0, minimum=0.1, maximum=20.0, step=0.1, elem_id="tcg_sharpness") - selfguidance_scale = gr.Slider(label="Self-Guidance Scale", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_selfguidance_scale") - opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale] + with gr.Row(): + image_mask = gr.Image(type='pil', image_mode='L', label="Mask", interactive=False, height = 256) + with gr.Row(): + generate_mask = gr.Button("Generate Mask", elem_id="tcg_btn_generate_mask") + with gr.Row(): + with gr.Column(): + left = gr.Slider(label="Left", value=0.2, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_left") + with gr.Column(): + with gr.Row(): + top = gr.Slider(label="Top", value=0.3, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_top") + with gr.Row(): + bottom = gr.Slider(label="Bottom", value=0.7, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_bottom") + with gr.Column(): + right = gr.Slider(label="Right", value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_right") + with gr.Row(): + active = gr.Checkbox(label="Active", value=True, elem_id="tcg_active") + strength = gr.Slider(label="Strength", value=1.0, minimum=-5.0, maximum=5.0, step=0.1, elem_id="tcg_strength") + f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_margin") + f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_repl") + theta = gr.Slider(label="Conflict Threshold", value=0.01, minimum=0.0, maximum=1.0, step=0.001, elem_id="tcg_theta") + threshold = gr.Slider(label="Soft Threshold", value=0.5, minimum=0.0, maximum=1.0, step=0.01, elem_id="tcg_attn_threshold") + sharpness = gr.Slider(label="Threshold Sharpness", value=10.0, minimum=0.1, maximum=20.0, step=0.1, elem_id="tcg_sharpness") + selfguidance_scale = gr.Slider(label="Self-Guidance Scale", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_selfguidance_scale") + with gr.Row(): + start_step = gr.Slider(label="Start Step", value=0, minimum=0, maximum=100, step=1, elem_id="tcg_start_step") + end_step = gr.Slider(label="End Step", value=0, minimum=1, maximum=100, step=1, elem_id="tcg_end_step") + opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step] for opt in opts: opt.do_not_save_to_config = True + + generate_mask.click( + create_mask, + inputs = [left, right, top, bottom], + outputs = [image_mask] + ) + return opts def get_modules(self): return module_hooks.get_modules( module_name_filter='CrossAttention') - def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, *args, **kwargs): + def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step, *args, **kwargs): self.unhook_callbacks() active = getattr(p, 'tcg_active', active) if not active: @@ -80,24 +107,40 @@ def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, thr threshold= getattr(p, 'tcg_attn_threshold', threshold) sharpness = getattr(p, 'tcg_sharpness', sharpness) selfguidance_scale = getattr(p, 'tcg_selfguidance_scale', selfguidance_scale) + start_step = getattr(p, 'tcg_start_step', start_step) + end_step = getattr(p, 'tcg_end_step', end_step) batch_size = p.batch_size height, width = p.height, p.width hw = height * width + setattr(p, 'tcg_current_step', 0) + token_count, max_length = prompt_utils.get_token_count(p.prompt, p.steps, is_positive=True) min_idx = 1 max_idx = token_count+1 token_indices = list(range(min_idx, max_idx)) def tcg_forward_hook(module, input, kwargs, output): + current_step = module.tcg_current_step + if not start_step <= current_step <= end_step: + return + # calc attn scores q_map = module.tcg_to_q_map k_map = module.tcg_to_k_map + v_map = module.tcg_to_v_map + + batch_size, seq_len, inner_dim = output.shape + heads = module.heads + # select k tokens - k_map = k_map.transpose(-1, -2)[..., module.tcg_token_indices].transpose(-1,-2) - attn_scores = get_attention_scores(q_map, k_map, dtype=q_map.dtype) # (2*B, H*W, C) - attn_scores = attn_scores.to(torch.float32) + # k_map = k_map.transpose(-1, -2)[..., module.tcg_token_indices].transpose(-1,-2) + + # calc attn scores for q and k + out_attn_scores = get_attention_scores(q_map, k_map, dtype=q_map.dtype) # (2*B, H*W, C) + out_attn_scores = out_attn_scores.to(torch.float32) + attn_scores = out_attn_scores[..., module.tcg_token_indices] # (B, H*W, C) B, HW, C = attn_scores.shape downscale_h = round((HW * (height / width)) ** 0.5) @@ -121,7 +164,8 @@ def tcg_forward_hook(module, input, kwargs, output): self_guidance = obj_appearance self_guidance = self_guidance.to(output.dtype) - self_guidance_factor = output.detach().clone() + self_guidance_factor = out_attn_scores.detach().clone() + #self_guidance_factor = output.detach().clone() self_guidance_factor[..., module.tcg_token_indices] = self_guidance attn_map = attn_map.view(attn_map.size(0), *inner_dims, attn_map.size(-1)) @@ -143,26 +187,38 @@ def tcg_forward_hook(module, input, kwargs, output): return centroids = calculate_centroid(attn_map) # (B, C, 2) - #logger.debug(centroids) displ_force = displacement_force(attn_map, centroids, region_mask_centroid, f_repl, f_margin, clamp = 10) # B C 2 # zero out displacement force displ_force = displ_force * conflicts.unsqueeze(-1) - logger.debug("Displacements: %s", displ_force) + + #displ_force *= strength + #logger.debug("Displacements: %s", displ_force) + + # apply displacements to the attn map + output_attn_map = attn_map.detach().clone() + modified_attn_map, out_centroids = apply_displacements(output_attn_map, centroids, displ_force) - # reassign the attn map - output_attn_map = output.detach().clone() - modified_attn_map, out_centroids = apply_displacements(output_attn_map[..., module.tcg_token_indices].unsqueeze(0), centroids, displ_force) - output_attn_map[..., module.tcg_token_indices] = modified_attn_map.squeeze(0) + # do the rest of attention + output_attn_map = output_attn_map.view(output_attn_map.size(0), -1, output_attn_map.size(-1)) + out_attn_scores[..., module.tcg_token_indices] = output_attn_map + new_output = out_attn_scores @ v_map.to(out_attn_scores.dtype) # (B, HW, C) @ (B, HW, C) -> (B, HW, C) + new_output = new_output.to(output.dtype) - loss = output - output_attn_map + new_output = module.to_out[0](new_output) + new_output = module.to_out[1](new_output) + + logger.debug("old: %s\nnew: %s\ndisplacements: %s\ndisplaced centroids: %s", centroids, out_centroids, displ_force, out_centroids - centroids) + + #loss = strength * torch.norm(output - output_attn_map, dim=-1) ** 2 + selfguidance_scale * self_guidance_factor + loss = output - new_output loss = normalize_map(loss) loss **= 2 - + loss *= strength + loss += selfguidance_scale * self_guidance_factor - output += strength * loss - output += selfguidance_scale * self_guidance_factor + output += loss #output = output_attn_map @@ -171,28 +227,39 @@ def tcg_to_q_hook(module, input, kwargs, output): def tcg_to_k_hook(module, input, kwargs, output): setattr(module.tcg_parent_module[0], 'tcg_to_k_map', output) + + def tcg_to_v_hook(module, input, kwargs, output): + setattr(module.tcg_parent_module[0], 'tcg_to_v_map', output) def cfg_denoised_callback(params: script_callbacks.CFGDenoisedParams): - pass + for module in self.get_modules(): + setattr(module, 'tcg_current_step', p.tcg_current_step) + p.tcg_current_step += 1 script_callbacks.on_cfg_denoised(cfg_denoised_callback) - # TODO: Parameterize this - mask_H, mask_W = 64, 64 - temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) - temp_region_mask[0, 1*mask_H//8 : 7*mask_H//8 , 2*mask_W//8 : 5*mask_W//8] = 1.0 # mask the left half ish of the canvas + mask_H, mask_W = image_mask.size + temp_region_mask = torch.from_numpy(np.array(image_mask)).unsqueeze(-1).unsqueeze(0).to(torch.float32).to(shared.device) # (1, H, W, 1) + temp_region_mask = temp_region_mask.repeat(batch_size, 1, 1, 1) # (B, H, W, 1) + # temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) + #temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) + #temp_region_mask[0, 1*mask_H//8 : 7*mask_H//8 , 2*mask_W//8 : 5*mask_W//8] = 1.0 # mask the left half ish of the canvas for module in self.get_modules(): if not module.network_layer_name.endswith('attn2'): continue + module_hooks.modules_add_field(module, 'tcg_current_step', 0) module_hooks.modules_add_field(module, 'tcg_to_q_map', None) module_hooks.modules_add_field(module, 'tcg_to_k_map', None) + module_hooks.modules_add_field(module, 'tcg_to_v_map', None) module_hooks.modules_add_field(module, 'tcg_region_mask', temp_region_mask) module_hooks.modules_add_field(module, 'tcg_token_indices', torch.tensor(token_indices, dtype=torch.int32, device=shared.device)) module_hooks.modules_add_field(module.to_q, 'tcg_parent_module', [module]) module_hooks.modules_add_field(module.to_k, 'tcg_parent_module', [module]) + module_hooks.modules_add_field(module.to_v, 'tcg_parent_module', [module]) module_hooks.module_add_forward_hook(module.to_q, tcg_to_q_hook, with_kwargs=True) module_hooks.module_add_forward_hook(module.to_k, tcg_to_k_hook, with_kwargs=True) + module_hooks.module_add_forward_hook(module.to_v, tcg_to_v_hook, with_kwargs=True) module_hooks.module_add_forward_hook(module, tcg_forward_hook, with_kwargs=True) def postprocess_batch(self, p, *args, **kwargs): @@ -203,13 +270,17 @@ def unhook_callbacks(self) -> None: for module in self.get_modules(): module_hooks.remove_module_forward_hook(module.to_q, 'tcg_to_q_hook') module_hooks.remove_module_forward_hook(module.to_k, 'tcg_to_k_hook') + module_hooks.remove_module_forward_hook(module.to_v, 'tcg_to_v_hook') module_hooks.remove_module_forward_hook(module, 'tcg_forward_hook') + module_hooks.modules_remove_field(module, 'tcg_current_step') module_hooks.modules_remove_field(module, 'tcg_to_q_map') module_hooks.modules_remove_field(module, 'tcg_to_k_map') + module_hooks.modules_remove_field(module, 'tcg_to_v_map') module_hooks.modules_remove_field(module, 'tcg_region_mask') module_hooks.modules_remove_field(module, 'tcg_token_indices') module_hooks.modules_remove_field(module.to_q, 'tcg_parent_module') module_hooks.modules_remove_field(module.to_k, 'tcg_parent_module') + module_hooks.modules_remove_field(module.to_v, 'tcg_parent_module') def before_process(self, p, *args, **kwargs): pass @@ -917,6 +988,29 @@ def _png_batch(attnmap, name, title): _png(copied_map[..., c].unsqueeze(-1), img_idx+ofs+i, f'Displacement Forces Channel {c} Step {i}') #_png(copied_map, img_idx+i, f'Displacement Forces Step {i}') + +def create_mask(left, right, top, bottom, width=256, height=256): + """ Create a PIL.Image mask for the region bounded by the given normalized coordinates + Arguments: + left: float - The left coordinate of the region + right: float - The right coordinate of the region + top: float - The top coordinate of the region + bottom: float - The bottom coordinate of the region + width: int - The width of the mask + height: int - The height of the mask + Returns: + PIL.Image - The mask image + """ + mask = np.zeros((height, width), dtype=np.uint8) + x0, x1 = int(left * width), int(right * width) + y0, y1 = int(top * height), int(bottom * height) + x_min, x_max = min(x0, x1), max(x0, x1) + y_min, y_max = min(y0, y1), max(y0, y1) + + mask[y_min:y_max, x_min:x_max] = 255 + return Image.fromarray(mask.astype(np.uint8)) + + # XYZ Plot # Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py def tcg_apply_override(field, boolean: bool = False): From b99fd181e1bb8bce1e100dba0db04de8f451dbeb Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Tue, 28 May 2024 23:24:34 -0700 Subject: [PATCH 44/55] Revert "try to fix completely wrong dimensions when swapping token output" This reverts commit 36934d8a50bb3b5011a799cbb01ddb2c85cb98f4. --- scripts/tcg.py | 154 ++++++++++--------------------------------------- 1 file changed, 30 insertions(+), 124 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 4b752e7..40b9d2b 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -1,7 +1,5 @@ import os, sys import logging -import numpy as np -from PIL import Image import gradio as gr import torch import random @@ -27,7 +25,6 @@ logging.basicConfig() logger.setLevel(logging.DEBUG) - """ WIP Implementation of https://arxiv.org/abs/2404.11824 Author: v0xie @@ -35,6 +32,7 @@ """ + class TCGExtensionScript(UIWrapper): def __init__(self): self.infotext_fields: list = [] @@ -54,48 +52,23 @@ def title(self) -> str: def setup_ui(self, is_img2img) -> list: with gr.Accordion('TCG', open=True): - with gr.Row(): - image_mask = gr.Image(type='pil', image_mode='L', label="Mask", interactive=False, height = 256) - with gr.Row(): - generate_mask = gr.Button("Generate Mask", elem_id="tcg_btn_generate_mask") - with gr.Row(): - with gr.Column(): - left = gr.Slider(label="Left", value=0.2, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_left") - with gr.Column(): - with gr.Row(): - top = gr.Slider(label="Top", value=0.3, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_top") - with gr.Row(): - bottom = gr.Slider(label="Bottom", value=0.7, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_bottom") - with gr.Column(): - right = gr.Slider(label="Right", value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_right") - with gr.Row(): - active = gr.Checkbox(label="Active", value=True, elem_id="tcg_active") - strength = gr.Slider(label="Strength", value=1.0, minimum=-5.0, maximum=5.0, step=0.1, elem_id="tcg_strength") - f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_margin") - f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_repl") - theta = gr.Slider(label="Conflict Threshold", value=0.01, minimum=0.0, maximum=1.0, step=0.001, elem_id="tcg_theta") - threshold = gr.Slider(label="Soft Threshold", value=0.5, minimum=0.0, maximum=1.0, step=0.01, elem_id="tcg_attn_threshold") - sharpness = gr.Slider(label="Threshold Sharpness", value=10.0, minimum=0.1, maximum=20.0, step=0.1, elem_id="tcg_sharpness") - selfguidance_scale = gr.Slider(label="Self-Guidance Scale", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_selfguidance_scale") - with gr.Row(): - start_step = gr.Slider(label="Start Step", value=0, minimum=0, maximum=100, step=1, elem_id="tcg_start_step") - end_step = gr.Slider(label="End Step", value=0, minimum=1, maximum=100, step=1, elem_id="tcg_end_step") - opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step] + active = gr.Checkbox(label="Active", value=True, elem_id="tcg_active") + strength = gr.Slider(label="Strength", value=1.0, minimum=-5.0, maximum=5.0, step=0.1, elem_id="tcg_strength") + f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_margin") + f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_repl") + theta = gr.Slider(label="Conflict Threshold", value=0.01, minimum=0.0, maximum=1.0, step=0.001, elem_id="tcg_theta") + threshold = gr.Slider(label="Soft Threshold", value=0.5, minimum=0.0, maximum=1.0, step=0.01, elem_id="tcg_attn_threshold") + sharpness = gr.Slider(label="Threshold Sharpness", value=10.0, minimum=0.1, maximum=20.0, step=0.1, elem_id="tcg_sharpness") + selfguidance_scale = gr.Slider(label="Self-Guidance Scale", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_selfguidance_scale") + opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale] for opt in opts: opt.do_not_save_to_config = True - - generate_mask.click( - create_mask, - inputs = [left, right, top, bottom], - outputs = [image_mask] - ) - return opts def get_modules(self): return module_hooks.get_modules( module_name_filter='CrossAttention') - def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step, *args, **kwargs): + def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, *args, **kwargs): self.unhook_callbacks() active = getattr(p, 'tcg_active', active) if not active: @@ -107,40 +80,24 @@ def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, thr threshold= getattr(p, 'tcg_attn_threshold', threshold) sharpness = getattr(p, 'tcg_sharpness', sharpness) selfguidance_scale = getattr(p, 'tcg_selfguidance_scale', selfguidance_scale) - start_step = getattr(p, 'tcg_start_step', start_step) - end_step = getattr(p, 'tcg_end_step', end_step) batch_size = p.batch_size height, width = p.height, p.width hw = height * width - setattr(p, 'tcg_current_step', 0) - token_count, max_length = prompt_utils.get_token_count(p.prompt, p.steps, is_positive=True) min_idx = 1 max_idx = token_count+1 token_indices = list(range(min_idx, max_idx)) def tcg_forward_hook(module, input, kwargs, output): - current_step = module.tcg_current_step - if not start_step <= current_step <= end_step: - return - # calc attn scores q_map = module.tcg_to_q_map k_map = module.tcg_to_k_map - v_map = module.tcg_to_v_map - - batch_size, seq_len, inner_dim = output.shape - heads = module.heads - # select k tokens - # k_map = k_map.transpose(-1, -2)[..., module.tcg_token_indices].transpose(-1,-2) - - # calc attn scores for q and k - out_attn_scores = get_attention_scores(q_map, k_map, dtype=q_map.dtype) # (2*B, H*W, C) - out_attn_scores = out_attn_scores.to(torch.float32) - attn_scores = out_attn_scores[..., module.tcg_token_indices] # (B, H*W, C) + k_map = k_map.transpose(-1, -2)[..., module.tcg_token_indices].transpose(-1,-2) + attn_scores = get_attention_scores(q_map, k_map, dtype=q_map.dtype) # (2*B, H*W, C) + attn_scores = attn_scores.to(torch.float32) B, HW, C = attn_scores.shape downscale_h = round((HW * (height / width)) ** 0.5) @@ -164,8 +121,7 @@ def tcg_forward_hook(module, input, kwargs, output): self_guidance = obj_appearance self_guidance = self_guidance.to(output.dtype) - self_guidance_factor = out_attn_scores.detach().clone() - #self_guidance_factor = output.detach().clone() + self_guidance_factor = output.detach().clone() self_guidance_factor[..., module.tcg_token_indices] = self_guidance attn_map = attn_map.view(attn_map.size(0), *inner_dims, attn_map.size(-1)) @@ -187,38 +143,26 @@ def tcg_forward_hook(module, input, kwargs, output): return centroids = calculate_centroid(attn_map) # (B, C, 2) + #logger.debug(centroids) displ_force = displacement_force(attn_map, centroids, region_mask_centroid, f_repl, f_margin, clamp = 10) # B C 2 # zero out displacement force displ_force = displ_force * conflicts.unsqueeze(-1) - - #displ_force *= strength - #logger.debug("Displacements: %s", displ_force) - - # apply displacements to the attn map - output_attn_map = attn_map.detach().clone() - modified_attn_map, out_centroids = apply_displacements(output_attn_map, centroids, displ_force) + logger.debug("Displacements: %s", displ_force) - # do the rest of attention - output_attn_map = output_attn_map.view(output_attn_map.size(0), -1, output_attn_map.size(-1)) - out_attn_scores[..., module.tcg_token_indices] = output_attn_map - new_output = out_attn_scores @ v_map.to(out_attn_scores.dtype) # (B, HW, C) @ (B, HW, C) -> (B, HW, C) - new_output = new_output.to(output.dtype) + # reassign the attn map + output_attn_map = output.detach().clone() + modified_attn_map, out_centroids = apply_displacements(output_attn_map[..., module.tcg_token_indices].unsqueeze(0), centroids, displ_force) + output_attn_map[..., module.tcg_token_indices] = modified_attn_map.squeeze(0) - new_output = module.to_out[0](new_output) - new_output = module.to_out[1](new_output) - - logger.debug("old: %s\nnew: %s\ndisplacements: %s\ndisplaced centroids: %s", centroids, out_centroids, displ_force, out_centroids - centroids) - - #loss = strength * torch.norm(output - output_attn_map, dim=-1) ** 2 + selfguidance_scale * self_guidance_factor - loss = output - new_output + loss = output - output_attn_map loss = normalize_map(loss) loss **= 2 - loss *= strength - loss += selfguidance_scale * self_guidance_factor + - output += loss + output += strength * loss + output += selfguidance_scale * self_guidance_factor #output = output_attn_map @@ -227,39 +171,28 @@ def tcg_to_q_hook(module, input, kwargs, output): def tcg_to_k_hook(module, input, kwargs, output): setattr(module.tcg_parent_module[0], 'tcg_to_k_map', output) - - def tcg_to_v_hook(module, input, kwargs, output): - setattr(module.tcg_parent_module[0], 'tcg_to_v_map', output) def cfg_denoised_callback(params: script_callbacks.CFGDenoisedParams): - for module in self.get_modules(): - setattr(module, 'tcg_current_step', p.tcg_current_step) - p.tcg_current_step += 1 + pass script_callbacks.on_cfg_denoised(cfg_denoised_callback) - mask_H, mask_W = image_mask.size - temp_region_mask = torch.from_numpy(np.array(image_mask)).unsqueeze(-1).unsqueeze(0).to(torch.float32).to(shared.device) # (1, H, W, 1) - temp_region_mask = temp_region_mask.repeat(batch_size, 1, 1, 1) # (B, H, W, 1) - # temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) - #temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) - #temp_region_mask[0, 1*mask_H//8 : 7*mask_H//8 , 2*mask_W//8 : 5*mask_W//8] = 1.0 # mask the left half ish of the canvas + # TODO: Parameterize this + mask_H, mask_W = 64, 64 + temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) + temp_region_mask[0, 1*mask_H//8 : 7*mask_H//8 , 2*mask_W//8 : 5*mask_W//8] = 1.0 # mask the left half ish of the canvas for module in self.get_modules(): if not module.network_layer_name.endswith('attn2'): continue - module_hooks.modules_add_field(module, 'tcg_current_step', 0) module_hooks.modules_add_field(module, 'tcg_to_q_map', None) module_hooks.modules_add_field(module, 'tcg_to_k_map', None) - module_hooks.modules_add_field(module, 'tcg_to_v_map', None) module_hooks.modules_add_field(module, 'tcg_region_mask', temp_region_mask) module_hooks.modules_add_field(module, 'tcg_token_indices', torch.tensor(token_indices, dtype=torch.int32, device=shared.device)) module_hooks.modules_add_field(module.to_q, 'tcg_parent_module', [module]) module_hooks.modules_add_field(module.to_k, 'tcg_parent_module', [module]) - module_hooks.modules_add_field(module.to_v, 'tcg_parent_module', [module]) module_hooks.module_add_forward_hook(module.to_q, tcg_to_q_hook, with_kwargs=True) module_hooks.module_add_forward_hook(module.to_k, tcg_to_k_hook, with_kwargs=True) - module_hooks.module_add_forward_hook(module.to_v, tcg_to_v_hook, with_kwargs=True) module_hooks.module_add_forward_hook(module, tcg_forward_hook, with_kwargs=True) def postprocess_batch(self, p, *args, **kwargs): @@ -270,17 +203,13 @@ def unhook_callbacks(self) -> None: for module in self.get_modules(): module_hooks.remove_module_forward_hook(module.to_q, 'tcg_to_q_hook') module_hooks.remove_module_forward_hook(module.to_k, 'tcg_to_k_hook') - module_hooks.remove_module_forward_hook(module.to_v, 'tcg_to_v_hook') module_hooks.remove_module_forward_hook(module, 'tcg_forward_hook') - module_hooks.modules_remove_field(module, 'tcg_current_step') module_hooks.modules_remove_field(module, 'tcg_to_q_map') module_hooks.modules_remove_field(module, 'tcg_to_k_map') - module_hooks.modules_remove_field(module, 'tcg_to_v_map') module_hooks.modules_remove_field(module, 'tcg_region_mask') module_hooks.modules_remove_field(module, 'tcg_token_indices') module_hooks.modules_remove_field(module.to_q, 'tcg_parent_module') module_hooks.modules_remove_field(module.to_k, 'tcg_parent_module') - module_hooks.modules_remove_field(module.to_v, 'tcg_parent_module') def before_process(self, p, *args, **kwargs): pass @@ -988,29 +917,6 @@ def _png_batch(attnmap, name, title): _png(copied_map[..., c].unsqueeze(-1), img_idx+ofs+i, f'Displacement Forces Channel {c} Step {i}') #_png(copied_map, img_idx+i, f'Displacement Forces Step {i}') - -def create_mask(left, right, top, bottom, width=256, height=256): - """ Create a PIL.Image mask for the region bounded by the given normalized coordinates - Arguments: - left: float - The left coordinate of the region - right: float - The right coordinate of the region - top: float - The top coordinate of the region - bottom: float - The bottom coordinate of the region - width: int - The width of the mask - height: int - The height of the mask - Returns: - PIL.Image - The mask image - """ - mask = np.zeros((height, width), dtype=np.uint8) - x0, x1 = int(left * width), int(right * width) - y0, y1 = int(top * height), int(bottom * height) - x_min, x_max = min(x0, x1), max(x0, x1) - y_min, y_max = min(y0, y1), max(y0, y1) - - mask[y_min:y_max, x_min:x_max] = 255 - return Image.fromarray(mask.astype(np.uint8)) - - # XYZ Plot # Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py def tcg_apply_override(field, boolean: bool = False): From 498ba271cfb0ef5152471d758a4219b7d7b3b053 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 29 May 2024 17:05:10 -0700 Subject: [PATCH 45/55] try to fix dims again --- scripts/tcg.py | 110 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 88 insertions(+), 22 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 40b9d2b..7de7c4c 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -4,6 +4,7 @@ import torch import random import torch.nn.functional as F +from einops import rearrange if os.environ.get('INCANT_DEBUG', None): # suppress excess logging @@ -92,39 +93,52 @@ def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, thr def tcg_forward_hook(module, input, kwargs, output): # calc attn scores - q_map = module.tcg_to_q_map - k_map = module.tcg_to_k_map + q_map = module.tcg_to_q_map # B, HW, inner_dim + k_map = module.tcg_to_k_map # B, C, inner_dim + v_map = module.tcg_to_v_map # B, C, inner_dim + + # q_map = prepare_attn_map(q_map, module.heads) + # k_map = prepare_attn_map(k_map, module.heads) + # v_map = prepare_attn_map(v_map, module.heads) + + attn_scores = q_map @ k_map.transpose(-1, -2) + attn_scores *= module.scale + #channel_dim = q_map.shape[-1] + #attn_scores /= (channel_dim ** 0.5) + attn_scores = attn_scores.softmax(dim=-1).to(device=shared.device, dtype=q_map.dtype) + # select k tokens - k_map = k_map.transpose(-1, -2)[..., module.tcg_token_indices].transpose(-1,-2) - attn_scores = get_attention_scores(q_map, k_map, dtype=q_map.dtype) # (2*B, H*W, C) + # k_map = k_map.transpose(-1, -2)[..., module.tcg_token_indices].transpose(-1,-2) + #attn_scores = get_attention_scores(q_map, k_map, dtype=q_map.dtype) # (2*B, H*W, C) attn_scores = attn_scores.to(torch.float32) B, HW, C = attn_scores.shape downscale_h = round((HW * (height / width)) ** 0.5) - attn_scores = attn_scores.view(2, attn_scores.size(0)//2, downscale_h, HW//downscale_h, attn_scores.size(-1)).mean(dim=0) # (2*B, HW, C) -> (B, H, W, C) + # attn_scores = attn_scores.view(2, attn_scores.size(0)//2, downscale_h, HW//downscale_h, attn_scores.size(-1)).mean(dim=0) # (2*B, HW, C) -> (B, H, W, C) + attn_scores = attn_scores.view(attn_scores.size(0), downscale_h, HW//downscale_h, attn_scores.size(-1)) # (2*B, HW, C) -> (B, H, W, C) # slice attn map - attn_map = attn_scores.detach().clone() # (B, H, W, K) where K is the subset of tokens + attn_map = attn_scores[..., module.tcg_token_indices].detach().clone() # (B, H, W, K) where K is the subset of tokens # threshold it # also represents object shape - attn_map = soft_threshold(attn_map, threshold=threshold, sharpness=sharpness) + attn_map = soft_threshold(attn_map, threshold=threshold, sharpness=sharpness) # B H W C # self-guidance - inner_dims = attn_map.shape[1:-1] - attn_map = attn_map.view(attn_map.size(0), -1, attn_map.size(-1)) + # inner_dims = attn_map.shape[1:-1] + # attn_map = attn_map.view(attn_map.size(0), -1, attn_map.size(-1)) - shape_sum = torch.sum(attn_map, dim=1) # (B, HW) + # shape_sum = torch.sum(attn_map, dim=1) # (B, HW) - obj_appearance = shape_sum * attn_map - obj_appearance /= shape_sum + # obj_appearance = shape_sum * attn_map + # obj_appearance /= shape_sum - self_guidance = obj_appearance - self_guidance = self_guidance.to(output.dtype) - self_guidance_factor = output.detach().clone() - self_guidance_factor[..., module.tcg_token_indices] = self_guidance + # self_guidance = obj_appearance + # self_guidance = self_guidance.to(output.dtype) + # self_guidance_factor = output.detach().clone() + # self_guidance_factor[..., module.tcg_token_indices] = self_guidance - attn_map = attn_map.view(attn_map.size(0), *inner_dims, attn_map.size(-1)) + # attn_map = attn_map.view(attn_map.size(0), *inner_dims, attn_map.size(-1)) # region mask region_mask = module.tcg_region_mask @@ -151,18 +165,23 @@ def tcg_forward_hook(module, input, kwargs, output): displ_force = displ_force * conflicts.unsqueeze(-1) logger.debug("Displacements: %s", displ_force) - # reassign the attn map - output_attn_map = output.detach().clone() - modified_attn_map, out_centroids = apply_displacements(output_attn_map[..., module.tcg_token_indices].unsqueeze(0), centroids, displ_force) + # modify the attn map + output_attn_map = attn_scores.detach().clone() # B H W C + + modified_attn_map, out_centroids = apply_displacements(output_attn_map[..., module.tcg_token_indices], centroids, displ_force) output_attn_map[..., module.tcg_token_indices] = modified_attn_map.squeeze(0) + output_attn_map = output_attn_map.view(B, -1, C) # B HW C + # output_attn_map = output_attn_map.permute(0, 2, 1) # B C HW + output_attn_map = output_attn_map @ v_map + loss = output - output_attn_map loss = normalize_map(loss) loss **= 2 output += strength * loss - output += selfguidance_scale * self_guidance_factor + # output += selfguidance_scale * self_guidance_factor #output = output_attn_map @@ -172,6 +191,9 @@ def tcg_to_q_hook(module, input, kwargs, output): def tcg_to_k_hook(module, input, kwargs, output): setattr(module.tcg_parent_module[0], 'tcg_to_k_map', output) + def tcg_to_v_hook(module, input, kwargs, output): + setattr(module.tcg_parent_module[0], 'tcg_to_v_map', output) + def cfg_denoised_callback(params: script_callbacks.CFGDenoisedParams): pass @@ -187,12 +209,15 @@ def cfg_denoised_callback(params: script_callbacks.CFGDenoisedParams): continue module_hooks.modules_add_field(module, 'tcg_to_q_map', None) module_hooks.modules_add_field(module, 'tcg_to_k_map', None) + module_hooks.modules_add_field(module, 'tcg_to_v_map', None) module_hooks.modules_add_field(module, 'tcg_region_mask', temp_region_mask) module_hooks.modules_add_field(module, 'tcg_token_indices', torch.tensor(token_indices, dtype=torch.int32, device=shared.device)) module_hooks.modules_add_field(module.to_q, 'tcg_parent_module', [module]) module_hooks.modules_add_field(module.to_k, 'tcg_parent_module', [module]) + module_hooks.modules_add_field(module.to_v, 'tcg_parent_module', [module]) module_hooks.module_add_forward_hook(module.to_q, tcg_to_q_hook, with_kwargs=True) module_hooks.module_add_forward_hook(module.to_k, tcg_to_k_hook, with_kwargs=True) + module_hooks.module_add_forward_hook(module.to_v, tcg_to_v_hook, with_kwargs=True) module_hooks.module_add_forward_hook(module, tcg_forward_hook, with_kwargs=True) def postprocess_batch(self, p, *args, **kwargs): @@ -203,13 +228,16 @@ def unhook_callbacks(self) -> None: for module in self.get_modules(): module_hooks.remove_module_forward_hook(module.to_q, 'tcg_to_q_hook') module_hooks.remove_module_forward_hook(module.to_k, 'tcg_to_k_hook') + module_hooks.remove_module_forward_hook(module.to_v, 'tcg_to_v_hook') module_hooks.remove_module_forward_hook(module, 'tcg_forward_hook') module_hooks.modules_remove_field(module, 'tcg_to_q_map') module_hooks.modules_remove_field(module, 'tcg_to_k_map') + module_hooks.modules_remove_field(module, 'tcg_to_v_map') module_hooks.modules_remove_field(module, 'tcg_region_mask') module_hooks.modules_remove_field(module, 'tcg_token_indices') module_hooks.modules_remove_field(module.to_q, 'tcg_parent_module') module_hooks.modules_remove_field(module.to_k, 'tcg_parent_module') + module_hooks.modules_remove_field(module.to_v, 'tcg_parent_module') def before_process(self, p, *args, **kwargs): pass @@ -931,4 +959,42 @@ def fun(p, x, xs): if not hasattr(p, "tcg_active"): p.tcg_active = True setattr(p, field, x) - return fun \ No newline at end of file + return fun + + +def prepare_attn_map(to_k_map, heads): + to_k_map = head_to_batch_dim(to_k_map, heads) + to_k_map = average_over_head_dim(to_k_map, heads) + to_k_map = torch.stack([to_k_map[0], to_k_map[0]], dim=0) + return to_k_map + + +# based on diffusers/models/attention_processor.py Attention head_to_batch_dim +def head_to_batch_dim(x, heads, out_dim=3): + head_size = heads + if x.ndim == 3: + + batch_size, seq_len, dim = x.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) + x = x.permute(0, 2, 1, 3) + if out_dim == 3: + x = x.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) + return x + + +# based on diffusers/models/attention_processor.py Attention batch_to_head_dim +def batch_to_head_dim(x, heads): + head_size = heads + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // head_size, head_size, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return x + + +def average_over_head_dim(x, heads): + x = rearrange(x, '(b h) s t -> b h s t', h=heads).mean(1) + return x + From 5f69577177a3f6a524bde54c2819989ee3dc3165 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 29 May 2024 17:13:29 -0700 Subject: [PATCH 46/55] restore image mask element --- scripts/tcg.py | 98 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 81 insertions(+), 17 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 7de7c4c..9277ff1 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -1,5 +1,7 @@ import os, sys import logging +import numpy as np +from PIL import Image import gradio as gr import torch import random @@ -26,6 +28,7 @@ logging.basicConfig() logger.setLevel(logging.DEBUG) + """ WIP Implementation of https://arxiv.org/abs/2404.11824 Author: v0xie @@ -53,23 +56,48 @@ def title(self) -> str: def setup_ui(self, is_img2img) -> list: with gr.Accordion('TCG', open=True): - active = gr.Checkbox(label="Active", value=True, elem_id="tcg_active") - strength = gr.Slider(label="Strength", value=1.0, minimum=-5.0, maximum=5.0, step=0.1, elem_id="tcg_strength") - f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_margin") - f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_repl") - theta = gr.Slider(label="Conflict Threshold", value=0.01, minimum=0.0, maximum=1.0, step=0.001, elem_id="tcg_theta") - threshold = gr.Slider(label="Soft Threshold", value=0.5, minimum=0.0, maximum=1.0, step=0.01, elem_id="tcg_attn_threshold") - sharpness = gr.Slider(label="Threshold Sharpness", value=10.0, minimum=0.1, maximum=20.0, step=0.1, elem_id="tcg_sharpness") - selfguidance_scale = gr.Slider(label="Self-Guidance Scale", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_selfguidance_scale") - opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale] + with gr.Row(): + image_mask = gr.Image(type='pil', image_mode='L', label="Mask", interactive=False, height = 256) + with gr.Row(): + generate_mask = gr.Button("Generate Mask", elem_id="tcg_btn_generate_mask") + with gr.Row(): + with gr.Column(): + left = gr.Slider(label="Left", value=0.2, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_left") + with gr.Column(): + with gr.Row(): + top = gr.Slider(label="Top", value=0.3, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_top") + with gr.Row(): + bottom = gr.Slider(label="Bottom", value=0.7, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_bottom") + with gr.Column(): + right = gr.Slider(label="Right", value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="tcg_mask_right") + with gr.Row(): + active = gr.Checkbox(label="Active", value=True, elem_id="tcg_active") + strength = gr.Slider(label="Strength", value=1.0, minimum=-5.0, maximum=5.0, step=0.1, elem_id="tcg_strength") + f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_margin") + f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_repl") + theta = gr.Slider(label="Conflict Threshold", value=0.01, minimum=0.0, maximum=1.0, step=0.001, elem_id="tcg_theta") + threshold = gr.Slider(label="Soft Threshold", value=0.5, minimum=0.0, maximum=1.0, step=0.01, elem_id="tcg_attn_threshold") + sharpness = gr.Slider(label="Threshold Sharpness", value=10.0, minimum=0.1, maximum=20.0, step=0.1, elem_id="tcg_sharpness") + selfguidance_scale = gr.Slider(label="Self-Guidance Scale", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_selfguidance_scale") + with gr.Row(): + start_step = gr.Slider(label="Start Step", value=0, minimum=0, maximum=100, step=1, elem_id="tcg_start_step") + end_step = gr.Slider(label="End Step", value=0, minimum=1, maximum=100, step=1, elem_id="tcg_end_step") + opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step] for opt in opts: opt.do_not_save_to_config = True + + generate_mask.click( + create_mask, + inputs = [left, right, top, bottom], + outputs = [image_mask] + ) + return opts def get_modules(self): return module_hooks.get_modules( module_name_filter='CrossAttention') - def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, *args, **kwargs): + def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step, *args, **kwargs): self.unhook_callbacks() active = getattr(p, 'tcg_active', active) if not active: @@ -81,17 +109,24 @@ def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, thr threshold= getattr(p, 'tcg_attn_threshold', threshold) sharpness = getattr(p, 'tcg_sharpness', sharpness) selfguidance_scale = getattr(p, 'tcg_selfguidance_scale', selfguidance_scale) + start_step = getattr(p, 'tcg_start_step', start_step) + end_step = getattr(p, 'tcg_end_step', end_step) batch_size = p.batch_size height, width = p.height, p.width hw = height * width + setattr(p, 'tcg_current_step', 0) + token_count, max_length = prompt_utils.get_token_count(p.prompt, p.steps, is_positive=True) min_idx = 1 max_idx = token_count+1 token_indices = list(range(min_idx, max_idx)) def tcg_forward_hook(module, input, kwargs, output): + current_step = module.tcg_current_step + if not start_step <= current_step <= end_step: + return # calc attn scores q_map = module.tcg_to_q_map # B, HW, inner_dim k_map = module.tcg_to_k_map # B, C, inner_dim @@ -190,23 +225,28 @@ def tcg_to_q_hook(module, input, kwargs, output): def tcg_to_k_hook(module, input, kwargs, output): setattr(module.tcg_parent_module[0], 'tcg_to_k_map', output) - + def tcg_to_v_hook(module, input, kwargs, output): setattr(module.tcg_parent_module[0], 'tcg_to_v_map', output) - + def cfg_denoised_callback(params: script_callbacks.CFGDenoisedParams): - pass + for module in self.get_modules(): + setattr(module, 'tcg_current_step', p.tcg_current_step) + p.tcg_current_step += 1 script_callbacks.on_cfg_denoised(cfg_denoised_callback) - # TODO: Parameterize this - mask_H, mask_W = 64, 64 - temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) - temp_region_mask[0, 1*mask_H//8 : 7*mask_H//8 , 2*mask_W//8 : 5*mask_W//8] = 1.0 # mask the left half ish of the canvas + mask_H, mask_W = image_mask.size + temp_region_mask = torch.from_numpy(np.array(image_mask)).unsqueeze(-1).unsqueeze(0).to(torch.float32).to(shared.device) # (1, H, W, 1) + temp_region_mask = temp_region_mask.repeat(batch_size, 1, 1, 1) # (B, H, W, 1) + # temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) + #temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) + #temp_region_mask[0, 1*mask_H//8 : 7*mask_H//8 , 2*mask_W//8 : 5*mask_W//8] = 1.0 # mask the left half ish of the canvas for module in self.get_modules(): if not module.network_layer_name.endswith('attn2'): continue + module_hooks.modules_add_field(module, 'tcg_current_step', 0) module_hooks.modules_add_field(module, 'tcg_to_q_map', None) module_hooks.modules_add_field(module, 'tcg_to_k_map', None) module_hooks.modules_add_field(module, 'tcg_to_v_map', None) @@ -230,6 +270,7 @@ def unhook_callbacks(self) -> None: module_hooks.remove_module_forward_hook(module.to_k, 'tcg_to_k_hook') module_hooks.remove_module_forward_hook(module.to_v, 'tcg_to_v_hook') module_hooks.remove_module_forward_hook(module, 'tcg_forward_hook') + module_hooks.modules_remove_field(module, 'tcg_current_step') module_hooks.modules_remove_field(module, 'tcg_to_q_map') module_hooks.modules_remove_field(module, 'tcg_to_k_map') module_hooks.modules_remove_field(module, 'tcg_to_v_map') @@ -945,6 +986,29 @@ def _png_batch(attnmap, name, title): _png(copied_map[..., c].unsqueeze(-1), img_idx+ofs+i, f'Displacement Forces Channel {c} Step {i}') #_png(copied_map, img_idx+i, f'Displacement Forces Step {i}') + +def create_mask(left, right, top, bottom, width=256, height=256): + """ Create a PIL.Image mask for the region bounded by the given normalized coordinates + Arguments: + left: float - The left coordinate of the region + right: float - The right coordinate of the region + top: float - The top coordinate of the region + bottom: float - The bottom coordinate of the region + width: int - The width of the mask + height: int - The height of the mask + Returns: + PIL.Image - The mask image + """ + mask = np.zeros((height, width), dtype=np.uint8) + x0, x1 = int(left * width), int(right * width) + y0, y1 = int(top * height), int(bottom * height) + x_min, x_max = min(x0, x1), max(x0, x1) + y_min, y_max = min(y0, y1), max(y0, y1) + + mask[y_min:y_max, x_min:x_max] = 255 + return Image.fromarray(mask.astype(np.uint8)) + + # XYZ Plot # Based on @mcmonkey4eva's XYZ Plot implementation here: https://github.com/mcmonkeyprojects/sd-dynamic-thresholding/blob/master/scripts/dynamic_thresholding.py def tcg_apply_override(field, boolean: bool = False): From 408fc994eeee2bbcb0a4653fe10911e8f22315e1 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 29 May 2024 18:17:22 -0700 Subject: [PATCH 47/55] attempt to fix dims --- scripts/tcg.py | 64 +++++++++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 9277ff1..cd74669 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -79,10 +79,11 @@ def setup_ui(self, is_img2img) -> list: threshold = gr.Slider(label="Soft Threshold", value=0.5, minimum=0.0, maximum=1.0, step=0.01, elem_id="tcg_attn_threshold") sharpness = gr.Slider(label="Threshold Sharpness", value=10.0, minimum=0.1, maximum=20.0, step=0.1, elem_id="tcg_sharpness") selfguidance_scale = gr.Slider(label="Self-Guidance Scale", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_selfguidance_scale") + clamp = gr.Slider(label="Force Clamp", value=1.0, minimum=0.0, maximum=20.0, step=1.0, elem_id="tcg_clamp") with gr.Row(): start_step = gr.Slider(label="Start Step", value=0, minimum=0, maximum=100, step=1, elem_id="tcg_start_step") end_step = gr.Slider(label="End Step", value=0, minimum=1, maximum=100, step=1, elem_id="tcg_end_step") - opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step] + opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step, clamp] for opt in opts: opt.do_not_save_to_config = True @@ -97,7 +98,7 @@ def setup_ui(self, is_img2img) -> list: def get_modules(self): return module_hooks.get_modules( module_name_filter='CrossAttention') - def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step, *args, **kwargs): + def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step, clamp, *args, **kwargs): self.unhook_callbacks() active = getattr(p, 'tcg_active', active) if not active: @@ -111,6 +112,7 @@ def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, thr selfguidance_scale = getattr(p, 'tcg_selfguidance_scale', selfguidance_scale) start_step = getattr(p, 'tcg_start_step', start_step) end_step = getattr(p, 'tcg_end_step', end_step) + clamp = getattr(p, 'tcg_clamp', clamp) batch_size = p.batch_size height, width = p.height, p.width @@ -132,14 +134,10 @@ def tcg_forward_hook(module, input, kwargs, output): k_map = module.tcg_to_k_map # B, C, inner_dim v_map = module.tcg_to_v_map # B, C, inner_dim - # q_map = prepare_attn_map(q_map, module.heads) - # k_map = prepare_attn_map(k_map, module.heads) - # v_map = prepare_attn_map(v_map, module.heads) - attn_scores = q_map @ k_map.transpose(-1, -2) - attn_scores *= module.scale - #channel_dim = q_map.shape[-1] - #attn_scores /= (channel_dim ** 0.5) + #attn_scores *= module.scale + channel_dim = q_map.shape[-1] + attn_scores /= (channel_dim ** 0.5) attn_scores = attn_scores.softmax(dim=-1).to(device=shared.device, dtype=q_map.dtype) # select k tokens @@ -160,20 +158,21 @@ def tcg_forward_hook(module, input, kwargs, output): attn_map = soft_threshold(attn_map, threshold=threshold, sharpness=sharpness) # B H W C # self-guidance - # inner_dims = attn_map.shape[1:-1] - # attn_map = attn_map.view(attn_map.size(0), -1, attn_map.size(-1)) + inner_dims = attn_map.shape[1:-1] + attn_map = attn_map.view(attn_map.size(0), -1, attn_map.size(-1)) + + shape_sum = torch.sum(attn_map, dim=1, keepdim=True) # (B, HW) - # shape_sum = torch.sum(attn_map, dim=1) # (B, HW) + obj_appearance = shape_sum * attn_map + obj_appearance /= shape_sum + torch.finfo(torch.float32).eps - # obj_appearance = shape_sum * attn_map - # obj_appearance /= shape_sum + self_guidance = obj_appearance - # self_guidance = obj_appearance - # self_guidance = self_guidance.to(output.dtype) - # self_guidance_factor = output.detach().clone() - # self_guidance_factor[..., module.tcg_token_indices] = self_guidance + self_guidance_factor = attn_scores.detach().clone() + self_guidance_factor = self_guidance_factor.view(B, -1, C) + self_guidance_factor[..., module.tcg_token_indices] = self_guidance - # attn_map = attn_map.view(attn_map.size(0), *inner_dims, attn_map.size(-1)) + attn_map = attn_map.view(attn_map.size(0), *inner_dims, attn_map.size(-1)) # region mask region_mask = module.tcg_region_mask @@ -194,29 +193,39 @@ def tcg_forward_hook(module, input, kwargs, output): centroids = calculate_centroid(attn_map) # (B, C, 2) #logger.debug(centroids) - displ_force = displacement_force(attn_map, centroids, region_mask_centroid, f_repl, f_margin, clamp = 10) # B C 2 + displ_force = displacement_force(attn_map, centroids, region_mask_centroid, f_repl, f_margin, clamp = clamp) # B C 2 # zero out displacement force displ_force = displ_force * conflicts.unsqueeze(-1) logger.debug("Displacements: %s", displ_force) - # modify the attn map + # apply displacements output_attn_map = attn_scores.detach().clone() # B H W C + modified_attn_map, out_centroids = apply_displacements( + output_attn_map[..., module.tcg_token_indices], + #output_attn_map[..., module.tcg_token_indices], + centroids, + displ_force + ) # B H W C - modified_attn_map, out_centroids = apply_displacements(output_attn_map[..., module.tcg_token_indices], centroids, displ_force) - output_attn_map[..., module.tcg_token_indices] = modified_attn_map.squeeze(0) + # zero out area in region for all conflict tokens + modified_attn_map = modified_attn_map * (1 - region_mask) + output_attn_map[..., module.tcg_token_indices] = modified_attn_map.squeeze(0) output_attn_map = output_attn_map.view(B, -1, C) # B HW C + # output_attn_map = output_attn_map.permute(0, 2, 1) # B C HW + orig_attn_map = attn_scores.view(B, -1, C) # B HW C + orig_attn_map = orig_attn_map @ v_map + output_attn_map = output_attn_map @ v_map + self_guidance_factor = self_guidance_factor @ v_map loss = output - output_attn_map - loss = normalize_map(loss) - loss **= 2 - + #loss **= 2 output += strength * loss - # output += selfguidance_scale * self_guidance_factor + output += selfguidance_scale * self_guidance_factor #output = output_attn_map @@ -238,6 +247,7 @@ def cfg_denoised_callback(params: script_callbacks.CFGDenoisedParams): mask_H, mask_W = image_mask.size temp_region_mask = torch.from_numpy(np.array(image_mask)).unsqueeze(-1).unsqueeze(0).to(torch.float32).to(shared.device) # (1, H, W, 1) + temp_region_mask /= 255.0 temp_region_mask = temp_region_mask.repeat(batch_size, 1, 1, 1) # (B, H, W, 1) # temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) #temp_region_mask = torch.zeros((batch_size, mask_H, mask_W, 1), dtype=torch.float32, device=shared.device) # (B, H, W) From de93eb43e8b581d6db0bc09913d8156aa11f30a8 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 30 May 2024 13:03:39 -0700 Subject: [PATCH 48/55] reimplement sdp for sanity check --- scripts/tcg.py | 84 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index cd74669..33bc6cc 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -82,7 +82,7 @@ def setup_ui(self, is_img2img) -> list: clamp = gr.Slider(label="Force Clamp", value=1.0, minimum=0.0, maximum=20.0, step=1.0, elem_id="tcg_clamp") with gr.Row(): start_step = gr.Slider(label="Start Step", value=0, minimum=0, maximum=100, step=1, elem_id="tcg_start_step") - end_step = gr.Slider(label="End Step", value=0, minimum=1, maximum=100, step=1, elem_id="tcg_end_step") + end_step = gr.Slider(label="End Step", value=100, minimum=1, maximum=100, step=1, elem_id="tcg_end_step") opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step, clamp] for opt in opts: opt.do_not_save_to_config = True @@ -129,16 +129,36 @@ def tcg_forward_hook(module, input, kwargs, output): current_step = module.tcg_current_step if not start_step <= current_step <= end_step: return + if not len(token_indices) > 0: + return + + heads = module.heads # calc attn scores - q_map = module.tcg_to_q_map # B, HW, inner_dim - k_map = module.tcg_to_k_map # B, C, inner_dim - v_map = module.tcg_to_v_map # B, C, inner_dim + q_map = head_to_batch_dim(module.tcg_to_q_map, heads) # B, heads, HW, inner_dim + k_map = head_to_batch_dim(module.tcg_to_k_map, heads) # B, heads, C, inner_dim + v_map = head_to_batch_dim(module.tcg_to_v_map, heads) # B, heads, C, inner_dim + batch_size, _, seq_len, inner_dim = q_map.shape + # attention attn_scores = q_map @ k_map.transpose(-1, -2) - #attn_scores *= module.scale - channel_dim = q_map.shape[-1] - attn_scores /= (channel_dim ** 0.5) - attn_scores = attn_scores.softmax(dim=-1).to(device=shared.device, dtype=q_map.dtype) + + # attn_scores *= module.scale # same as dividing by channel dim + #channel_dim = q_map.shape[-1] + #attn_scores /= (channel_dim ** 0.5) + + hidden_states = torch.nn.functional.scaled_dot_product_attention( + q_map, k_map, v_map, attn_mask=None, dropout_p=0.0, is_causal=False + ) + # attn_scores = attn_scores.softmax(dim=-1).to(device=shared.device, dtype=q_map.dtype) + #orig_attn_map = attn_scores @ v_map + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * inner_dim) + #hidden_states = batch_to_head_dim(attn_scores, heads) # B, HW, C + + hidden_states = module.to_out[0](hidden_states) + # dropout + hidden_states = module.to_out[1](hidden_states) + + return hidden_states # select k tokens # k_map = k_map.transpose(-1, -2)[..., module.tcg_token_indices].transpose(-1,-2) @@ -221,8 +241,12 @@ def tcg_forward_hook(module, input, kwargs, output): output_attn_map = output_attn_map @ v_map self_guidance_factor = self_guidance_factor @ v_map - loss = output - output_attn_map - #loss **= 2 + loss = output_attn_map + #loss = output - output_attn_map + return loss + + loss = normalize_map(loss) + loss **= 2 output += strength * loss output += selfguidance_scale * self_guidance_factor @@ -1044,28 +1068,32 @@ def prepare_attn_map(to_k_map, heads): # based on diffusers/models/attention_processor.py Attention head_to_batch_dim -def head_to_batch_dim(x, heads, out_dim=3): - head_size = heads - if x.ndim == 3: - - batch_size, seq_len, dim = x.shape - extra_dim = 1 - else: - batch_size, extra_dim, seq_len, dim = x.shape - x = x.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size) - x = x.permute(0, 2, 1, 3) - if out_dim == 3: - x = x.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size) - return x +def head_to_batch_dim(x, heads): + """ Rearrange the tensor to have the head dimension as the batch dimension + Arguments: + x: torch.Tensor - The tensor to rearrange (B, HW, D), where D is heads * dim + heads: int - The number of heads + Returns: + torch.Tensor - The rearranged tensor (B, heads, HW, D // heads) + """ + B, HW, D = x.shape + x = x.view(B, HW, heads, D // heads).transpose(1, 2) + return x # based on diffusers/models/attention_processor.py Attention batch_to_head_dim def batch_to_head_dim(x, heads): - head_size = heads - batch_size, seq_len, dim = x.shape - x = x.reshape(batch_size // head_size, head_size, seq_len, dim) - x = x.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return x + """ Rearrange the tensor to have the batch dimension as the head dimension + Arguments: + x: torch.Tensor - The tensor to rearrange (B, heads, HW, D // heads) + heads: int - The number of heads + Returns: + torch.Tensor - The rearranged tensor (B, HW, D) + """ + B, head_dim, HW, D = x.shape + x = x.transpose(1, 2) # (B, HW, heads, D // heads) + x = x.reshape(B, HW, head_dim * D) + return x def average_over_head_dim(x, heads): From 1b3366845433802f03d7fdc453ab284c3799ae91 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 30 May 2024 13:10:14 -0700 Subject: [PATCH 49/55] usesoftmax --- scripts/tcg.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 33bc6cc..ae8c915 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -132,30 +132,25 @@ def tcg_forward_hook(module, input, kwargs, output): if not len(token_indices) > 0: return + # sdp attention from sd_hijack_optimizations.py heads = module.heads - # calc attn scores q_map = head_to_batch_dim(module.tcg_to_q_map, heads) # B, heads, HW, inner_dim k_map = head_to_batch_dim(module.tcg_to_k_map, heads) # B, heads, C, inner_dim v_map = head_to_batch_dim(module.tcg_to_v_map, heads) # B, heads, C, inner_dim batch_size, _, seq_len, inner_dim = q_map.shape - - # attention - attn_scores = q_map @ k_map.transpose(-1, -2) - - # attn_scores *= module.scale # same as dividing by channel dim - #channel_dim = q_map.shape[-1] - #attn_scores /= (channel_dim ** 0.5) - - hidden_states = torch.nn.functional.scaled_dot_product_attention( - q_map, k_map, v_map, attn_mask=None, dropout_p=0.0, is_causal=False - ) - # attn_scores = attn_scores.softmax(dim=-1).to(device=shared.device, dtype=q_map.dtype) - #orig_attn_map = attn_scores @ v_map + # sdp from pytorch + L, S = q_map.size(-2), k_map.size(-2) + scale_factor = module.scale + #attn_bias = torch.zeros(L, S, dtype=q_map.dtype) + attn_scores = q_map @ k_map.transpose(-1, -2) * scale_factor + #attn_scores += attn_bias + attn_scores = torch.softmax(attn_scores, dim=-1) + attn_scores @= v_map + hidden_states = attn_scores + + # to output map hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * inner_dim) - #hidden_states = batch_to_head_dim(attn_scores, heads) # B, HW, C - hidden_states = module.to_out[0](hidden_states) - # dropout hidden_states = module.to_out[1](hidden_states) return hidden_states From 583b31a7791716c249b7d860a5a3c9c4a7617ad8 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 30 May 2024 13:22:59 -0700 Subject: [PATCH 50/55] fixed soft threshold normalize over wrong dim --- scripts/tcg.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index ae8c915..14ca252 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -141,36 +141,38 @@ def tcg_forward_hook(module, input, kwargs, output): # sdp from pytorch L, S = q_map.size(-2), k_map.size(-2) scale_factor = module.scale - #attn_bias = torch.zeros(L, S, dtype=q_map.dtype) attn_scores = q_map @ k_map.transpose(-1, -2) * scale_factor - #attn_scores += attn_bias attn_scores = torch.softmax(attn_scores, dim=-1) - attn_scores @= v_map - hidden_states = attn_scores - # to output map - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * inner_dim) - hidden_states = module.to_out[0](hidden_states) - hidden_states = module.to_out[1](hidden_states) - return hidden_states + #return hidden_states # select k tokens # k_map = k_map.transpose(-1, -2)[..., module.tcg_token_indices].transpose(-1,-2) #attn_scores = get_attention_scores(q_map, k_map, dtype=q_map.dtype) # (2*B, H*W, C) - attn_scores = attn_scores.to(torch.float32) - B, HW, C = attn_scores.shape + # attn_scores = attn_scores.to(torch.float32) + B, _, HW, C = attn_scores.shape - downscale_h = round((HW * (height / width)) ** 0.5) # attn_scores = attn_scores.view(2, attn_scores.size(0)//2, downscale_h, HW//downscale_h, attn_scores.size(-1)).mean(dim=0) # (2*B, HW, C) -> (B, H, W, C) - attn_scores = attn_scores.view(attn_scores.size(0), downscale_h, HW//downscale_h, attn_scores.size(-1)) # (2*B, HW, C) -> (B, H, W, C) # slice attn map - attn_map = attn_scores[..., module.tcg_token_indices].detach().clone() # (B, H, W, K) where K is the subset of tokens + #attn_map = attn_scores[..., module.tcg_token_indices].detach().clone() # (B, H, W, K) where K is the subset of tokens # threshold it # also represents object shape - attn_map = soft_threshold(attn_map, threshold=threshold, sharpness=sharpness) # B H W C + downscale_h = round((HW * (height / width)) ** 0.5) + attn_scores = attn_scores.view(attn_scores.size(0) * heads, downscale_h, HW//downscale_h, attn_scores.size(-1)) # (B, heads, HW, C) -> (B*heads, H, W, C) + attn_map = soft_threshold(attn_scores, threshold=threshold, sharpness=sharpness) # B H W C + + attn_map = attn_map.view(B, heads, HW, attn_map.size(-1)) # (B, heads, H, W, C) -> (B,heads, HW, C) + + # to output map + hidden_states = attn_map @ v_map + #hidden_states = attn_scores @ v_map + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * inner_dim) + hidden_states = module.to_out[0](hidden_states) + hidden_states = module.to_out[1](hidden_states) + return hidden_states # self-guidance inner_dims = attn_map.shape[1:-1] @@ -182,11 +184,9 @@ def tcg_forward_hook(module, input, kwargs, output): obj_appearance /= shape_sum + torch.finfo(torch.float32).eps self_guidance = obj_appearance - self_guidance_factor = attn_scores.detach().clone() self_guidance_factor = self_guidance_factor.view(B, -1, C) self_guidance_factor[..., module.tcg_token_indices] = self_guidance - attn_map = attn_map.view(attn_map.size(0), *inner_dims, attn_map.size(-1)) # region mask @@ -394,8 +394,8 @@ def _normalize_map(attnmap): """ B, H, W, C = attnmap.shape flattened_attnmap = attnmap.view(attnmap.shape[0], H*W, attnmap.shape[-1]).transpose(-1, -2) # B, C, H*W - min_val = torch.min(flattened_attnmap, dim=-1).values.unsqueeze(-1) # (B, C, 1) - max_val = torch.max(flattened_attnmap, dim=-1).values.unsqueeze(-1) # (B, C, 1) + min_val = torch.min(flattened_attnmap, dim=1).values.unsqueeze(1) # (B, 1, C) + max_val = torch.max(flattened_attnmap, dim=1).values.unsqueeze(1) # (B, 1, C) normalized_attn = (flattened_attnmap - min_val) / ((max_val - min_val) + torch.finfo(attnmap.dtype).eps) normalized_attn = normalized_attn.view(B, C, H*W).transpose(-1, -2) # B, H*W, C normalized_attn = normalized_attn.view(B, H, W, C) From a24f8162e006781683c6ac6dcd2ced5efc6d5799 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 30 May 2024 13:42:03 -0700 Subject: [PATCH 51/55] fix soft threshodl again? --- scripts/tcg.py | 47 ++++++++++++++++++++++------------------------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 14ca252..8af2c3c 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -127,6 +127,7 @@ def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, thr def tcg_forward_hook(module, input, kwargs, output): current_step = module.tcg_current_step + tokens = module.tcg_token_indices if not start_step <= current_step <= end_step: return if not len(token_indices) > 0: @@ -139,35 +140,31 @@ def tcg_forward_hook(module, input, kwargs, output): v_map = head_to_batch_dim(module.tcg_to_v_map, heads) # B, heads, C, inner_dim batch_size, _, seq_len, inner_dim = q_map.shape # sdp from pytorch - L, S = q_map.size(-2), k_map.size(-2) scale_factor = module.scale attn_scores = q_map @ k_map.transpose(-1, -2) * scale_factor - attn_scores = torch.softmax(attn_scores, dim=-1) + attn_scores = torch.softmax(attn_scores, dim=-1) # (B, heads, HW, C) + # dims + B, heads, HW, C = attn_scores.shape + downscale_h = round((HW * (height / width)) ** 0.5) + H, W = downscale_h, HW // downscale_h + K = len(tokens) # number of tokens - #return hidden_states - - # select k tokens - # k_map = k_map.transpose(-1, -2)[..., module.tcg_token_indices].transpose(-1,-2) - #attn_scores = get_attention_scores(q_map, k_map, dtype=q_map.dtype) # (2*B, H*W, C) - # attn_scores = attn_scores.to(torch.float32) - B, _, HW, C = attn_scores.shape - - # attn_scores = attn_scores.view(2, attn_scores.size(0)//2, downscale_h, HW//downscale_h, attn_scores.size(-1)).mean(dim=0) # (2*B, HW, C) -> (B, H, W, C) - - # slice attn map - #attn_map = attn_scores[..., module.tcg_token_indices].detach().clone() # (B, H, W, K) where K is the subset of tokens + # reshape + attn_scores = attn_scores.view(B * heads, H, W, C) # (B, heads, HW, C) -> (B*heads, H, W, C) - # threshold it - # also represents object shape - downscale_h = round((HW * (height / width)) ** 0.5) - attn_scores = attn_scores.view(attn_scores.size(0) * heads, downscale_h, HW//downscale_h, attn_scores.size(-1)) # (B, heads, HW, C) -> (B*heads, H, W, C) - attn_map = soft_threshold(attn_scores, threshold=threshold, sharpness=sharpness) # B H W C + # slice attn_scores into attn_map to operate on target tokens + attn_map = attn_scores[..., tokens].detach().clone() # (..., K) where K is the subset of tokens - attn_map = attn_map.view(B, heads, HW, attn_map.size(-1)) # (B, heads, H, W, C) -> (B,heads, HW, C) + # threshold, also represents object shape + attn_map = soft_threshold(attn_map, threshold=threshold, sharpness=sharpness) # B H W C # to output map - hidden_states = attn_map @ v_map + attn_map = attn_map.view(B*heads, H, W, K) # (B, heads, H, W, C) -> (B,heads, HW, C) + attn_scores[..., tokens] = attn_map + + attn_scores = attn_scores.view(B, heads, H*W, C) # (B, heads, HW, C) -> (B*heads, HW, C) + hidden_states = attn_scores @ v_map #hidden_states = attn_scores @ v_map hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * inner_dim) hidden_states = module.to_out[0](hidden_states) @@ -388,17 +385,17 @@ def soft_threshold(attention_map, threshold=0.5, sharpness=10): def _normalize_map(attnmap): """ Normalize the attention map over the channel dimension Arguments: - attnmap: torch.Tensor - The attention map to normalize. Shape: (B, H, W, C) or (B, HW, C) + attnmap: torch.Tensor - The attention map to normalize. Shape: (B, H, W, C) Returns: torch.Tensor - The attention map normalized to (0, 1). Shape: (B, H, W, C) """ B, H, W, C = attnmap.shape - flattened_attnmap = attnmap.view(attnmap.shape[0], H*W, attnmap.shape[-1]).transpose(-1, -2) # B, C, H*W + flattened_attnmap = attnmap.view(attnmap.shape[0], H*W, attnmap.shape[-1]) # B, H*W, C min_val = torch.min(flattened_attnmap, dim=1).values.unsqueeze(1) # (B, 1, C) max_val = torch.max(flattened_attnmap, dim=1).values.unsqueeze(1) # (B, 1, C) normalized_attn = (flattened_attnmap - min_val) / ((max_val - min_val) + torch.finfo(attnmap.dtype).eps) - normalized_attn = normalized_attn.view(B, C, H*W).transpose(-1, -2) # B, H*W, C - normalized_attn = normalized_attn.view(B, H, W, C) + normalized_attn = normalized_attn.view(B, H, W, C) # B, H*W, C + #normalized_attn = normalized_attn.view(B, H, W, C) return normalized_attn threshold = max(0.0, min(1.0, threshold)) normalized_attn = _normalize_map(attention_map) From 2d4bf1dcc27183949fa80f8d7979806ce2b029e9 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 30 May 2024 14:34:38 -0700 Subject: [PATCH 52/55] rework loss term --- scripts/tcg.py | 100 +++++++++++++++++++++++++++---------------------- 1 file changed, 56 insertions(+), 44 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 8af2c3c..9fa6a59 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -156,35 +156,23 @@ def tcg_forward_hook(module, input, kwargs, output): # slice attn_scores into attn_map to operate on target tokens attn_map = attn_scores[..., tokens].detach().clone() # (..., K) where K is the subset of tokens - # threshold, also represents object shape + # threshold, also represents object shape in self-guidance attn_map = soft_threshold(attn_map, threshold=threshold, sharpness=sharpness) # B H W C - # to output map - attn_map = attn_map.view(B*heads, H, W, K) # (B, heads, H, W, C) -> (B,heads, HW, C) - attn_scores[..., tokens] = attn_map - - attn_scores = attn_scores.view(B, heads, H*W, C) # (B, heads, HW, C) -> (B*heads, HW, C) - hidden_states = attn_scores @ v_map - #hidden_states = attn_scores @ v_map - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * inner_dim) - hidden_states = module.to_out[0](hidden_states) - hidden_states = module.to_out[1](hidden_states) - return hidden_states - - # self-guidance - inner_dims = attn_map.shape[1:-1] - attn_map = attn_map.view(attn_map.size(0), -1, attn_map.size(-1)) - - shape_sum = torch.sum(attn_map, dim=1, keepdim=True) # (B, HW) - - obj_appearance = shape_sum * attn_map - obj_appearance /= shape_sum + torch.finfo(torch.float32).eps - - self_guidance = obj_appearance - self_guidance_factor = attn_scores.detach().clone() - self_guidance_factor = self_guidance_factor.view(B, -1, C) - self_guidance_factor[..., module.tcg_token_indices] = self_guidance - attn_map = attn_map.view(attn_map.size(0), *inner_dims, attn_map.size(-1)) + # mean over heads + attn_map = attn_map.view(B, heads, H, W, K).mean(dim=1) # (B, heads, H, W, C) -> (B, H, W, K) + + # # self-guidance + # inner_dims = attn_map.shape[1:-1] + # attn_map = attn_map.view(attn_map.size(0), -1, attn_map.size(-1)) + # shape_sum = torch.sum(attn_map, dim=1, keepdim=True) # (B, HW) + # obj_appearance = shape_sum * attn_map + # obj_appearance /= shape_sum + torch.finfo(torch.float32).eps + # self_guidance = obj_appearance + # self_guidance_factor = attn_scores.detach().clone() + # self_guidance_factor = self_guidance_factor.view(B, -1, C) + # self_guidance_factor[..., module.tcg_token_indices] = self_guidance + # attn_map = attn_map.view(attn_map.size(0), *inner_dims, attn_map.size(-1)) # region mask region_mask = module.tcg_region_mask @@ -199,7 +187,7 @@ def tcg_forward_hook(module, input, kwargs, output): # detect conflicts and return if none conflicts = detect_conflict(attn_map, region_mask, theta) # (B, C) if not torch.any(conflicts > 0.01): - logger.debug("No conflicts detected") + # logger.debug("No conflicts detected") return centroids = calculate_centroid(attn_map) # (B, C, 2) @@ -209,12 +197,16 @@ def tcg_forward_hook(module, input, kwargs, output): # zero out displacement force displ_force = displ_force * conflicts.unsqueeze(-1) - logger.debug("Displacements: %s", displ_force) + # logger.debug("Displacements: %s", displ_force) + + # stack along dim 0 for heads + displ_force = torch.stack([displ_force] * heads, dim=1).view(B*heads, K, -1) # B* heads, C, 2 + centroids = torch.stack([centroids] * heads, dim=1).view(B*heads, K, -1) # B* heads, C, 2 # apply displacements - output_attn_map = attn_scores.detach().clone() # B H W C + output_attn_map = attn_scores[..., tokens].detach().clone() # B H W C modified_attn_map, out_centroids = apply_displacements( - output_attn_map[..., module.tcg_token_indices], + output_attn_map, #output_attn_map[..., module.tcg_token_indices], centroids, displ_force @@ -223,24 +215,43 @@ def tcg_forward_hook(module, input, kwargs, output): # zero out area in region for all conflict tokens modified_attn_map = modified_attn_map * (1 - region_mask) - output_attn_map[..., module.tcg_token_indices] = modified_attn_map.squeeze(0) - output_attn_map = output_attn_map.view(B, -1, C) # B HW C + attn_map = modified_attn_map + # output_attn_map = output_attn_map.view(B, -1, C) # B HW C + # output_attn_map = output_attn_map.view(B, -1, C) # B HW C # output_attn_map = output_attn_map.permute(0, 2, 1) # B C HW - orig_attn_map = attn_scores.view(B, -1, C) # B HW C - orig_attn_map = orig_attn_map @ v_map + # orig_attn_map = attn_scores.view(B, -1, C) # B HW C + # orig_attn_map = orig_attn_map @ v_map - output_attn_map = output_attn_map @ v_map - self_guidance_factor = self_guidance_factor @ v_map + # output_attn_map = output_attn_map @ v_map + # self_guidance_factor = self_guidance_factor @ v_map - loss = output_attn_map + #loss = output_attn_map #loss = output - output_attn_map - return loss - loss = normalize_map(loss) - loss **= 2 + # to output map + attn_map = attn_map.view(B*heads, H, W, K) # (B, heads, H, W, C) -> (B,heads, HW, C) + attn_scores[..., tokens] = attn_map + attn_scores = attn_scores.view(B, heads, H*W, C) # (B, heads, HW, C) -> (B*heads, HW, C) + hidden_states = attn_scores @ v_map + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * inner_dim) # B, HW, heads*C + hidden_states = module.to_out[0](hidden_states) + hidden_states = module.to_out[1](hidden_states) + + loss = output - hidden_states + loss = torch.norm(loss, dim=-1, keepdim=True) ** 2 + # loss = torch.sum(loss, dim=(1,2), keepdim=True) + #loss /= loss.norm(dim=1, keepdim=True) + torch.finfo(loss.dtype).eps + loss *= strength + #loss **= 2 + + return output + loss * hidden_states + #return hidden_states + + hidden_states = normalize_map(hidden_states) + hidden_states **= 2 - output += strength * loss + output += strength * hidden_states output += selfguidance_scale * self_guidance_factor #output = output_attn_map @@ -354,7 +365,7 @@ def displacement_force(attention_map, verts, target_pos, f_rep_strength, f_margi f_rep = f_clamp(repulsive_force(f_rep_strength, verts, target_pos)) f_margin = f_clamp(margin_force(f_margin_strength, H, W, verts)) - logger.debug(f"Repulsive force: {debug_coord(f_rep)}, Margin force: {debug_coord(f_margin)}") + #logger.debug(f"Repulsive force: {debug_coord(f_rep)}, Margin force: {debug_coord(f_margin)}") return f_rep + f_margin @@ -508,7 +519,8 @@ def warping_force(attention_map, verts, displacements, h, w): torch.clamp_max(s_y, 1.0, out=s_y) torch.clamp_max(s_x, 1.0, out=s_x) if torch.any(s_x < 0.99) or torch.any(s_y < 0.99): - logger.debug(f"Scaling factor: {s_x}, {s_y}") + #logger.debug(f"Scaling factor: {s_x}, {s_y}") + pass # displacements o_new = displacements - correction From a7d15800cb0a2d4b837a57192ae6b1552fe5a4a0 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 30 May 2024 15:32:17 -0700 Subject: [PATCH 53/55] possibly working --- scripts/tcg.py | 98 ++++++++++++++++++++++++-------------------------- 1 file changed, 47 insertions(+), 51 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 9fa6a59..b7298f8 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -73,17 +73,18 @@ def setup_ui(self, is_img2img) -> list: with gr.Row(): active = gr.Checkbox(label="Active", value=True, elem_id="tcg_active") strength = gr.Slider(label="Strength", value=1.0, minimum=-5.0, maximum=5.0, step=0.1, elem_id="tcg_strength") - f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_margin") - f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_f_repl") + f_margin = gr.Slider(label="Margin Force", value=1.0, minimum=0.0, maximum=5.0, step=0.1, elem_id="tcg_f_margin") + f_repl = gr.Slider(label="Repulsion Force", value=1.0, minimum=0.0, maximum=5.0, step=0.1, elem_id="tcg_f_repl") theta = gr.Slider(label="Conflict Threshold", value=0.01, minimum=0.0, maximum=1.0, step=0.001, elem_id="tcg_theta") threshold = gr.Slider(label="Soft Threshold", value=0.5, minimum=0.0, maximum=1.0, step=0.01, elem_id="tcg_attn_threshold") sharpness = gr.Slider(label="Threshold Sharpness", value=10.0, minimum=0.1, maximum=20.0, step=0.1, elem_id="tcg_sharpness") selfguidance_scale = gr.Slider(label="Self-Guidance Scale", value=1.0, minimum=-2.0, maximum=2.0, step=0.1, elem_id="tcg_selfguidance_scale") - clamp = gr.Slider(label="Force Clamp", value=1.0, minimum=0.0, maximum=20.0, step=1.0, elem_id="tcg_clamp") + region_exclusion = gr.Slider(label="Region Exclusion Scale", value=1.0, minimum=0.0, maximum=1.0, step=0.1, elem_id="tcg_region_exclusion") + clamp = gr.Slider(label="Force Clamp", value=10.0, minimum=0.0, maximum=20.0, step=1.0, elem_id="tcg_clamp") with gr.Row(): start_step = gr.Slider(label="Start Step", value=0, minimum=0, maximum=100, step=1, elem_id="tcg_start_step") end_step = gr.Slider(label="End Step", value=100, minimum=1, maximum=100, step=1, elem_id="tcg_end_step") - opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step, clamp] + opts = [active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step, clamp, region_exclusion] for opt in opts: opt.do_not_save_to_config = True @@ -98,7 +99,7 @@ def setup_ui(self, is_img2img) -> list: def get_modules(self): return module_hooks.get_modules( module_name_filter='CrossAttention') - def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step, clamp, *args, **kwargs): + def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, threshold, sharpness, selfguidance_scale, image_mask, start_step, end_step, clamp, region_exclusion, *args, **kwargs): self.unhook_callbacks() active = getattr(p, 'tcg_active', active) if not active: @@ -113,6 +114,7 @@ def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, thr start_step = getattr(p, 'tcg_start_step', start_step) end_step = getattr(p, 'tcg_end_step', end_step) clamp = getattr(p, 'tcg_clamp', clamp) + region_exclusion = getattr(p, 'tcg_region_exclusion', region_exclusion) batch_size = p.batch_size height, width = p.height, p.width @@ -156,23 +158,23 @@ def tcg_forward_hook(module, input, kwargs, output): # slice attn_scores into attn_map to operate on target tokens attn_map = attn_scores[..., tokens].detach().clone() # (..., K) where K is the subset of tokens - # threshold, also represents object shape in self-guidance - attn_map = soft_threshold(attn_map, threshold=threshold, sharpness=sharpness) # B H W C - # mean over heads attn_map = attn_map.view(B, heads, H, W, K).mean(dim=1) # (B, heads, H, W, C) -> (B, H, W, K) + # threshold, also represents object shape in self-guidance + attn_map = soft_threshold(attn_map, threshold=threshold, sharpness=sharpness) # B H W C + # # self-guidance # inner_dims = attn_map.shape[1:-1] - # attn_map = attn_map.view(attn_map.size(0), -1, attn_map.size(-1)) - # shape_sum = torch.sum(attn_map, dim=1, keepdim=True) # (B, HW) + # attn_map = attn_map.view(B, heads, HW, K) # (B, heads, H, W, K) -> (B, HW, K) + # shape_sum = torch.sum(attn_map, dim=2, keepdim=True) # (B, heads, 1, K) # obj_appearance = shape_sum * attn_map # obj_appearance /= shape_sum + torch.finfo(torch.float32).eps - # self_guidance = obj_appearance - # self_guidance_factor = attn_scores.detach().clone() - # self_guidance_factor = self_guidance_factor.view(B, -1, C) - # self_guidance_factor[..., module.tcg_token_indices] = self_guidance - # attn_map = attn_map.view(attn_map.size(0), *inner_dims, attn_map.size(-1)) + + # self_guidance = obj_appearance.detach().clone() # B HW K + # self_guidance = self_guidance.view(B*heads, H, W, K) + # attn_map = attn_map.view(B*heads, H, W, K) + # region mask region_mask = module.tcg_region_mask @@ -212,49 +214,43 @@ def tcg_forward_hook(module, input, kwargs, output): displ_force ) # B H W C - # zero out area in region for all conflict tokens - modified_attn_map = modified_attn_map * (1 - region_mask) + # region exclusion: zero out area in region for all conflict tokens + #region_exclusion_mask = 1 - region_mask + (1 - region_exclusion) + #region_exclusion_mask = torch.clamp(region_exclusion_mask, 0, 1) + # region_exclusion_mask = region_exclusion_mask.repeat(1, 1, 1, K).view(1, -1, K) # (1, HW, K) + # region_exclusion_mask *= conflicts.max(dim=0, keepdim=True).values.view(1, 1, -1) # (1, HW, K) # only use mask for regions with conflicts + # region_exclusion_mask = region_exclusion_mask.view(1, H, W, K) # (1, H, W, K) - attn_map = modified_attn_map - # output_attn_map = output_attn_map.view(B, -1, C) # B HW C - # output_attn_map = output_attn_map.view(B, -1, C) # B HW C - - # output_attn_map = output_attn_map.permute(0, 2, 1) # B C HW - # orig_attn_map = attn_scores.view(B, -1, C) # B HW C - # orig_attn_map = orig_attn_map @ v_map + modified_attn_map = modified_attn_map * (1-region_mask) + #modified_attn_map = modified_attn_map * (1 - region_mask) + #modified_attn_map = modified_attn_map * (1 - region_mask) + (modified_attn_map * region_mask * (1-region_exclusion)) - # output_attn_map = output_attn_map @ v_map - # self_guidance_factor = self_guidance_factor @ v_map - - #loss = output_attn_map - #loss = output - output_attn_map + attn_map = modified_attn_map # to output map - attn_map = attn_map.view(B*heads, H, W, K) # (B, heads, H, W, C) -> (B,heads, HW, C) - attn_scores[..., tokens] = attn_map - attn_scores = attn_scores.view(B, heads, H*W, C) # (B, heads, HW, C) -> (B*heads, HW, C) - hidden_states = attn_scores @ v_map - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * inner_dim) # B, HW, heads*C - hidden_states = module.to_out[0](hidden_states) - hidden_states = module.to_out[1](hidden_states) - - loss = output - hidden_states + tcg_attnmap = attn_scores.detach().clone() + tcg_attnmap[..., tokens] = attn_map + tcg_attnmap = tcg_attnmap.view(B, heads, HW, C) # (B, heads, H, W, C) -> (B*heads, HW, C) + + # attn_scores = attn_scores.view(B, heads, H*W, C) # (B, heads, HW, C) -> (B*heads, HW, C) + tcg_hidden_states = tcg_attnmap @ v_map + tcg_hidden_states = tcg_hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * inner_dim) # B, HW, heads*C + tcg_hidden_states = module.to_out[0](tcg_hidden_states) + tcg_hidden_states = module.to_out[1](tcg_hidden_states) + + # selfguidance_attnmap = attn_scores.detach().clone() + # selfguidance_attnmap[..., tokens] = self_guidance + # selfguidance_attnmap = selfguidance_attnmap.view(B, heads, HW, C) + # sg_hidden_states = selfguidance_attnmap @ v_map + # sg_hidden_states = sg_hidden_states.transpose(1, 2).reshape(batch_size, -1, heads * inner_dim) # B, HW, heads*C + # sg_hidden_states = module.to_out[0](sg_hidden_states) + # sg_hidden_states = module.to_out[1](sg_hidden_states) + + loss = output - tcg_hidden_states loss = torch.norm(loss, dim=-1, keepdim=True) ** 2 - # loss = torch.sum(loss, dim=(1,2), keepdim=True) - #loss /= loss.norm(dim=1, keepdim=True) + torch.finfo(loss.dtype).eps loss *= strength - #loss **= 2 - - return output + loss * hidden_states - #return hidden_states - - hidden_states = normalize_map(hidden_states) - hidden_states **= 2 - - output += strength * hidden_states - output += selfguidance_scale * self_guidance_factor - #output = output_attn_map + return output - loss * tcg_hidden_states def tcg_to_q_hook(module, input, kwargs, output): setattr(module.tcg_parent_module[0], 'tcg_to_q_map', output) From e1f614f5bc83f13b8ccece04cd4ad23ed3080b8f Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 30 May 2024 15:46:03 -0700 Subject: [PATCH 54/55] add guard clause when strength is 0 --- scripts/tcg.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/tcg.py b/scripts/tcg.py index b7298f8..3f2e1c7 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -130,6 +130,8 @@ def before_process_batch(self, p, active, strength, f_margin, f_repl, theta, thr def tcg_forward_hook(module, input, kwargs, output): current_step = module.tcg_current_step tokens = module.tcg_token_indices + if strength == 0: + return if not start_step <= current_step <= end_step: return if not len(token_indices) > 0: From 2f90e59189a85e5a76731b27d0f586b7da08c556 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 5 Jun 2024 07:51:34 -0700 Subject: [PATCH 55/55] test loss fn --- scripts/tcg.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/scripts/tcg.py b/scripts/tcg.py index 3f2e1c7..3e7106a 100644 --- a/scripts/tcg.py +++ b/scripts/tcg.py @@ -223,7 +223,7 @@ def tcg_forward_hook(module, input, kwargs, output): # region_exclusion_mask *= conflicts.max(dim=0, keepdim=True).values.view(1, 1, -1) # (1, HW, K) # only use mask for regions with conflicts # region_exclusion_mask = region_exclusion_mask.view(1, H, W, K) # (1, H, W, K) - modified_attn_map = modified_attn_map * (1-region_mask) + modified_attn_map = region_exclusion * modified_attn_map + (1-region_exclusion) * modified_attn_map * (1-region_mask) #modified_attn_map = modified_attn_map * (1 - region_mask) #modified_attn_map = modified_attn_map * (1 - region_mask) + (modified_attn_map * region_mask * (1-region_exclusion)) @@ -250,9 +250,11 @@ def tcg_forward_hook(module, input, kwargs, output): loss = output - tcg_hidden_states loss = torch.norm(loss, dim=-1, keepdim=True) ** 2 - loss *= strength + #loss *= strength + strength_factor = max(0, 1 - strength) - return output - loss * tcg_hidden_states + return (1-strength_factor) * output + (strength * tcg_hidden_states) + #return (1-strength_factor) * output + (strength * loss * tcg_hidden_states) def tcg_to_q_hook(module, input, kwargs, output): setattr(module.tcg_parent_module[0], 'tcg_to_q_map', output)