diff --git a/convert_weight.py b/convert_weight.py index 1d485ef5..21bba54e 100755 --- a/convert_weight.py +++ b/convert_weight.py @@ -41,19 +41,19 @@ def convert_modconv(vars, source_name, target_name, flip=False): def convert_conv(vars, source_name, target_name, bias=True, start=0): weight = vars[source_name + '/weight'].value().eval() - + dic = {'weight': weight.transpose((3, 2, 0, 1))} - + if bias: dic['bias'] = vars[source_name + '/bias'].value().eval() - + dic_torch = {} - + dic_torch[target_name + f'.{start}.weight'] = torch.from_numpy(dic['weight']) - + if bias: dic_torch[target_name + f'.{start + 1}.bias'] = torch.from_numpy(dic['bias']) - + return dic_torch @@ -101,26 +101,39 @@ def update(state_dict, new): raise ValueError(f'Shape mismatch: {v.shape} vs {state_dict[k].shape}') state_dict[k] = v - - + + def discriminator_fill_statedict(statedict, vars, size): log_size = int(math.log(size, 2)) - + update(statedict, convert_conv(vars, f'{size}x{size}/FromRGB', 'convs.0')) conv_i = 1 for i in range(log_size - 2, 0, -1): reso = 4 * 2 ** i - update(statedict, convert_conv(vars, f'{reso}x{reso}/Conv0', f'convs.{conv_i}.conv1')) - update(statedict, convert_conv(vars, f'{reso}x{reso}/Conv1_down', f'convs.{conv_i}.conv2', start=1)) - update(statedict, convert_conv(vars, f'{reso}x{reso}/Skip', f'convs.{conv_i}.skip', start=1, bias=False)) + update( + statedict, + convert_conv(vars, f'{reso}x{reso}/Conv0', f'convs.{conv_i}.conv1'), + ) + update( + statedict, + convert_conv( + vars, f'{reso}x{reso}/Conv1_down', f'convs.{conv_i}.conv2', start=1 + ), + ) + update( + statedict, + convert_conv( + vars, f'{reso}x{reso}/Skip', f'convs.{conv_i}.skip', start=1, bias=False + ), + ) conv_i += 1 - + update(statedict, convert_conv(vars, f'4x4/Conv', 'final_conv')) update(statedict, convert_dense(vars, f'4x4/Dense0', 'final_linear.0')) update(statedict, convert_dense(vars, f'Output', 'final_linear.1')) - + return statedict @@ -166,9 +179,7 @@ def fill_statedict(state_dict, vars, size): update( state_dict, convert_modconv( - vars, - f'G_synthesis/{reso}x{reso}/Conv1', - f'convs.{conv_i + 1}' + vars, f'G_synthesis/{reso}x{reso}/Conv1', f'convs.{conv_i + 1}' ), ) conv_i += 2 @@ -193,6 +204,7 @@ def fill_statedict(state_dict, vars, size): parser.add_argument('--repo', type=str, required=True) parser.add_argument('--gen', action='store_true') parser.add_argument('--disc', action='store_true') + parser.add_argument('--channel_multiplier', type=int, default=2) parser.add_argument('path', metavar='PATH') args = parser.parse_args() @@ -209,24 +221,24 @@ def fill_statedict(state_dict, vars, size): size = g_ema.output_shape[2] - g = Generator(size, 512, 8) + g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) state_dict = g.state_dict() state_dict = fill_statedict(state_dict, g_ema.vars, size) g.load_state_dict(state_dict) latent_avg = torch.from_numpy(g_ema.vars['dlatent_avg'].value().eval()) - + ckpt = {'g_ema': state_dict, 'latent_avg': latent_avg} - + if args.gen: - g_train = Generator(size, 512, 8) + g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) g_train_state = g_train.state_dict() g_train_state = fill_statedict(g_train_state, generator.vars, size) ckpt['g'] = g_train_state - + if args.disc: - disc = Discriminator(size) + disc = Discriminator(size, channel_multiplier=args.channel_multiplier) d_state = disc.state_dict() d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) ckpt['d'] = d_state @@ -242,15 +254,23 @@ def fill_statedict(state_dict, vars, size): z = np.random.RandomState(0).randn(n_sample, 512).astype('float32') with torch.no_grad(): - img_pt, _ = g([torch.from_numpy(z).to(device)], truncation=0.5, truncation_latent=latent_avg.to(device)) + img_pt, _ = g( + [torch.from_numpy(z).to(device)], + truncation=0.5, + truncation_latent=latent_avg.to(device), + ) Gs_kwargs = dnnlib.EasyDict() Gs_kwargs.randomize_noise = False img_tf = g_ema.run(z, None, **Gs_kwargs) img_tf = torch.from_numpy(img_tf).to(device) - img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(0.0, 1.0) + img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp( + 0.0, 1.0 + ) img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0) - utils.save_image(img_concat, name + '.png', nrow=n_sample, normalize=True, range=(-1, 1)) + utils.save_image( + img_concat, name + '.png', nrow=n_sample, normalize=True, range=(-1, 1) + )