Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the meaning of the mask in "attn.data.masked_fill_(mask.data, -float('inf'))" in the forward function of the class GlobalAttentionGeneral #93

Open
fyw1999 opened this issue Jul 8, 2021 · 2 comments

Comments

@fyw1999
Copy link

fyw1999 commented Jul 8, 2021

In the forward function of the class GlobalAttentionGeneral, there are some codes i think maybe wrong. We assume that batch_size is 20, words_num is 18 and embedding_dim is 256, so in the second stage of the generation, the dimensions of argument input and context are (20,48,64,64) and (20,256,18). After " attn = attn.view(batch_size * queryL, sourceL)", we can infer that the dimension of attn is (81920,18), meanwhile, we know the dimension of mask is equal to the dimension of captions is (20,18). If a value in the tensor captions is 0, it indicates that there is no word in that position in the sentence, and the same position in the tensor mask is 1. Based on the above analysis, i think the purpose of code "attn.data.masked_fill_(mask.data.bool(), -float('inf'))" is setting the value in attn to minus infinity if these is no word in the corresponding position of captions. However, although the dimension of mask is the same as the dimension of attn after "mask = self.mask.repeat(queryL, 1)", the meaning of the position of value is not corresponding between mask and attn. Because the dimension of mask goes from (20,18) to (81920,18) and the dimension of attn goes from (20,4096,18) to (81920,18), which just repeat the dimension of mask along the row, in this case, mask[1][0] represents whether there is a word in the first position of the sencond sentence in a batch but attn[1][0] represents the dot product of the second pixel of the first image and the firsrt word of the first sentence in a batch. So the same postion in mask and attn represents different meaning. Can anyone answer my questions, thank you very much.
`class GlobalAttentionGeneral(nn.Module):
def init(self, idf, cdf):
super(GlobalAttentionGeneral, self).init()
self.conv_context = conv1x1(cdf, idf)
self.sm = nn.Softmax()
self.mask = None

def applyMask(self, mask):
    self.mask = mask  # batch x sourceL

def forward(self, input, context):
    """
        input: batch x idf x ih x iw (queryL=ihxiw)
        context: batch x cdf x sourceL
    """
    ih, iw = input.size(2), input.size(3)
    queryL = ih * iw
    batch_size, sourceL = context.size(0), context.size(2)

    # --> batch x queryL x idf
    target = input.view(batch_size, -1, queryL)
    targetT = torch.transpose(target, 1, 2).contiguous()
    # batch x cdf x sourceL --> batch x cdf x sourceL x 1
    sourceT = context.unsqueeze(3)
    # --> batch x idf x sourceL
    sourceT = self.conv_context(sourceT).squeeze(3)

    # Get attention
    # (batch x queryL x idf)(batch x idf x sourceL)
    # -->batch x queryL x sourceL
    attn = torch.bmm(targetT, sourceT)
    # --> batch*queryL x sourceL
    attn = attn.view(batch_size*queryL, sourceL)
    if self.mask is not None:
        # batch_size x sourceL --> batch_size*queryL x sourceL
        mask = self.mask.repeat(queryL, 1)
        attn.data.masked_fill_(mask.data, -float('inf'))
    attn = self.sm(attn)  # Eq. (2)
    # --> batch x queryL x sourceL
    attn = attn.view(batch_size, queryL, ### sourceL)
    # --> batch x sourceL x queryL
    attn = torch.transpose(attn, 1, 2).contiguous()

    # (batch x idf x sourceL)(batch x sourceL x queryL)
    # --> batch x idf x queryL
    weightedContext = torch.bmm(sourceT, attn)
    weightedContext = weightedContext.view(batch_size, -1, ih, iw)
    attn = attn.view(batch_size, -1, ih, iw)

    return weightedContext, attn`
@baolp
Copy link

baolp commented Mar 6, 2022

I have the same question.

@YibinLiu666
Copy link

I have the same question.
mask is used to eliminate the influence of images with the same label, otherwise the distance between text and images with the same label will be expanded in contrastive learning

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants