Skip to content

Commit

Permalink
Added --channel_multiplier options to convert_weight.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Feb 27, 2020
1 parent bb857e1 commit c7c8fc7
Showing 1 changed file with 46 additions and 26 deletions.
72 changes: 46 additions & 26 deletions convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,19 @@ def convert_modconv(vars, source_name, target_name, flip=False):

def convert_conv(vars, source_name, target_name, bias=True, start=0):
weight = vars[source_name + '/weight'].value().eval()

dic = {'weight': weight.transpose((3, 2, 0, 1))}

if bias:
dic['bias'] = vars[source_name + '/bias'].value().eval()

dic_torch = {}

dic_torch[target_name + f'.{start}.weight'] = torch.from_numpy(dic['weight'])

if bias:
dic_torch[target_name + f'.{start + 1}.bias'] = torch.from_numpy(dic['bias'])

return dic_torch


Expand Down Expand Up @@ -101,26 +101,39 @@ def update(state_dict, new):
raise ValueError(f'Shape mismatch: {v.shape} vs {state_dict[k].shape}')

state_dict[k] = v


def discriminator_fill_statedict(statedict, vars, size):
log_size = int(math.log(size, 2))

update(statedict, convert_conv(vars, f'{size}x{size}/FromRGB', 'convs.0'))

conv_i = 1

for i in range(log_size - 2, 0, -1):
reso = 4 * 2 ** i
update(statedict, convert_conv(vars, f'{reso}x{reso}/Conv0', f'convs.{conv_i}.conv1'))
update(statedict, convert_conv(vars, f'{reso}x{reso}/Conv1_down', f'convs.{conv_i}.conv2', start=1))
update(statedict, convert_conv(vars, f'{reso}x{reso}/Skip', f'convs.{conv_i}.skip', start=1, bias=False))
update(
statedict,
convert_conv(vars, f'{reso}x{reso}/Conv0', f'convs.{conv_i}.conv1'),
)
update(
statedict,
convert_conv(
vars, f'{reso}x{reso}/Conv1_down', f'convs.{conv_i}.conv2', start=1
),
)
update(
statedict,
convert_conv(
vars, f'{reso}x{reso}/Skip', f'convs.{conv_i}.skip', start=1, bias=False
),
)
conv_i += 1

update(statedict, convert_conv(vars, f'4x4/Conv', 'final_conv'))
update(statedict, convert_dense(vars, f'4x4/Dense0', 'final_linear.0'))
update(statedict, convert_dense(vars, f'Output', 'final_linear.1'))

return statedict


Expand Down Expand Up @@ -166,9 +179,7 @@ def fill_statedict(state_dict, vars, size):
update(
state_dict,
convert_modconv(
vars,
f'G_synthesis/{reso}x{reso}/Conv1',
f'convs.{conv_i + 1}'
vars, f'G_synthesis/{reso}x{reso}/Conv1', f'convs.{conv_i + 1}'
),
)
conv_i += 2
Expand All @@ -193,6 +204,7 @@ def fill_statedict(state_dict, vars, size):
parser.add_argument('--repo', type=str, required=True)
parser.add_argument('--gen', action='store_true')
parser.add_argument('--disc', action='store_true')
parser.add_argument('--channel_multiplier', type=int, default=2)
parser.add_argument('path', metavar='PATH')

args = parser.parse_args()
Expand All @@ -209,24 +221,24 @@ def fill_statedict(state_dict, vars, size):

size = g_ema.output_shape[2]

g = Generator(size, 512, 8)
g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
state_dict = g.state_dict()
state_dict = fill_statedict(state_dict, g_ema.vars, size)

g.load_state_dict(state_dict)

latent_avg = torch.from_numpy(g_ema.vars['dlatent_avg'].value().eval())

ckpt = {'g_ema': state_dict, 'latent_avg': latent_avg}

if args.gen:
g_train = Generator(size, 512, 8)
g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier)
g_train_state = g_train.state_dict()
g_train_state = fill_statedict(g_train_state, generator.vars, size)
ckpt['g'] = g_train_state

if args.disc:
disc = Discriminator(size)
disc = Discriminator(size, channel_multiplier=args.channel_multiplier)
d_state = disc.state_dict()
d_state = discriminator_fill_statedict(d_state, discriminator.vars, size)
ckpt['d'] = d_state
Expand All @@ -242,15 +254,23 @@ def fill_statedict(state_dict, vars, size):
z = np.random.RandomState(0).randn(n_sample, 512).astype('float32')

with torch.no_grad():
img_pt, _ = g([torch.from_numpy(z).to(device)], truncation=0.5, truncation_latent=latent_avg.to(device))
img_pt, _ = g(
[torch.from_numpy(z).to(device)],
truncation=0.5,
truncation_latent=latent_avg.to(device),
)

Gs_kwargs = dnnlib.EasyDict()
Gs_kwargs.randomize_noise = False
img_tf = g_ema.run(z, None, **Gs_kwargs)
img_tf = torch.from_numpy(img_tf).to(device)

img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(0.0, 1.0)
img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(
0.0, 1.0
)

img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0)
utils.save_image(img_concat, name + '.png', nrow=n_sample, normalize=True, range=(-1, 1))
utils.save_image(
img_concat, name + '.png', nrow=n_sample, normalize=True, range=(-1, 1)
)

0 comments on commit c7c8fc7

Please sign in to comment.