Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Understanding the discrete reverse process #2

Open
SagiPolaczek opened this issue Aug 17, 2023 · 0 comments
Open

Understanding the discrete reverse process #2

SagiPolaczek opened this issue Aug 17, 2023 · 0 comments

Comments

@SagiPolaczek
Copy link

SagiPolaczek commented Aug 17, 2023

Hey!

First, thank you for your contribution to the field as well as open-sourcing your code! Really appreciated!
I hopes it's OK that I approach to you here: I want to use D3PM for protein sequences (similar to what you did, with LM) but I'm struggling to understand the following point in the reverse process:

In your paper you've mentioned:

Screenshot 2023-08-17 at 12 39 12

Which theoretically, I agree with. But when it comes to implementing it's not possible to calculate the sum in the last line.
While that in the original paper of D3PM they use the mean & log scale to predict that distribution, as far as I understand, in your code you only consider the logits of $p_\theta(x_0 | x_t)$.

More specifically I looked at:
MLMDiffusionTransformer.forward():

        sequence_output = self.encoder(embed, encoder_attention_mask=attn_mask)[0]
        prediction_scores = self.cls(sequence_output)

        out = {
            "logits": prediction_scores,
            "sequence_output": sequence_output,
            "embeds": token_embed,
        }

AND
MLMDiffusion.forward():

        corrupt_ids, corrupt_mask = (
            self.noise_schedule.corrupt(input_ids, t, corrupt_mask)
        )

        model_output = self.network(
            corrupt_ids,
            t, 
            attn_mask,
        )
        logits = model_output['logits']
        hiddens = model_output['sequence_output']
        
        loss_fct = nn.CrossEntropyLoss(reduction='none')  # -100 index = padding token
        nll = loss_fct(logits.view(-1, logits.shape[-1]), input_ids.view(-1))

Am I missing something? If not, how is it match the paper?

Thanks a lot!
Sagi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant