Skip to content

Commit

Permalink
fix: join chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
viacheslav-dobrynin committed Apr 7, 2023
1 parent 596b2ed commit f82828b
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 45 deletions.
4 changes: 2 additions & 2 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ tasks.withType<Test> {
}

jmh {
includes.set(listOf("BertHyperParameterBenchmark")) // include pattern (regular expression) for benchmarks to be executed
includes.set(listOf(".*")) // include pattern (regular expression) for benchmarks to be executed
warmupIterations.set(2) // Number of warmup iterations to do
iterations.set(2) // Number of measurement iterations to do
fork.set(2) // How many times to forks a single benchmark. Use 0 to disable forking altogether
zip64.set(true) // is used for big archives (more than 65535 entries)
resultsFile.set(project.file("${project.buildDir}/reports/jmh/results.txt")) // results file
resultsFile.set(project.file("${project.buildDir}/outputs/jmh/results.txt")) // results file
}

ktlint {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ open class BertHyperParameterBenchmark {

private val testContents = generateWindows()

/*
@Benchmark
fun singleThreadBenchmark_100(): Array<FloatArray> {
return singleThreadBenchmark(100)
Expand All @@ -75,14 +74,12 @@ open class BertHyperParameterBenchmark {
fun singleThreadBenchmark_2000(): Array<FloatArray> {
return singleThreadBenchmark(2000)
}
*/

// @Benchmark
// fun singleThreadBenchmark_5000(): Array<FloatArray> {
// return singleThreadBenchmark(5000)
// }
@Benchmark
fun singleThreadBenchmark_5000(): Array<FloatArray> {
return singleThreadBenchmark(5000)
}

/*
@Benchmark
fun singleThreadBenchmark_10_000(): Array<FloatArray> {
return singleThreadBenchmark(10_000)
Expand All @@ -106,7 +103,7 @@ open class BertHyperParameterBenchmark {
@Benchmark
fun singleThreadBenchmark_50_000(): Array<FloatArray> {
return singleThreadBenchmark(50_000)
}*/
}

@Benchmark
fun multithreadedBenchmark_4_5000(): Array<FloatArray> {
Expand All @@ -120,28 +117,30 @@ open class BertHyperParameterBenchmark {
}.toTypedArray()
}

private fun multithreadedBenchmark(batchSize: Int, numThreads: Int): Array<FloatArray> = runBlocking(Dispatchers.Default) {
val counter = AtomicInteger(0)
val chan = Channel<List<String>>(numThreads)
repeat(numThreads) {
launch {
val predictor = tinyModel.newPredictor()
for (data in chan) {
counter.incrementAndGet()
predictor.predict(data.toTypedArray())
private fun multithreadedBenchmark(batchSize: Int, numThreads: Int): Array<FloatArray> =
runBlocking(Dispatchers.Default) {
val counter = AtomicInteger(0)
val chan = Channel<List<String>>(numThreads)
repeat(numThreads) {
launch {
val predictor = tinyModel.newPredictor()
for (data in chan) {
counter.incrementAndGet()
predictor.predict(data.toTypedArray())
}
predictor.close()
}
predictor.close()
}
}

for (data in testContents.chunked(batchSize)) {
chan.send(data)
}
while (!chan.isEmpty) {}
chan.close()
for (data in testContents.chunked(batchSize)) {
chan.send(data)
}
while (!chan.isEmpty) {
}
chan.close()

arrayOf()
}
arrayOf()
}

private fun generateWindows(count: Int = 50_000): List<String> {
val result = mutableListOf<String>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SaveInBatchCommand(

override fun run() {
val contents = Paths.get(contentFile.path)
.bufferedReader(bufferSize = standProperties.app.fileLoadBufferSizeKb * 1024)
.bufferedReader()
.lineSequence()

val seconds = measureTimeSeconds {
Expand Down
1 change: 0 additions & 1 deletion src/main/kotlin/ru/itmo/stand/config/StandProperties.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ data class StandProperties @ConstructorBinding constructor(
val basePath: String,
val bertMultiToken: BertMultiToken,
val neighboursAlgorithm: NeighboursAlgorithm,
val fileLoadBufferSizeKb: Int,
)

data class BertMultiToken(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class BertEmbeddingCalculator(
private val standProperties: StandProperties,
) {

// TODO: configure to return vector for middle token
private val predictor by lazy {
bertModelLoader.loadModel(standProperties.app.neighboursAlgorithm.bertModelType).newPredictor()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class VectorIndexBuilder(
private val log = LoggerFactory.getLogger(javaClass)

fun index(windowedTokensFile: File) {
log.info("Starting vector indexing")
val windowsByTokenPairs = readWindowsByTokenPairs(windowedTokensFile)

log.info("starting vector indexing")
val counter = AtomicInteger(0)
val clusterSizes = AtomicInteger(0)
val windowsCount = AtomicInteger(0)
Expand All @@ -37,11 +37,11 @@ class VectorIndexBuilder(
counter.incrementAndGet()
}

log.info("token count: ${counter.get()}")
log.info("cluster sizes: ${clusterSizes.get()}")
log.info("windows count: ${windowsCount.get()}")
log.info("mean windows per token: ${windowsCount.get().toDouble() / counter.get().toDouble()}")
log.info("mean cluster size is ${clusterSizes.get() / counter.get().toFloat()}")
log.info("Token count: ${counter.get()}")
log.info("Cluster sizes: ${clusterSizes.get()}")
log.info("Windows count: ${windowsCount.get()}")
log.info("Mean windows per token: ${windowsCount.get().toDouble() / counter.get().toDouble()}")
log.info("Mean cluster size is ${clusterSizes.get() / counter.get().toFloat()}")
}

private fun readWindowsByTokenPairs(windowedTokensFile: File) = windowedTokensFile
Expand All @@ -53,7 +53,7 @@ class VectorIndexBuilder(
val windows = tokenAndWindows[1]
.split(WINDOWS_SEPARATOR)
.filter { it.isNotBlank() }
.take(1000)
.take(1000) // TODO: configure this value
token to windows.map { it.split(WINDOW_DOC_IDS_SEPARATOR).first() }
}

Expand All @@ -64,9 +64,9 @@ class VectorIndexBuilder(

val doubleEmb = embeddings.toDoubleArray()

val clusterModel = XMeans.fit(doubleEmb, 8)
val clusterModel = XMeans.fit(doubleEmb, 8) // TODO: configure this value

log.info("{} got centroids {}", token.first, clusterModel.k)
log.info("{} got {} centroids", token.first, clusterModel.k)

val centroids = clusterModel.centroids

Expand Down
5 changes: 2 additions & 3 deletions src/main/kotlin/ru/itmo/stand/util/Concurrency.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ fun <T> processParallel(data: Sequence<T>, numWorkers: Int, log: Logger, action:
data
.onEachIndexed { index, _ -> if (index % 10 == 0) log.info("Elements processed: {}", index) }
.chunked(numWorkers)
.mapIndexed { index, chunk ->
.forEach { chunk ->
chunk.map {
launch {
action(it)
}
}
}.joinAll()
}
.forEach { it.joinAll() }
}
1 change: 0 additions & 1 deletion src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ stand:
elasticsearch.host-and-port: localhost:9200
app:
base-path: "."
file-load-buffer-size-kb: 512
neighbours-algorithm:
token-batch-size: 5
bert-model-type: TINY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@ fun standProperties(
basePath: String = ".",
bertMultiTokenBatchSize: Int = 5,
neighboursAlgorithmBatchSize: Int = 5,
fileLoadBufferSizeMb: Int = 512,
) = StandProperties(
ElasticsearchProperties(elkHostAndPort),
ApplicationProperties(
basePath,
BertMultiToken(bertMultiTokenBatchSize),
NeighboursAlgorithm(neighboursAlgorithmBatchSize, BertModelType.BASE, 500_000),
fileLoadBufferSizeMb,
),
)

0 comments on commit f82828b

Please sign in to comment.