Skip to content

How to get intermediate feature in UNet #5796

Answered by KumoLiu
ukaukaaaa asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @ukaukaaaa, I think you can do something like this:

import torch
from monai.networks.nets import UNet

net = UNet(spatial_dims=3, in_channels=1, out_channels=2, channels=(32,64,128,256), strides=(2,2,2))

intermediate_outputs = {}
def forward_hook(layer_name):
    def hook(module, input, output):
        intermediate_outputs[layer_name] = output

    return hook

input = torch.randn(4,1,64,64,64)
net.model[-1].register_forward_hook(forward_hook("last_layer"))
out = net(input)
print(intermediate_outputs['last_layer'])

Hope it can help you, thanks!

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by wyli
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants
Converted from issue

This discussion was converted from issue #5788 on December 30, 2022 08:30.