diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md
index 539cbc1b3163a..a72680d52a26c 100644
--- a/docs/mllib-dimensionality-reduction.md
+++ b/docs/mllib-dimensionality-reduction.md
@@ -76,13 +76,14 @@ Refer to the [`SingularValueDecomposition` Java docs](api/java/org/apache/spark/
The same code applies to `IndexedRowMatrix` if `U` is defined as an
`IndexedRowMatrix`.
+
+
+Refer to the [`SingularValueDecomposition` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.SingularValueDecomposition) for details on the API.
-In order to run the above application, follow the instructions
-provided in the [Self-Contained
-Applications](quick-start.html#self-contained-applications) section of the Spark
-quick-start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
+{% include_example python/mllib/svd_example.py %}
+The same code applies to `IndexedRowMatrix` if `U` is defined as an
+`IndexedRowMatrix`.
@@ -118,17 +119,21 @@ Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feat
The following code demonstrates how to compute principal components on a `RowMatrix`
and use them to project the vectors into a low-dimensional space.
-The number of columns should be small, e.g, less than 1000.
Refer to the [`RowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) for details on the API.
{% include_example java/org/apache/spark/examples/mllib/JavaPCAExample.java %}
-
-In order to run the above application, follow the instructions
-provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
-section of the Spark
-quick-start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
+
+
+The following code demonstrates how to compute principal components on a `RowMatrix`
+and use them to project the vectors into a low-dimensional space.
+
+Refer to the [`RowMatrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.RowMatrix) for details on the API.
+
+{% include_example python/mllib/pca_rowmatrix_example.py %}
+
+
+
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java
index 3077f557ef886..0a7dc621e1110 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java
@@ -18,7 +18,8 @@
package org.apache.spark.examples.mllib;
// $example on$
-import java.util.LinkedList;
+import java.util.Arrays;
+import java.util.List;
// $example off$
import org.apache.spark.SparkConf;
@@ -39,21 +40,25 @@ public class JavaPCAExample {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("PCA Example");
SparkContext sc = new SparkContext(conf);
+ JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc);
// $example on$
- double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}};
- LinkedList rowsList = new LinkedList<>();
- for (int i = 0; i < array.length; i++) {
- Vector currentRow = Vectors.dense(array[i]);
- rowsList.add(currentRow);
- }
- JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList);
+ List data = Arrays.asList(
+ Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}),
+ Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
+ Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
+ );
+
+ JavaRDD rows = jsc.parallelize(data);
// Create a RowMatrix from JavaRDD.
RowMatrix mat = new RowMatrix(rows.rdd());
- // Compute the top 3 principal components.
- Matrix pc = mat.computePrincipalComponents(3);
+ // Compute the top 4 principal components.
+ // Principal components are stored in a local dense matrix.
+ Matrix pc = mat.computePrincipalComponents(4);
+
+ // Project the rows to the linear space spanned by the top 4 principal components.
RowMatrix projected = mat.multiply(pc);
// $example off$
Vector[] collectPartitions = (Vector[])projected.rows().collect();
@@ -61,6 +66,6 @@ public static void main(String[] args) {
for (Vector vector : collectPartitions) {
System.out.println("\t" + vector);
}
- sc.stop();
+ jsc.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java
index 3730e60f68803..802be3960a337 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java
@@ -18,7 +18,8 @@
package org.apache.spark.examples.mllib;
// $example on$
-import java.util.LinkedList;
+import java.util.Arrays;
+import java.util.List;
// $example off$
import org.apache.spark.SparkConf;
@@ -43,22 +44,22 @@ public static void main(String[] args) {
JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc);
// $example on$
- double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}};
- LinkedList rowsList = new LinkedList<>();
- for (int i = 0; i < array.length; i++) {
- Vector currentRow = Vectors.dense(array[i]);
- rowsList.add(currentRow);
- }
- JavaRDD rows = jsc.parallelize(rowsList);
+ List data = Arrays.asList(
+ Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}),
+ Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
+ Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
+ );
+
+ JavaRDD rows = jsc.parallelize(data);
// Create a RowMatrix from JavaRDD.
RowMatrix mat = new RowMatrix(rows.rdd());
- // Compute the top 3 singular values and corresponding singular vectors.
- SingularValueDecomposition svd = mat.computeSVD(3, true, 1.0E-9d);
- RowMatrix U = svd.U();
- Vector s = svd.s();
- Matrix V = svd.V();
+ // Compute the top 5 singular values and corresponding singular vectors.
+ SingularValueDecomposition svd = mat.computeSVD(5, true, 1.0E-9d);
+ RowMatrix U = svd.U(); // The U factor is a RowMatrix.
+ Vector s = svd.s(); // The singular values are stored in a local dense vector.
+ Matrix V = svd.V(); // The V factor is a local dense matrix.
// $example off$
Vector[] collectPartitions = (Vector[]) U.rows().collect();
System.out.println("U factor is:");
diff --git a/examples/src/main/python/mllib/pca_rowmatrix_example.py b/examples/src/main/python/mllib/pca_rowmatrix_example.py
new file mode 100644
index 0000000000000..49b9b1bbe08e9
--- /dev/null
+++ b/examples/src/main/python/mllib/pca_rowmatrix_example.py
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import SparkContext
+# $example on$
+from pyspark.mllib.linalg import Vectors
+from pyspark.mllib.linalg.distributed import RowMatrix
+# $example off$
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="PythonPCAOnRowMatrixExample")
+
+ # $example on$
+ rows = sc.parallelize([
+ Vectors.sparse(5, {1: 1.0, 3: 7.0}),
+ Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
+ Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
+ ])
+
+ mat = RowMatrix(rows)
+ # Compute the top 4 principal components.
+ # Principal components are stored in a local dense matrix.
+ pc = mat.computePrincipalComponents(4)
+
+ # Project the rows to the linear space spanned by the top 4 principal components.
+ projected = mat.multiply(pc)
+ # $example off$
+ collected = projected.rows.collect()
+ print("Projected Row Matrix of principal component:")
+ for vector in collected:
+ print(vector)
+ sc.stop()
diff --git a/examples/src/main/python/mllib/svd_example.py b/examples/src/main/python/mllib/svd_example.py
new file mode 100644
index 0000000000000..5b220fdb3fd67
--- /dev/null
+++ b/examples/src/main/python/mllib/svd_example.py
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import SparkContext
+# $example on$
+from pyspark.mllib.linalg import Vectors
+from pyspark.mllib.linalg.distributed import RowMatrix
+# $example off$
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="PythonSVDExample")
+
+ # $example on$
+ rows = sc.parallelize([
+ Vectors.sparse(5, {1: 1.0, 3: 7.0}),
+ Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
+ Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
+ ])
+
+ mat = RowMatrix(rows)
+
+ # Compute the top 5 singular values and corresponding singular vectors.
+ svd = mat.computeSVD(5, computeU=True)
+ U = svd.U # The U factor is a RowMatrix.
+ s = svd.s # The singular values are stored in a local dense vector.
+ V = svd.V # The V factor is a local dense matrix.
+ # $example off$
+ collected = U.rows.collect()
+ print("U factor is:")
+ for vector in collected:
+ print(vector)
+ print("Singular values are: %s" % s)
+ print("V factor is:\n%s" % V)
+ sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala
index a137ba2a2f9d3..da43a8d9c7e80 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala
@@ -39,9 +39,9 @@ object PCAOnRowMatrixExample {
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
- val dataRDD = sc.parallelize(data, 2)
+ val rows = sc.parallelize(data)
- val mat: RowMatrix = new RowMatrix(dataRDD)
+ val mat: RowMatrix = new RowMatrix(rows)
// Compute the top 4 principal components.
// Principal components are stored in a local dense matrix.
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala
index b286a3f7b9096..769ae2a3a88b1 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala
@@ -28,6 +28,9 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.RowMatrix
// $example off$
+/**
+ * Example for SingularValueDecomposition.
+ */
object SVDExample {
def main(args: Array[String]): Unit = {
@@ -41,15 +44,15 @@ object SVDExample {
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
- val dataRDD = sc.parallelize(data, 2)
+ val rows = sc.parallelize(data)
- val mat: RowMatrix = new RowMatrix(dataRDD)
+ val mat: RowMatrix = new RowMatrix(rows)
// Compute the top 5 singular values and corresponding singular vectors.
val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(5, computeU = true)
val U: RowMatrix = svd.U // The U factor is a RowMatrix.
- val s: Vector = svd.s // The singular values are stored in a local dense vector.
- val V: Matrix = svd.V // The V factor is a local dense matrix.
+ val s: Vector = svd.s // The singular values are stored in a local dense vector.
+ val V: Matrix = svd.V // The V factor is a local dense matrix.
// $example off$
val collect = U.rows.collect()
println("U factor is:")
diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py
index 600655c912ca6..4cb802514be52 100644
--- a/python/pyspark/mllib/linalg/distributed.py
+++ b/python/pyspark/mllib/linalg/distributed.py
@@ -28,14 +28,13 @@
from pyspark import RDD, since
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import _convert_to_vector, Matrix, QRDecomposition
+from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition
from pyspark.mllib.stat import MultivariateStatisticalSummary
from pyspark.storagelevel import StorageLevel
-__all__ = ['DistributedMatrix', 'RowMatrix', 'IndexedRow',
- 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix',
- 'BlockMatrix']
+__all__ = ['BlockMatrix', 'CoordinateMatrix', 'DistributedMatrix', 'IndexedRow',
+ 'IndexedRowMatrix', 'MatrixEntry', 'RowMatrix', 'SingularValueDecomposition']
class DistributedMatrix(object):
@@ -301,6 +300,136 @@ def tallSkinnyQR(self, computeQ=False):
R = decomp.call("R")
return QRDecomposition(Q, R)
+ @since('2.2.0')
+ def computeSVD(self, k, computeU=False, rCond=1e-9):
+ """
+ Computes the singular value decomposition of the RowMatrix.
+
+ The given row matrix A of dimension (m X n) is decomposed into
+ U * s * V'T where
+
+ * U: (m X k) (left singular vectors) is a RowMatrix whose
+ columns are the eigenvectors of (A X A')
+ * s: DenseVector consisting of square root of the eigenvalues
+ (singular values) in descending order.
+ * v: (n X k) (right singular vectors) is a Matrix whose columns
+ are the eigenvectors of (A' X A)
+
+ For more specific details on implementation, please refer
+ the Scala documentation.
+
+ :param k: Number of leading singular values to keep (`0 < k <= n`).
+ It might return less than k if there are numerically zero singular values
+ or there are not enough Ritz values converged before the maximum number of
+ Arnoldi update iterations is reached (in case that matrix A is ill-conditioned).
+ :param computeU: Whether or not to compute U. If set to be
+ True, then U is computed by A * V * s^-1
+ :param rCond: Reciprocal condition number. All singular values
+ smaller than rCond * s[0] are treated as zero
+ where s[0] is the largest singular value.
+ :returns: :py:class:`SingularValueDecomposition`
+
+ >>> rows = sc.parallelize([[3, 1, 1], [-1, 3, 1]])
+ >>> rm = RowMatrix(rows)
+
+ >>> svd_model = rm.computeSVD(2, True)
+ >>> svd_model.U.rows.collect()
+ [DenseVector([-0.7071, 0.7071]), DenseVector([-0.7071, -0.7071])]
+ >>> svd_model.s
+ DenseVector([3.4641, 3.1623])
+ >>> svd_model.V
+ DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0)
+ """
+ j_model = self._java_matrix_wrapper.call(
+ "computeSVD", int(k), bool(computeU), float(rCond))
+ return SingularValueDecomposition(j_model)
+
+ @since('2.2.0')
+ def computePrincipalComponents(self, k):
+ """
+ Computes the k principal components of the given row matrix
+
+ .. note:: This cannot be computed on matrices with more than 65535 columns.
+
+ :param k: Number of principal components to keep.
+ :returns: :py:class:`pyspark.mllib.linalg.DenseMatrix`
+
+ >>> rows = sc.parallelize([[1, 2, 3], [2, 4, 5], [3, 6, 1]])
+ >>> rm = RowMatrix(rows)
+
+ >>> # Returns the two principal components of rm
+ >>> pca = rm.computePrincipalComponents(2)
+ >>> pca
+ DenseMatrix(3, 2, [-0.349, -0.6981, 0.6252, -0.2796, -0.5592, -0.7805], 0)
+
+ >>> # Transform into new dimensions with the greatest variance.
+ >>> rm.multiply(pca).rows.collect() # doctest: +NORMALIZE_WHITESPACE
+ [DenseVector([0.1305, -3.7394]), DenseVector([-0.3642, -6.6983]), \
+ DenseVector([-4.6102, -4.9745])]
+ """
+ return self._java_matrix_wrapper.call("computePrincipalComponents", k)
+
+ @since('2.2.0')
+ def multiply(self, matrix):
+ """
+ Multiply this matrix by a local dense matrix on the right.
+
+ :param matrix: a local dense matrix whose number of rows must match the number of columns
+ of this matrix
+ :returns: :py:class:`RowMatrix`
+
+ >>> rm = RowMatrix(sc.parallelize([[0, 1], [2, 3]]))
+ >>> rm.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect()
+ [DenseVector([2.0, 3.0]), DenseVector([6.0, 11.0])]
+ """
+ if not isinstance(matrix, DenseMatrix):
+ raise ValueError("Only multiplication with DenseMatrix "
+ "is supported.")
+ j_model = self._java_matrix_wrapper.call("multiply", matrix)
+ return RowMatrix(j_model)
+
+
+class SingularValueDecomposition(JavaModelWrapper):
+ """
+ Represents singular value decomposition (SVD) factors.
+
+ .. versionadded:: 2.2.0
+ """
+
+ @property
+ @since('2.2.0')
+ def U(self):
+ """
+ Returns a distributed matrix whose columns are the left
+ singular vectors of the SingularValueDecomposition if computeU was set to be True.
+ """
+ u = self.call("U")
+ if u is not None:
+ mat_name = u.getClass().getSimpleName()
+ if mat_name == "RowMatrix":
+ return RowMatrix(u)
+ elif mat_name == "IndexedRowMatrix":
+ return IndexedRowMatrix(u)
+ else:
+ raise TypeError("Expected RowMatrix/IndexedRowMatrix got %s" % mat_name)
+
+ @property
+ @since('2.2.0')
+ def s(self):
+ """
+ Returns a DenseVector with singular values in descending order.
+ """
+ return self.call("s")
+
+ @property
+ @since('2.2.0')
+ def V(self):
+ """
+ Returns a DenseMatrix whose columns are the right singular
+ vectors of the SingularValueDecomposition.
+ """
+ return self.call("V")
+
class IndexedRow(object):
"""
@@ -528,6 +657,68 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024):
colsPerBlock)
return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock)
+ @since('2.2.0')
+ def computeSVD(self, k, computeU=False, rCond=1e-9):
+ """
+ Computes the singular value decomposition of the IndexedRowMatrix.
+
+ The given row matrix A of dimension (m X n) is decomposed into
+ U * s * V'T where
+
+ * U: (m X k) (left singular vectors) is a IndexedRowMatrix
+ whose columns are the eigenvectors of (A X A')
+ * s: DenseVector consisting of square root of the eigenvalues
+ (singular values) in descending order.
+ * v: (n X k) (right singular vectors) is a Matrix whose columns
+ are the eigenvectors of (A' X A)
+
+ For more specific details on implementation, please refer
+ the scala documentation.
+
+ :param k: Number of leading singular values to keep (`0 < k <= n`).
+ It might return less than k if there are numerically zero singular values
+ or there are not enough Ritz values converged before the maximum number of
+ Arnoldi update iterations is reached (in case that matrix A is ill-conditioned).
+ :param computeU: Whether or not to compute U. If set to be
+ True, then U is computed by A * V * s^-1
+ :param rCond: Reciprocal condition number. All singular values
+ smaller than rCond * s[0] are treated as zero
+ where s[0] is the largest singular value.
+ :returns: SingularValueDecomposition object
+
+ >>> rows = [(0, (3, 1, 1)), (1, (-1, 3, 1))]
+ >>> irm = IndexedRowMatrix(sc.parallelize(rows))
+ >>> svd_model = irm.computeSVD(2, True)
+ >>> svd_model.U.rows.collect() # doctest: +NORMALIZE_WHITESPACE
+ [IndexedRow(0, [-0.707106781187,0.707106781187]),\
+ IndexedRow(1, [-0.707106781187,-0.707106781187])]
+ >>> svd_model.s
+ DenseVector([3.4641, 3.1623])
+ >>> svd_model.V
+ DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0)
+ """
+ j_model = self._java_matrix_wrapper.call(
+ "computeSVD", int(k), bool(computeU), float(rCond))
+ return SingularValueDecomposition(j_model)
+
+ @since('2.2.0')
+ def multiply(self, matrix):
+ """
+ Multiply this matrix by a local dense matrix on the right.
+
+ :param matrix: a local dense matrix whose number of rows must match the number of columns
+ of this matrix
+ :returns: :py:class:`IndexedRowMatrix`
+
+ >>> mat = IndexedRowMatrix(sc.parallelize([(0, (0, 1)), (1, (2, 3))]))
+ >>> mat.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect()
+ [IndexedRow(0, [2.0,3.0]), IndexedRow(1, [6.0,11.0])]
+ """
+ if not isinstance(matrix, DenseMatrix):
+ raise ValueError("Only multiplication with DenseMatrix "
+ "is supported.")
+ return IndexedRowMatrix(self._java_matrix_wrapper.call("multiply", matrix))
+
class MatrixEntry(object):
"""
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 523b3f1113317..1037bab7f1088 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -23,6 +23,7 @@
import sys
import tempfile
import array as pyarray
+from math import sqrt
from time import time, sleep
from shutil import rmtree
@@ -54,6 +55,7 @@
from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
+from pyspark.mllib.linalg.distributed import RowMatrix
from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
from pyspark.mllib.recommendation import Rating
from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
@@ -1699,6 +1701,67 @@ def test_binary_term_freqs(self):
": expected " + str(expected[i]) + ", got " + str(output[i]))
+class DimensionalityReductionTests(MLlibTestCase):
+
+ denseData = [
+ Vectors.dense([0.0, 1.0, 2.0]),
+ Vectors.dense([3.0, 4.0, 5.0]),
+ Vectors.dense([6.0, 7.0, 8.0]),
+ Vectors.dense([9.0, 0.0, 1.0])
+ ]
+ sparseData = [
+ Vectors.sparse(3, [(1, 1.0), (2, 2.0)]),
+ Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]),
+ Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]),
+ Vectors.sparse(3, [(0, 9.0), (2, 1.0)])
+ ]
+
+ def assertEqualUpToSign(self, vecA, vecB):
+ eq1 = vecA - vecB
+ eq2 = vecA + vecB
+ self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6)
+
+ def test_svd(self):
+ denseMat = RowMatrix(self.sc.parallelize(self.denseData))
+ sparseMat = RowMatrix(self.sc.parallelize(self.sparseData))
+ m = 4
+ n = 3
+ for mat in [denseMat, sparseMat]:
+ for k in range(1, 4):
+ rm = mat.computeSVD(k, computeU=True)
+ self.assertEqual(rm.s.size, k)
+ self.assertEqual(rm.U.numRows(), m)
+ self.assertEqual(rm.U.numCols(), k)
+ self.assertEqual(rm.V.numRows, n)
+ self.assertEqual(rm.V.numCols, k)
+
+ # Test that U returned is None if computeU is set to False.
+ self.assertEqual(mat.computeSVD(1).U, None)
+
+ # Test that low rank matrices cannot have number of singular values
+ # greater than a limit.
+ rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1))))
+ self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1)
+
+ def test_pca(self):
+ expected_pcs = array([
+ [0.0, 1.0, 0.0],
+ [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0],
+ [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0]
+ ])
+ n = 3
+ denseMat = RowMatrix(self.sc.parallelize(self.denseData))
+ sparseMat = RowMatrix(self.sc.parallelize(self.sparseData))
+ for mat in [denseMat, sparseMat]:
+ for k in range(1, 4):
+ pcs = mat.computePrincipalComponents(k)
+ self.assertEqual(pcs.numRows, n)
+ self.assertEqual(pcs.numCols, k)
+
+ # We can just test the updated principal component for equality.
+ self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1])
+
+
if __name__ == "__main__":
from pyspark.mllib.tests import *
if not _have_scipy: