Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed May 7, 2020
1 parent 0eb0f07 commit 31f8907
Showing 1 changed file with 25 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ private class BlockExpectationAggregator(
@transient private lazy val newMeansMat = DenseMatrix.zeros(numFeatures, k)
@transient private lazy val newCovsMat = DenseMatrix.zeros(covSize, k)
@transient private lazy val auxiliaryProbMat = DenseMatrix.zeros(blockSize, k)
@transient private lazy val auxiliaryMat = DenseMatrix.zeros(blockSize, numFeatures)
@transient private lazy val auxiliaryPDFMat = DenseMatrix.zeros(blockSize, numFeatures)
@transient private lazy val auxiliaryCovVec = Vectors.zeros(covSize).toDense

@transient private lazy val gaussians = {
Expand Down Expand Up @@ -852,20 +852,36 @@ private class BlockExpectationAggregator(
val size = matrix.numRows
require(weights.length == size)

val blas1 = BLAS.getBLAS(size)
val blas2 = BLAS.getBLAS(k)

val probMat = if (blockSize == size) auxiliaryProbMat else DenseMatrix.zeros(size, k)
require(!probMat.isTransposed)
java.util.Arrays.fill(probMat.values, EPSILON)

val mat = if (blockSize == size) auxiliaryMat else DenseMatrix.zeros(size, numFeatures)
val pdfMat = if (blockSize == size) auxiliaryPDFMat else DenseMatrix.zeros(size, numFeatures)
val probSumVec = Vectors.zeros(size).toDense
var j = 0
val blas1 = BLAS.getBLAS(size)
while (j < k) {
val pdfVec = gaussians(j).pdf(matrix, mat)
blas1.daxpy(size, bcWeights.value(j), pdfVec.values, 0, 1,
probMat.values, j * size, 1)
val pdfVec = gaussians(j).pdf(matrix, pdfMat)
val w = bcWeights.value(j)
blas1.daxpy(size, w, pdfVec.values, 1, probSumVec.values, 1)
blas1.daxpy(size, w, pdfVec.values, 0, 1, probMat.values, j * size, 1)
j += 1
}

var i = 0
while (i < size) {
val probSum = probSumVec(i)
val weight = weights(i)
blas2.dscal(k, weight / probSum, probMat.values, i, size)
blas2.daxpy(k, 1.0, probMat.values, i, size, newWeights, 0, 1)
newLogLikelihood += math.log(probSum) * weight
i += 1
}

BLAS.gemm(1.0, matrix.transpose, probMat, 1.0, newMeansMat)

// compute the cov vector for each row vector
val covVec = auxiliaryCovVec
val covVecIter = matrix match {
Expand All @@ -886,18 +902,11 @@ private class BlockExpectationAggregator(
}
}

val blas2 = BLAS.getBLAS(k)
covVecIter.zip(weights.iterator).zipWithIndex.foreach {
case ((covVec, weight), i) =>
val probSum = blas2.dasum(k, probMat.values, i, size)
blas2.dscal(k, weight / probSum, probMat.values, i, size)
blas2.daxpy(k, 1.0, probMat.values, i, size, newWeights, 0, 1)
BLAS.nativeBLAS.dger(covSize, k, 1.0, covVec.values, 0, 1,
probMat.values, i, size, newCovsMat.values, 0, covSize)
newLogLikelihood += math.log(probSum) * weight
covVecIter.zipWithIndex.foreach { case (covVec, i) =>
BLAS.nativeBLAS.dger(covSize, k, 1.0, covVec.values, 0, 1,
probMat.values, i, size, newCovsMat.values, 0, covSize)
}

BLAS.gemm(1.0, matrix.transpose, probMat, 1.0, newMeansMat)
totalCnt += size

this
Expand Down

0 comments on commit 31f8907

Please sign in to comment.