Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9273] [MLlib] Add Convolutional Neural network to Spark MLlib #7609

Closed
wants to merge 1 commit into from

Conversation

hhbyyh
Copy link
Contributor

@hhbyyh hhbyyh commented Jul 23, 2015

jira: https://issues.apache.org/jira/browse/SPARK-9273
Here's an initial implementation of Convolutional Neural network based on Spark. The basic design follows the existing optimization framework (train a batch, sum the gradient, update weights) and should have good scalability. Currently it supports SGD training.

A typical driver program looks like

val topology = new CNNTopology
    topology.addLayer(CNNLayer.buildInputLayer(new Scale(28, 28)))
    topology.addLayer(CNNLayer.buildConvLayer(6, new Scale(5, 5)))
    topology.addLayer(CNNLayer.buildSampLayer(new Scale(2, 2)))
    topology.addLayer(CNNLayer.buildConvLayer(12, new Scale(5, 5)))
    topology.addLayer(CNNLayer.buildSampLayer(new Scale(2, 2)))
    topology.addLayer(CNNLayer.buildOutputLayer(10))
val cnn: CNN = new CNN(topology).setMaxIterations(5000).setMiniBatchSize(16)
val lines = sc.textFile("dataset/train.format", 8)
val data: RDD[LabeledPoint] = ...
cnn.train(data)

I tried on Mnist and it can get error rate less than 0.03%.
I'd like to collect some opinions and suggestions first since this is a large jira and give users a prototype to try CNN on Spark. I checked the new design from #1290 and current CNN can be integrated with slight effort.

@srowen
Copy link
Member

srowen commented Jul 23, 2015

Can you just host this outside Spark? that's the default rather than add it in directly. There are also already some JIRA about neural networks.

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Jul 23, 2015

@avulanov Can you please help take a look? Hope this can help exhibit the interface for CNN and hope it can integrate easily with your new design.

@SparkQA
Copy link

SparkQA commented Jul 23, 2015

Test build #38184 has finished for PR 7609 at commit 4ba2732.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
    • class CNNTopology
    • class Scale(var x: Int, var y: Int) extends Serializable
    • class InputCNNLayer extends CNNLayer

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Jul 23, 2015

Close the PR as requested by @srowen.
For any one of interests, please leave your comments. Thanks. cc @mengxr @witgo

@hhbyyh hhbyyh closed this Jul 23, 2015
@mengxr
Copy link
Contributor

mengxr commented Jul 23, 2015

@hhbyyh Please ping us on the JIRA when you plan to work on some major feature. We watch all JIRAs and may know some on-going work. For 1.5, @avulanov will submit a PR for neural network but only expose interface for multilayer perceptron. We can add more layer types in 1.6. I will ping you when he sends out the PR. It would be great if you can help review. Thanks!

@avulanov
Copy link
Contributor

@hhbyyh Sure, I will take a look at your CNN code. Indeed, it would be great to merge it with the new interface.

It would be really great if you could help review the PR for neural network: #7621.

@hhbyyh
Copy link
Contributor Author

hhbyyh commented Jul 25, 2015

@mengxr Sorry for the surprise. The code was actually written based on a customer requirement and I just shared it hoping it will help.

@srowen
Copy link
Member

srowen commented Jul 25, 2015

@hhbyyh I don't think that's a good reason to propose code in a PR here. You should just host it yourself somewhere. Adding code to Spark means a long-term commitment to support it, so it can't be a gallery of code snippets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants