Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Primary documents #87

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package ru.itmo.stand.service.impl.neighbours

import org.springframework.stereotype.Service
import ru.itmo.stand.service.preprocessing.StopWordRemover
import ru.itmo.stand.service.preprocessing.TextCleaner
import ru.itmo.stand.service.preprocessing.Tokenizer

@Service
class TokensPipelineExecutor(
private val stopWordRemover: StopWordRemover,
private val textCleaner: TextCleaner,
private val tokenizer: Tokenizer,
) {

fun execute(content: String): List<String> {
val cleanedContent = textCleaner.preprocess(content)
val tokens = tokenizer.preprocess(cleanedContent)
return stopWordRemover.preprocess(tokens)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import ru.itmo.stand.service.preprocessing.Tokenizer
import ru.itmo.stand.util.Window

@Service
class PreprocessingPipelineExecutor(
class WindowsPipelineExecutor(
private val standProperties: StandProperties,
private val contextSplitter: ContextSplitter,
private val stopWordRemover: StopWordRemover,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class InvertedIndexBuilder(
documentEmbeddingRepository.findByDocId(docId).embedding
}
NeighboursDocument(
token = contextualizedEmbedding.tokenWithEmbeddingId.split(ContextualizedEmbedding.TOKEN_AND_EMBEDDING_ID_SEPARATOR).first(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

выглядит кривовато, что мы одно поле парсим и в тот же объект складываем результат.
как proof of concept норм
но думаю, если решим мержить, то можно что-нибудь придумать с этим

tokenWithEmbeddingId = contextualizedEmbedding.tokenWithEmbeddingId,
docId = docId,
score = documentEmbedding.dot(contextualizedEmbedding.embedding),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ package ru.itmo.stand.service.impl.neighbours.indexing
import io.github.oshai.KotlinLogging
import org.springframework.stereotype.Service
import ru.itmo.stand.config.StandProperties
import ru.itmo.stand.service.impl.neighbours.PreprocessingPipelineExecutor
import ru.itmo.stand.service.impl.neighbours.WindowsPipelineExecutor
import ru.itmo.stand.service.model.Document
import ru.itmo.stand.util.Window
import ru.itmo.stand.util.createPath
import java.io.File

@Service
class WindowedTokenCreator(
private val preprocessingPipelineExecutor: PreprocessingPipelineExecutor,
private val windowsPipelineExecutor: WindowsPipelineExecutor,
private val standProperties: StandProperties,
) {

Expand Down Expand Up @@ -78,7 +78,7 @@ class WindowedTokenCreator(
}

fun create(document: Document): List<Window> {
return preprocessingPipelineExecutor.execute(document.content)
return windowsPipelineExecutor.execute(document.content)
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,35 @@ package ru.itmo.stand.service.impl.neighbours.search

import org.springframework.stereotype.Service
import ru.itmo.stand.service.bert.BertEmbeddingCalculator
import ru.itmo.stand.service.impl.neighbours.PreprocessingPipelineExecutor
import ru.itmo.stand.service.impl.neighbours.TokensPipelineExecutor
import ru.itmo.stand.service.impl.neighbours.WindowsPipelineExecutor
import ru.itmo.stand.storage.embedding.ContextualizedEmbeddingRepository
import ru.itmo.stand.storage.lucene.repository.neighbours.InvertedIndex

@Service
class NeighboursSearcher(
private val contextualizedEmbeddingRepository: ContextualizedEmbeddingRepository,
private val preprocessingPipelineExecutor: PreprocessingPipelineExecutor,
private val windowsPipelineExecutor: WindowsPipelineExecutor,
private val bertEmbeddingCalculator: BertEmbeddingCalculator,
private val invertedIndex: InvertedIndex,
private val tokensPipelineExecutor: TokensPipelineExecutor,
) {

fun search(query: String): List<String> {
val windows = preprocessingPipelineExecutor.execute(query)
val tokens = tokensPipelineExecutor.execute(query)

val primaryDocuments = invertedIndex.findByTokens(tokens)

val windows = windowsPipelineExecutor.execute(query)
val embeddings = bertEmbeddingCalculator.calculate(windows.map { it.toTranslatorInput() }.toTypedArray())

return embeddings.flatMap { embedding -> contextualizedEmbeddingRepository.findByVector(embedding.toTypedArray()) }
.let { contextualizedEmbeddings ->
val tokenWithEmbeddingIds = contextualizedEmbeddings.map { it.tokenWithEmbeddingId }
invertedIndex.findByTokenWithEmbeddingIds(tokenWithEmbeddingIds).groupingBy { it.docId }
val secondaryDocuments = invertedIndex.findByTokenWithEmbeddingIds(tokenWithEmbeddingIds)

sequenceOf(primaryDocuments, secondaryDocuments).flatten()
.groupingBy { it.docId }
.foldTo(HashMap(), 0f) { acc, doc -> acc + doc.score }
}.entries
.sortedByDescending { (_, score) -> score }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ru.itmo.stand.storage.lucene.model.neighbours

data class NeighboursDocument(
val token: String,
val tokenWithEmbeddingId: String,
val docId: String,
val score: Float,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.apache.lucene.document.StringField
import org.apache.lucene.index.ConcurrentMergeScheduler
import org.apache.lucene.index.IndexWriterConfig
import org.apache.lucene.index.Term
import org.apache.lucene.search.BooleanQuery
import org.apache.lucene.search.TermQuery
import org.springframework.stereotype.Repository
import ru.itmo.stand.config.StandProperties
Expand All @@ -29,6 +30,7 @@ class InvertedIndex(private val standProperties: StandProperties) : LuceneReposi

fun save(entity: NeighboursDocument) {
val document = Document()
document.add(StringField(NeighboursDocument::token.name, entity.token, YES))
document.add(StringField(NeighboursDocument::tokenWithEmbeddingId.name, entity.tokenWithEmbeddingId, YES))
document.add(StringField(NeighboursDocument::docId.name, entity.docId, YES))
document.add(StringField(NeighboursDocument::score.name, entity.score.toString(), YES))
Expand All @@ -39,14 +41,27 @@ class InvertedIndex(private val standProperties: StandProperties) : LuceneReposi
entities.forEach { save(it) }
}

fun findByTokens(tokens: Collection<String>): Sequence<NeighboursDocument> {
Copy link
Collaborator

@viacheslav-dobrynin viacheslav-dobrynin May 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

тут еще можно буст квери использовать, как вот в этой функции Артема:
image
наверное зря мы туда засунули analyze, можно обобщить функцию BoW query

val query = booleanQuery(tokens) { token ->
TermQuery(Term(NeighboursDocument::token.name, token))
}

return search(query)
}

fun findByTokenWithEmbeddingIds(tokenWithEmbeddingIds: Collection<String>): Sequence<NeighboursDocument> {
val query = booleanQuery(tokenWithEmbeddingIds) { tokenWithEmbeddingId ->
TermQuery(Term(NeighboursDocument::tokenWithEmbeddingId.name, tokenWithEmbeddingId))
}

return search(query)
}

private fun search(query: BooleanQuery): Sequence<NeighboursDocument> {
return searcher.searchAll(query)
.map {
NeighboursDocument(
it.get(NeighboursDocument::token.name),
it.get(NeighboursDocument::tokenWithEmbeddingId.name),
it.get(NeighboursDocument::docId.name),
it.get(NeighboursDocument::score.name).toFloat(),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package ru.itmo.stand.fixtures

import ru.itmo.stand.service.impl.neighbours.PreprocessingPipelineExecutor
import ru.itmo.stand.service.impl.neighbours.WindowsPipelineExecutor
import ru.itmo.stand.service.preprocessing.ContextSplitter
import ru.itmo.stand.service.preprocessing.StopWordRemover

fun preprocessingPipelineExecutor(): PreprocessingPipelineExecutor = PreprocessingPipelineExecutor(
fun preprocessingPipelineExecutor(): WindowsPipelineExecutor = WindowsPipelineExecutor(
standProperties(),
ContextSplitter(),
StopWordRemover(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.junit.jupiter.api.Test
import ru.itmo.stand.fixtures.preprocessingPipelineExecutor
import ru.itmo.stand.util.Window

class PreprocessingPipelineExecutorTest {
class WindowsPipelineExecutorTest {

private val preprocessingPipelineExecutor = preprocessingPipelineExecutor()

Expand Down