Skip to content

Commit

Permalink
Moving ANN to ML package. GradientDescent constructor is now spark pr…
Browse files Browse the repository at this point in the history
…ivate.
  • Loading branch information
avulanov committed Jul 29, 2015
1 parent 43b0ae2 commit 374bea6
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.mllib.ann
package org.apache.spark.ml.ann

import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV}
import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.mllib.ann
package org.apache.spark.ml.ann

import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => brzAxpy,
sum => Bsum}
Expand Down Expand Up @@ -741,12 +741,12 @@ private[ann] class ANNUpdater extends Updater {
}

/**
* Llib-style trainer class that trains a network given the data and topology
* MLlib-style trainer class that trains a network given the data and topology
* @param topology topology of ANN
* @param inputSize input size
* @param outputSize output size
*/
class FeedForwardTrainer (topology: Topology, val inputSize: Int,
private[ml] class FeedForwardTrainer (topology: Topology, val inputSize: Int,
val outputSize: Int) extends Serializable {

// TODO: what if we need to pass random seed?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@

package org.apache.spark.ml.classification

import breeze.linalg.{argmax => Bargmax}

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed}
import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.ann.{FeedForwardTrainer, FeedForwardTopology}
import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology}
import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.DataFrame
Expand Down Expand Up @@ -88,7 +86,6 @@ with HasSeed with HasMaxIter with HasTol {
setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 1)
}


/** Label to vector converter. */
private object LabelConverter {
// TODO: Use OneHotEncoder instead
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector}
* @param gradient Gradient function to be used.
* @param updater Updater to be used to update weights after every iteration.
*/
class GradientDescent private[mllib] (private var gradient: Gradient, private var updater: Updater)
class GradientDescent private[spark] (private var gradient: Gradient, private var updater: Updater)
extends Optimizer with Logging {

private var stepSize: Double = 1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.mllib.ann
package org.apache.spark.ml.ann

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
Expand Down

0 comments on commit 374bea6

Please sign in to comment.