Skip to content

Commit

Permalink
修正。
Browse files Browse the repository at this point in the history
  • Loading branch information
takishim committed Feb 22, 2017
1 parent 3541d56 commit d0a3796
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 45 deletions.
9 changes: 8 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,11 @@
<groupId>to.kishimo</groupId>
<artifactId>mnist_java</artifactId>
<version>0.1</version>
</project>

<properties>
<java.version>1.8</java.version>
<maven.compiler.target>${java.version}</maven.compiler.target>
<maven.compiler.source>${java.version}</maven.compiler.source>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
</project>
100 changes: 56 additions & 44 deletions src/main/java/to/kishimo/minist/MnistDataSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import java.net.URLConnection;
import java.util.zip.GZIPInputStream;


/**
* MNISTの手書き文字の画像データを扱うクラス.
*/
Expand All @@ -18,38 +17,50 @@ public class MnistDataSet implements Serializable {
public static final String TRAIN_LABEL_FILE = "train-labels-idx1-ubyte.gz";
public static final String TEST_IMAGE_FILE = "t10k-images-idx3-ubyte.gz";
public static final String TEST_LABEL_FILE = "t10k-labels-idx1-ubyte.gz";
public static final String BASE_PATH = "./dataset/mnist/";

private static final String BASE_URL = "http://yann.lecun.com/exdb/mnist/";
private static final String BASE_PATH = "./dataset/mnist/";
private static final String SERIALIZED_FILE = "_mnist.ser";
private static final long serialVersionUID = 1L;

private String prefix;
private int numImages;
private int numDimensions;
private double[][] features;
private int[] labels;

/**
* 使用例.
*/
public static void main(String... args) throws IOException, ClassNotFoundException {
// トレーニングデータセット
MnistDataSet trainDataSet = MnistDataSet.createInstance("train", TRAIN_IMAGE_FILE, TRAIN_LABEL_FILE);
MnistDataSet trainDataSet = MnistDataSet.createInstance("train");
// 概形を表示する
trainDataSet.showImageAsText(0);
// 画像を表示する
trainDataSet.showImage(1);
// 画像を保存する
trainDataSet.saveImage(BASE_PATH, "train", 2);
trainDataSet.saveImage(MnistDataSet.BASE_PATH, 2);

// テストデータセット
MnistDataSet testDataSet = MnistDataSet.createInstance("test", TEST_IMAGE_FILE, TEST_LABEL_FILE);
MnistDataSet testDataSet = MnistDataSet.createInstance("test");
// 概形を表示する
testDataSet.showImageAsText(0);
// 画像を表示する
testDataSet.showImage(1);
// 画像を保存する
testDataSet.saveImage(BASE_PATH, "test", 2);
testDataSet.saveImage(MnistDataSet.BASE_PATH, 2);
}

private MnistDataSet(String prefix) {
this.prefix = prefix;
}

public static MnistDataSet createInstance(String prefix) throws IOException, ClassNotFoundException {
if ("train".equals(prefix)) {
return createInstance(prefix, TRAIN_IMAGE_FILE, TRAIN_LABEL_FILE);
} else if ("test".equals(prefix)) {
return createInstance(prefix, TEST_IMAGE_FILE, TEST_LABEL_FILE);
} else {
throw new IllegalArgumentException("Prefix must be 'train' or 'test'.");
}
}

/**
Expand All @@ -72,17 +83,19 @@ public static MnistDataSet createInstance(String prefix, String imageFile, Strin
String serFilePath = BASE_PATH + prefix + SERIALIZED_FILE;
if (new File(serFilePath).exists()) {
System.out.println("Deserializing object from " + prefix + SERIALIZED_FILE + " ...");
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(serFilePath));
return (MnistDataSet) ois.readObject();
try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(serFilePath));) {
return (MnistDataSet) ois.readObject();
}
} else {
MnistDataSet dataSet = new MnistDataSet();
MnistDataSet dataSet = new MnistDataSet(prefix);
System.out.println("Loading feature data from " + imageFile + " ...");
dataSet.loadFeatures(imageFile);
System.out.println("Loading label data from " + labelFile + " ...");
dataSet.loadLabels(labelFile);

ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(serFilePath));
oos.writeObject(dataSet);
try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(serFilePath));) {
oos.writeObject(dataSet);
}

return dataSet;
}
Expand Down Expand Up @@ -163,11 +176,10 @@ public void showImage(int index) {
/**
* 画像をファイルに保存する.
*
* @param dir ファイルを保存するディレクトリ
* @param prefix ファイル名の先頭に付けるプレフィックス
* @param index 画像の番号
* @param dir ファイルを保存するディレクトリ
* @param index 画像の番号
*/
public void saveImage(String dir, String prefix, int index) throws IOException {
public void saveImage(String dir, int index) throws IOException {
BufferedImage image = makeImage(index);
File file = new File(dir + "/" + prefix + "_" + String.format("%05d", index) + "_" + labels[index] + ".gif");
if (file.exists()) file.delete();
Expand Down Expand Up @@ -200,15 +212,16 @@ private BufferedImage makeImage(int index) {
* @param imageFile 画像データのファイル名
*/
private void loadFeatures(String imageFile) throws IOException {
DataInputStream is = new DataInputStream(new GZIPInputStream(new FileInputStream(BASE_PATH + imageFile)));
is.readInt();
numImages = is.readInt();
numDimensions = is.readInt() * is.readInt();
try (DataInputStream is = new DataInputStream(new GZIPInputStream(new FileInputStream(BASE_PATH + imageFile)));) {
is.readInt();
numImages = is.readInt();
numDimensions = is.readInt() * is.readInt();

features = new double[numImages][numDimensions];
for (int i = 0; i < numImages; i++) {
for (int j = 0; j < numDimensions; j++) {
features[i][j] = (double) is.readUnsignedByte() / 255.0;
features = new double[numImages][numDimensions];
for (int i = 0; i < numImages; i++) {
for (int j = 0; j < numDimensions; j++) {
features[i][j] = (double) is.readUnsignedByte() / 255.0;
}
}
}
}
Expand All @@ -219,14 +232,15 @@ private void loadFeatures(String imageFile) throws IOException {
* @param labelFile 正解ラベルデータのファイル名
*/
private void loadLabels(String labelFile) throws IOException {
DataInputStream is = new DataInputStream(new GZIPInputStream(new FileInputStream(BASE_PATH + labelFile)));
try (DataInputStream is = new DataInputStream(new GZIPInputStream(new FileInputStream(BASE_PATH + labelFile)));) {

is.readInt();
int numLabels = is.readInt();
is.readInt();
int numLabels = is.readInt();

labels = new int[numLabels];
for (int i = 0; i < numLabels; i++) {
labels[i] = is.readUnsignedByte();
labels = new int[numLabels];
for (int i = 0; i < numLabels; i++) {
labels[i] = is.readUnsignedByte();
}
}
}

Expand All @@ -242,21 +256,19 @@ private static void download(String baseUrl, String basePath, String fileName) t
System.out.println("Downloading " + baseUrl + fileName + " ...");
URL url = new URL(baseUrl + fileName);
URLConnection conn = url.openConnection();
InputStream in = conn.getInputStream();

File file = new File(basePath + fileName);
FileOutputStream out = new FileOutputStream(file, false);
byte[] data = new byte[1024];
while (true) {
int ret = in.read(data);
if (ret == -1) {
break;
}
try (InputStream in = conn.getInputStream();
FileOutputStream out = new FileOutputStream(file, false);) {
byte[] data = new byte[1024];
while (true) {
int ret = in.read(data);
if (ret == -1) {
break;
}

out.write(data, 0, ret);
out.write(data, 0, ret);
}
}
out.close();
in.close();
}
}
}

0 comments on commit d0a3796

Please sign in to comment.