Skip to content

Commit

Permalink
[orx-kdtree] add k-nearest neighbor search to kd-tree (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
ylegall authored Oct 10, 2021
1 parent a3bb1b2 commit b8ee4d8
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 1 deletion.
35 changes: 35 additions & 0 deletions orx-kdtree/src/demo/kotlin/DemoKNearestNeighbour01.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import org.openrndr.application
import org.openrndr.color.ColorRGBa
import org.openrndr.extra.kdtree.buildKDTree
import org.openrndr.extra.kdtree.findKNearest
import org.openrndr.extra.kdtree.vector2Mapper
import org.openrndr.math.Vector2
import org.openrndr.shape.LineSegment

fun main() {
application {

configure {
width = 1080
height = 720
}

program {
val points = MutableList(1000) {
Vector2(Math.random() * width, Math.random() * height)
}
val tree = buildKDTree(points, 2, ::vector2Mapper)

extend {
drawer.circles(points, 5.0)

val kNearest = findKNearest(tree, mouse.position, k=7, dimensions = 2, ::vector2Mapper)
drawer.fill = ColorRGBa.RED
drawer.stroke = ColorRGBa.RED
drawer.strokeWeight = 2.0
drawer.circles(kNearest, 7.0)
drawer.lineSegments(kNearest.map { LineSegment(mouse.position, it) })
}
}
}
}
49 changes: 48 additions & 1 deletion orx-kdtree/src/main/kotlin/KDTree.kt
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ private fun <T> sqrDistance(left: T, right: T, dimensions: Int, mapper: (T, Int)
fun <T> findAllNodes(root: KDTreeNode<T>): List<KDTreeNode<T>> {
val stack = Stack<KDTreeNode<T>>()
val all = ArrayList<KDTreeNode<T>>()
stack.empty()
stack.push(root)
while (!stack.isEmpty()) {
val node = stack.pop()
Expand All @@ -184,6 +183,54 @@ fun <T> findAllNodes(root: KDTreeNode<T>): List<KDTreeNode<T>> {
}


fun <T> findKNearest(
root: KDTreeNode<T>,
item: T,
k: Int,
dimensions: Int,
mapper: (T, Int) -> Double
): List<T> {
// max-heap with size k
val queue = PriorityQueue<Pair<KDTreeNode<T>, Double>>(k + 1) {
nodeA, nodeB -> compareValues(nodeB.second, nodeA.second)
}

fun nearest(node: KDTreeNode<T>?, item: T) {
if (node != null) {
val dimensionValue = mapper(item, node.dimension)
val route: Int = if (dimensionValue < node.median) {
nearest(node.children[0], item)
0
} else {
nearest(node.children[1], item)
1
}

val distance = sqrDistance(item, node.item
?: throw IllegalStateException("item is null"), dimensions, mapper)

if (queue.size < k || distance < queue.peek().second) {
queue.add(Pair(node, distance))
if (queue.size > k) {
queue.poll()
}
}

val d = abs(node.median - dimensionValue)
if (d * d < queue.peek().second || queue.size < k) {
nearest(node.children[1 - route], item)
}
}
}

nearest(root, item)

return generateSequence { queue.poll() }
.map { it.first.item }
.filterNotNull()
.toList().reversed()
}

fun <T> findNearest(root: KDTreeNode<T>, item: T, dimensions: Int, mapper: (T, Int) -> Double): T? {
var nearest = java.lang.Double.POSITIVE_INFINITY
var nearestArg: KDTreeNode<T>? = null
Expand Down

0 comments on commit b8ee4d8

Please sign in to comment.