Skip to content

Commit

Permalink
初期コミット。
Browse files Browse the repository at this point in the history
  • Loading branch information
takishim committed Feb 15, 2017
1 parent c697e42 commit 45a8b0f
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 24 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
- データファイルをダウンロードする
- 特徴量データを読み込む
- 正解ラベルデータを読み込む
- オブジェクトのシリアライズ、デシリアライズ
- 文字の概形をテキストで表示する
- 画像をダイアログに表示する
- 画像をファイルに保存する
Expand Down
33 changes: 25 additions & 8 deletions src/main/java/to/kishimo/minist/ImageDataSet.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package to.kishimo.minist;

import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.*;
import java.util.zip.GZIPInputStream;

/**
* MNISTの手書き文字の画像データを扱うクラス.
*/
public class ImageDataSet {
public class ImageDataSet implements Serializable {
private static final long serialVersionUID = 1L;
private String fileName = "";
private int numImages;
private int numDimensions;
Expand All @@ -20,7 +18,7 @@ public class ImageDataSet {
*
* @param fileName 画像データのファイル名
*/
public ImageDataSet(String fileName) throws IOException {
private ImageDataSet(String fileName) throws IOException, ClassNotFoundException {
this.fileName = fileName;

File baseDir = new File(Const.BASE_PATH);
Expand All @@ -29,15 +27,35 @@ public ImageDataSet(String fileName) throws IOException {
}

Util.download(Const.BASE_URL, Const.BASE_PATH, fileName);
}

loadFeatures();
/**
* 画像データセットのインスタンスを作成する.
*
* @param fileName 画像データのファイル名
* @return 画像データセットのインスタンス
*/
public static ImageDataSet create(String fileName) throws IOException, ClassNotFoundException {
if (new File(Const.BASE_PATH + fileName + ".ser").exists()) {
System.out.println("Deserializing feature data from " + fileName + ".ser ...");
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(Const.BASE_PATH + fileName + ".ser"));
return (ImageDataSet) ois.readObject();
} else {
System.out.println("Loading feature data from " + fileName + " ...");
ImageDataSet imageDataSet = new ImageDataSet(fileName);
imageDataSet.loadFeatures();
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(Const.BASE_PATH + fileName + ".ser"));
oos.writeObject(imageDataSet);
return imageDataSet;
}
}

/**
* 画像数を取得する.
*
* @return 画像数
*/

public int getNumImages() {
return numImages;
}
Expand All @@ -64,7 +82,6 @@ public double[][] getFeatures() {
* 特徴量データを読み込む.
*/
private void loadFeatures() throws IOException {
System.out.println("Loading feature data from " + fileName + " ...");
DataInputStream is = new DataInputStream(new GZIPInputStream(new FileInputStream(Const.BASE_PATH + fileName)));
is.readInt();
numImages = is.readInt();
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/to/kishimo/minist/ImageViewer.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ public class ImageViewer {
* @param imageFile 画像ファイル名
* @param labelFile ラベルファイル名
*/
public ImageViewer(String imageFile, String labelFile) throws IOException {
ImageDataSet imageData = new ImageDataSet(imageFile);
public ImageViewer(String imageFile, String labelFile) throws IOException, ClassNotFoundException {
ImageDataSet imageData = ImageDataSet.create(imageFile);
images = imageData.getFeatures();

LabelDataSet labelData = new LabelDataSet(labelFile);
LabelDataSet labelData = LabelDataSet.create(labelFile);
labels = labelData.getLabels();
}

Expand Down
32 changes: 24 additions & 8 deletions src/main/java/to/kishimo/minist/LabelDataSet.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package to.kishimo.minist;

import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.*;
import java.util.zip.GZIPInputStream;

/**
* MNISTの手書き文字の正解ラベルデータを扱うクラス.
*/
public class LabelDataSet {
public class LabelDataSet implements Serializable {
private static final long serialVersionUID = 1L;
private String fileName = "";
private int numLabels;
private int[] labels;
Expand All @@ -19,7 +17,7 @@ public class LabelDataSet {
*
* @param fileName 正解ラベルデータのファイル名
*/
public LabelDataSet(String fileName) throws IOException {
private LabelDataSet(String fileName) throws IOException {
this.fileName = fileName;

File baseDir = new File(Const.BASE_PATH);
Expand All @@ -28,8 +26,27 @@ public LabelDataSet(String fileName) throws IOException {
}

Util.download(Const.BASE_URL, Const.BASE_PATH, fileName);
}

loadLabels();
/**
* 正解ラベルデータセットのインスタンスを作成する.
*
* @param fileName 正解ラベルデータのファイル名
* @return 正解ラベルデータセットのインスタンス
*/
public static LabelDataSet create(String fileName) throws IOException, ClassNotFoundException {
if (new File(Const.BASE_PATH + fileName + ".ser").exists()) {
System.out.println("Deserializing label data from " + fileName + ".ser ...");
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(Const.BASE_PATH + fileName + ".ser"));
return (LabelDataSet) ois.readObject();
} else {
System.out.println("Loading label data from " + fileName + " ...");
LabelDataSet labelDataSet = new LabelDataSet(fileName);
labelDataSet.loadLabels();
ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(Const.BASE_PATH + fileName + ".ser"));
oos.writeObject(labelDataSet);
return labelDataSet;
}
}

/**
Expand All @@ -54,7 +71,6 @@ public int[] getLabels() {
* 正解ラベルデータを読み込む.
*/
private void loadLabels() throws IOException {
System.out.println("Loading label data from " + fileName + " ...");
DataInputStream is = new DataInputStream(new GZIPInputStream(new FileInputStream(Const.BASE_PATH + fileName)));

is.readInt();
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/to/kishimo/minist/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
* 使用例.
*/
public class Main {
public static void main(String... args) throws IOException {
public static void main(String... args) throws IOException, ClassNotFoundException {
// トレーニングデータセット
ImageDataSet trainImages = new ImageDataSet(Const.TRAIN_IMAGE_FILE);
LabelDataSet trainLabels = new LabelDataSet(Const.TRAIN_LABEL_FILE);
ImageDataSet trainImages = ImageDataSet.create(Const.TRAIN_IMAGE_FILE);
LabelDataSet trainLabels = LabelDataSet.create(Const.TRAIN_LABEL_FILE);
ImageViewer trainViewer = new ImageViewer(trainImages.getFeatures(), trainLabels.getLabels());
// 概形を表示する
trainViewer.showImageAsText(0);
Expand All @@ -19,8 +19,8 @@ public static void main(String... args) throws IOException {
trainViewer.saveImage(Const.BASE_PATH, "train", 2);

// テストデータセット
ImageDataSet testImages = new ImageDataSet(Const.TEST_IMAGE_FILE);
LabelDataSet testLabels = new LabelDataSet(Const.TEST_LABEL_FILE);
ImageDataSet testImages = ImageDataSet.create(Const.TEST_IMAGE_FILE);
LabelDataSet testLabels = LabelDataSet.create(Const.TEST_LABEL_FILE);
ImageViewer testViewer = new ImageViewer(testImages.getFeatures(), testLabels.getLabels());
// 概形を表示する
testViewer.showImageAsText(0);
Expand Down

0 comments on commit 45a8b0f

Please sign in to comment.