Skip to content

Commit

Permalink
No trace tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 13, 2024
1 parent ea85c86 commit e9fee1a
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ def __call__(
hidden_states = hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
trace_tensor("attn_out", hidden_states[0,0,0,0])
#trace_tensor("attn_out", hidden_states[0,0,0,0])

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
Expand All @@ -1152,7 +1152,7 @@ def __call__(
hidden_states[:, residual.shape[1] :],
)
hidden_states_cl = hidden_states.clone()
trace_tensor("attn_out", hidden_states_cl[0,0,0])
#trace_tensor("attn_out", hidden_states_cl[0,0,0])
# linear proj
hidden_states = attn.to_out[0](hidden_states_cl)
# dropout
Expand Down Expand Up @@ -1221,13 +1221,9 @@ def __call__(
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
trace_tensor("query", query)
trace_tensor("key", key)
trace_tensor("value", value)
hidden_states = hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, is_causal=False
)
trace_tensor("attn_out", hidden_states[:,:,:50])

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
Expand Down Expand Up @@ -1597,10 +1593,6 @@ def __call__(
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
trace_tensor("query", query)
trace_tensor("key", key)
trace_tensor("value", value)
trace_tensor("attn_out", hidden_states[:,:,:50])
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down

0 comments on commit e9fee1a

Please sign in to comment.