Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference from Checkpoints #12

Open
ghost opened this issue Jun 14, 2021 · 5 comments
Open

Inference from Checkpoints #12

ghost opened this issue Jun 14, 2021 · 5 comments

Comments

@ghost
Copy link

ghost commented Jun 14, 2021

Hi and thank you very much for your really helpful code!
I am trying to test my trained model and have problems with the inference.py file.
I specified a checkpoint stored in the ckpt folder, but I get a "KeyError: 'param'".
Could you please elaborate on how to use the --model_path flag? (And in general, it would be useful to have a quick overview on how to use the inference.py file.)
Thank you very much in advance and best regards.

@ghost
Copy link
Author

ghost commented Jun 22, 2021

I browsed a little bit through the history of the project and found my problem to be a result of the refactoring of the inference. I will update the code and prepare a pull request, if you don't mind

@minhkids
Copy link

minhkids commented Jun 1, 2022

yeah, I have the same problem with " KeyErrror: 'param'", and can you update how to use inference in README pls!

@minhkids
Copy link

minhkids commented Jun 2, 2022

@bowphs could you show me how to fix this issue!

@ghost
Copy link
Author

ghost commented Jun 2, 2022

If I remember correctly, the problem are the *_run scripts, which do not save the model properly: During inference, you try to load the model, but the keys do not exist.
A quick fix would be to save the model manually in your *_run script, for example, in the WTM_run.py script, you could add something like:

save_name = f'./ckpt/WTM_{taskname}_tp{n_topic}_{dist}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'

    checkpoint = {
        "net": model.wae.state_dict(),
        "optimizer": model.optimizer.state_dict(),
        "epoch": num_epochs,
        "param": {
            "bow_dim": voc_size,
            "n_topic": n_topic,
            "taskname": taskname,
            "dist": dist,
            "dropout": dropout
        }
    }
    torch.save(checkpoint, save_name)
    print("Succesfully saved model. Model name: {save_name}.")

@minhkids
Copy link

minhkids commented Jun 7, 2022

Thanks for your reply, I just try your code and it still not working for me, for other *_run file doesn't have checkpoint so idk how to handle those file in order to work. I would be really thankful if you can give me your repo which you fix in this issue. @bowphs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant