Skip to content

Commit

Permalink
Debugged the batch normalization, inference produces actual results now
Browse files Browse the repository at this point in the history
  • Loading branch information
Tuomas Frondelius committed Aug 31, 2018
1 parent 5eb93de commit 96cf1da
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
42 changes: 28 additions & 14 deletions CNTKUNet/CNTKUNet/Models/UNet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ private static void center(Variable features, int[] ks, int in_channels, int out
}

//Convolution blocks
var block1 = ConvBlock(features, ks, in_channels, out_channels, null, null, wpath, name1, true, true, false);
var block2 = ConvBlock(block1.Output, ks, out_channels, out_channels, null, null, wpath, name2, true, true, false);
var block1 = ConvBlock(features, ks, in_channels, out_channels, null, null, wpath, name1, true, true, true);
var block2 = ConvBlock(block1.Output, ks, out_channels, out_channels, null, null, wpath, name2, true, true, true);

center = block2.Output;
}
Expand Down Expand Up @@ -342,19 +342,19 @@ private static void make_bn_pars(int out_channels, NDShape shape, out Parameter
string vn = "bn" + S[0] + "_" + S[1] + "_" + "running_var";
//Initialize new weight
w = new Parameter(new int[] { out_channels}, DataType.Float, CNTKLib.GlorotUniformInitializer(), DeviceDescriptor.CPUDevice);
float[] W = make_copies(weight_fromDisk(wpath, wn),shape);
float[] W = make_copies(weight_fromDisk(wpath, wn),shape,true);
w = weight_fromFloat(w, W, new int[] { out_channels });
//Initialize new bias
b = new Parameter(new int[] { out_channels }, DataType.Float, CNTKLib.GlorotUniformInitializer(), DeviceDescriptor.CPUDevice);
W = make_copies(weight_fromDisk(wpath, bn), shape);
W = make_copies(weight_fromDisk(wpath, bn), shape, true);
b = weight_fromFloat(b, W, new int[] { out_channels });
//Initialize new running mean
m = new Parameter(new int[] { out_channels }, DataType.Float, CNTKLib.GlorotUniformInitializer(), DeviceDescriptor.CPUDevice);
W = make_copies(weight_fromDisk(wpath, mn), shape);
W = make_copies(weight_fromDisk(wpath, mn), shape, true);
m = weight_fromFloat(m, W, new int[] { out_channels });
//Initialize new variance
v = new Parameter(new int[] { out_channels }, DataType.Float, CNTKLib.GlorotUniformInitializer(), DeviceDescriptor.CPUDevice);
W = make_copies(weight_fromDisk(wpath, vn), shape);
W = make_copies(weight_fromDisk(wpath, vn), shape, true);
v = weight_fromFloat(v, W, new int[] { out_channels });
}
else
Expand Down Expand Up @@ -400,20 +400,34 @@ private static float[] from_kernel(float[] kernel, int dim)
}

//Copy bias to match the dimensions of the input
private static float[] make_copies(float[] kernel, NDShape dims)
private static float[] make_copies(float[] kernel, NDShape dims, bool swap_axes=false)
{
int K = dims[0] * dims[1];
float[] outarray = new float[dims[0]* dims[1]* dims[2]];
//Loop over number of maps
for(int k=0; k<dims[2]; k++)
float[] outarray = new float[dims[0] * dims[1] * dims[2]];
if (swap_axes == false)
{
//Loop over spatial dimensions
for(int kk = 0; kk<K; kk++)
//Loop over number of maps
for (int k = 0; k < dims[2]; k++)
{
outarray[k*K + kk] = kernel[k];
//Loop over spatial dimensions
for (int kk = 0; kk < K; kk++)
{
outarray[k * K + kk] = kernel[k];
}
}
}
else
{
//Loop over number of maps
for (int k = 0; k < K; k++)
{
//Loop over spatial dimensions
for (int kk = 0; kk < dims[2]; kk++)
{
outarray[k*dims[2] + kk] = kernel[kk];
}
}
}

return outarray;
}

Expand Down
3 changes: 2 additions & 1 deletion CNTKUNet/CNTKUNet/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

using CNTKUNet.Models;


namespace CNTKUNet
{
class Program
Expand All @@ -22,7 +23,7 @@ class Program
static void Main()
{
//Path to weights
string wpath = "c:\\users\\jfrondel\\Desktop\\GITS\\UNetE3bn.h5";
string wpath = "c:\\users\\jfrondel\\Desktop\\GITS\\UNetE3bnf.h5";

//Path to test image
string impath = "c:\\users\\jfrondel\\desktop\\GITS\\sample.png";
Expand Down

0 comments on commit 96cf1da

Please sign in to comment.