From e9fee1a4fdd23a812305cd6453fc57906e2c3ece Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Thu, 13 Jun 2024 05:36:43 -0500 Subject: [PATCH] No trace tensors --- src/diffusers/models/attention_processor.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4544131803f8..9c38905d8600 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1141,7 +1141,7 @@ def __call__( hidden_states = hidden_states = F.scaled_dot_product_attention( query, key, value, dropout_p=0.0, is_causal=False ) - trace_tensor("attn_out", hidden_states[0,0,0,0]) + #trace_tensor("attn_out", hidden_states[0,0,0,0]) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -1152,7 +1152,7 @@ def __call__( hidden_states[:, residual.shape[1] :], ) hidden_states_cl = hidden_states.clone() - trace_tensor("attn_out", hidden_states_cl[0,0,0]) + #trace_tensor("attn_out", hidden_states_cl[0,0,0]) # linear proj hidden_states = attn.to_out[0](hidden_states_cl) # dropout @@ -1221,13 +1221,9 @@ def __call__( query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - trace_tensor("query", query) - trace_tensor("key", key) - trace_tensor("value", value) hidden_states = hidden_states = F.scaled_dot_product_attention( query, key, value, dropout_p=0.0, is_causal=False ) - trace_tensor("attn_out", hidden_states[:,:,:50]) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -1597,10 +1593,6 @@ def __call__( hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - trace_tensor("query", query) - trace_tensor("key", key) - trace_tensor("value", value) - trace_tensor("attn_out", hidden_states[:,:,:50]) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype)