The k-d Tree: Splitting Space to Search It in O(log n)

May 26, 202630 min read
dsaalgorithmsinterview-prepdata-structures
The k-d Tree: Splitting Space to Search It in O(log n)
TL;DR
  • k-d trees split k-dimensional space with axis-aligned hyperplanes, giving O(log n) average nearest neighbor search without checking every point
  • Build time is O(n log n): find the median at each level, cycle through axes, recurse on each half — linear median-finding keeps each level O(n)
  • Nearest neighbor backtracking descends to the likely leaf, then checks sibling subtrees only when the splitting hyperplane falls inside the current best-distance hypersphere
  • Range queries run in O(√n + m) in 2D via the recurrence Q(n) = 2Q(n/4) + O(1), generalizing to O(n^(1-1/k) + m) in k dimensions
  • Deletion is non-trivial: move the left subtree to the right, then replace the deleted node with the minimum along the current axis from the right subtree
  • The curse of dimensionality makes k-d trees degrade past ~20 dimensions — switch to HNSW or approximate nearest neighbor methods when k exceeds 15

Your map app knows which coffee shops are within 500 meters of you. A game engine knows which objects are close enough to collide. A machine learning library classifies a new point by finding its five nearest labeled neighbors. The common thread: partition space recursively so you never have to check every point. Google Maps is not looping over every Starbucks on Earth. That would be O(n) per query, and the VP of Engineering would have a lot of questions.

The k-d tree is the canonical structure for this. Built in 1975 by Jon Louis Bentley for his CACM paper "Multidimensional Binary Search Trees Used for Associative Searching," it is a binary tree where every node is a k-dimensional point and every internal node partitions space along one coordinate axis. It finds nearest neighbors in O(log n) average time and reports all points inside a rectangular region in O(n^(1-1/k) + m) time, where m is the output count.

Reach for a k-d tree when your data lives in low-dimensional coordinate space (roughly fewer than 20 dimensions) and you need repeated nearest neighbor or range queries against a set of points that changes rarely.

Drake meme: rejecting O(nlogn) in favor of O(n!) because it saves one line of code

This is fine for a homework assignment. It is not fine when your location service is handling 50 million queries per second.


What the Tree Is Doing to Space

Imagine a 2D point set scattered across a plane. The k-d tree splits that plane with an axis-aligned hyperplane, puts all points on one side into a left subtree, and all points on the other into a right subtree. Then it recurses, cycling through axes level by level.

Each internal node is both a point in the dataset and a splitting decision. The node's coordinate along the chosen axis is the boundary. Points with a smaller value go left, larger values go right, equal values typically go right.

Level 0: split on X
        (7, 2)
       /        \
Level 1: split on Y
    (5, 4)      (9, 6)
   /     \       /     \
Level 2: split on X
(2,3) (4,7) (8,1) (9,9)

The plane gets recursively carved into rectangles, and each leaf corresponds to one cell of that partition. When you search for the nearest neighbor of a query point, you descend to the cell the query point would live in, then backtrack up, checking sibling cells only when they might contain a closer point.

k-d tree space partitioning: 6 points divided by axis-aligned splits at x=7, y=4, y=6 into colored rectangular cells

The same 6-point tree from the diagram above. Blue line = level-0 X split. Amber lines = level-1 Y splits. Each colored region is a cell the search algorithm can prune entirely.

The Memory Layout

A node in a k-d tree stores three things:

  1. The k-dimensional point (an array of coordinates).
  2. The splitting dimension (which axis this node divides on).
  3. Left and right child pointers (or indices into a flat array).
Node {
  point:    [x₀, x₁, ..., x_{k-1}]   // k floats
  axis:     int                         // splitting dimension
  left:     *Node
  right:    *Node
}

k-d tree node layout showing the four fields and pointer arrows to left and right child nodes

Each node stores the point, which axis it splits on, and two child pointers. The flat-array variant puts left at 2i+1 and right at 2i+2, which the hardware prefetcher actually appreciates.

For a pointer-based implementation, each node is heap-allocated. A flat-array layout is more cache-friendly: one traversal from root to leaf touches contiguous memory rather than chasing scattered pointers. For a flat-array implementation (like a binary heap array), the left child of node i lives at 2i+1 and the right at 2i+2. Ray tracing engines like PBR use compact 8-byte nodes where the below-split child is always stored at nodeNum+1, so only one pointer is stored explicitly. Eight such nodes fit in a single 64-byte cache line.

Splitting Strategies

Three common strategies for choosing where to split:

Cycle through axes. Level d splits on dimension d % k. Simple and predictable. Works well when data is roughly uniformly distributed.

Maximum spread. Pick the dimension with the highest variance at each node. Produces better-balanced trees on skewed data. O(kn) extra work per level during construction.

Sliding midpoint. Split at the spatial midpoint of the bounding box for the current cell, but if all points are on one side, slide the split to the nearest point. This is what scipy uses. It prevents degenerate thin cells and has the best practical performance across diverse data distributions, per Maneewongvatana and Mount's 1999 analysis.


The Core Operations

OperationAverage CaseWorst CaseSpace
Build (from n points)O(n log n)O(n log n)O(n)
InsertO(log n)O(n)O(1)
DeleteO(log n)O(n)O(log n) stack
Exact point searchO(log n)O(n)O(log n) stack
k-NN searchO(log n)O(n)O(log n) stack
Range query (d-dim)O(n^(1-1/d) + m)O(n^(1-1/d) + m)O(log n) stack

Build: Why O(n log n)

To build a balanced k-d tree, find the median of the points along the chosen axis and make it the root. Recurse on the left and right halves. The total work is O(n log n) because you do O(n) median-finding work at each of O(log n) levels of the recursion.

Median finding in O(n) per level requires the median-of-medians algorithm (BFPRT). A simpler approach sorts all n points along the chosen axis in O(n log n), then uses the precomputed order to find medians in O(1) per split. Both give O(n log n) total.

Insert: Why O(log n) Average

Insert follows the same descent as a BST search. At each node, compare the new point's coordinate along that node's splitting axis, go left or right, and continue until you hit a null child. Insert there. On a balanced tree, depth is O(log n), so insertion is O(log n). On an unbalanced tree built via sequential insertion without pre-sorting, depth can reach O(n) in adversarial cases.

Delete: Why Nobody Talks About Delete

You cannot simply swap in a child the way you do in a BST. The splitting dimension alternates. If node N splits on the X axis, its children split on Y. If you were to replace N with its right child (a Y-splitter), the invariant breaks for N's entire left subtree. Every tutorial mentions insertion in one paragraph. Then they get to deletion and suddenly remember they have somewhere else to be.

The correct algorithm:

  1. If the node is a leaf, remove it.
  2. If the node has a right subtree: find the node in the right subtree with the minimum value along the current splitting axis. Replace the target node's point with that minimum. Recursively delete the minimum from the right subtree.
  3. If the node has only a left subtree: move the left subtree to become the right subtree, then proceed as in step 2.

Step 3 looks odd because it is odd. Equal values go right by convention. If you extracted the maximum from the left subtree, any equal-valued point remaining there would violate the ordering invariant. Moving left to right makes the entire old left subtree eligible for the find-minimum procedure.

The find-minimum operation on a k-d tree is itself O(n) in the worst case because the minimum along a given axis can be anywhere in the subtree (you cannot prune whole branches reliably). This means deletion is O(n) worst case, but O(log n) in practice on balanced trees with well-distributed data.

A common production alternative: mark deleted nodes as tombstones and periodically rebuild. This is the data-structure equivalent of clearing your inbox by labeling everything "archive." It keeps insertion and deletion at O(log n) amortized, at the cost of occasional O(n log n) rebuilds that you schedule for 3 AM and hope nobody notices.

Nearest Neighbor: The Backtracking Algorithm

  1. Traverse from root to a leaf, following the same comparisons as an insert. This is the "likely" cell for the query point.
  2. Record the leaf as the current best candidate. Compute the distance.
  3. Unwind the recursion. At each ancestor node, check: does the hypersphere of radius current_best_distance centered on the query point intersect the splitting hyperplane?
  4. If yes: the other side of the split could contain a closer point. Descend into the sibling subtree and recurse.
  5. If no: prune. Nothing on the other side can beat the current best.
distance_to_hyperplane = |query[axis] - node.point[axis]|

if distance_to_hyperplane < current_best_distance:
    search sibling subtree
else:
    prune

The average case is O(log n) because pruning eliminates most of the tree. The worst case is O(n): imagine all points on a circle and the query at the center. The hypersphere intersects every split plane, so no branches are pruned. The structure degrades gracefully into a very expensive linear scan. Congratulations, you've reinvented the for loop with extra steps.

k-d tree nearest neighbor search: query point at (9,2), traversal path shown in green, left half of tree pruned because the hypersphere does not cross x=7

Query (9,2): the algorithm descends right to (9,6), then to (8,1) which becomes the best at distance sqrt(2) ≈ 1.41. On the way back up, the distance to the x=7 split plane is 2, which exceeds 1.41. The entire left subtree is pruned without a single comparison.

The k-NN variant keeps a max-heap of size k tracking the k closest candidates so far. The pruning condition becomes: prune if distance_to_hyperplane >= max(k_best_distances).

Range Query: Why O(sqrt(n)) in 2D

The 2D range search runs in O(sqrt(n) + m) time, and the proof is a recurrence.

Let Q(n) be the number of nodes visited for a range query on a subtree with n points. At each 2D node, the splitting hyperplane is either fully inside the query rectangle (both subtrees searched), crossing one side (one subtree searched, one maybe-pruned), or outside (prune). In the worst case, the hyperplane crosses the rectangle boundary. After descending two levels (splitting first on X, then on Y), the rectangle can fully contain at most n/4 of the original points and the four boundary-crossing sub-problems together cover at most n/4 points too. The recurrence is:

Q(n) = 2·Q(n/4) + O(1)

Solving: substitute k = log₄(n). The number of terms in the expansion is k. Each level contributes 2^level terms. At the base, 2^k = 2^(log₄ n) = n^(1/2) = sqrt(n). So Q(n) = O(sqrt(n)).

The m term is the output: you add O(1) per reported point.

In d dimensions, after descending d levels (one per axis), the subtree size drops by 2^d while the number of boundary-crossing sub-problems grows by at most 2. The general recurrence is Q(n) = 2·Q(n/2^d) + O(1), which gives O(n^(1-1/d) + m) by the same Master Theorem argument.


The Curse of Dimensionality

k-d trees degrade in high dimensions, and it happens fast. The rule of thumb: the structure is useful when n >> 2^k, where k is the number of dimensions.

Chart showing minimum required points growing exponentially as dimensions increase from k=1 to k=20, where at k=20 you need more than a million points before the tree beats a linear scan

At k=3, you need a few dozen points before the k-d tree pays off. At k=20, you need over a million. Most embedding models run at 512 to 1536 dimensions. Do the math, and then switch to HNSW.

At k = 20, you need n >> 1,000,000 points before the tree beats a linear scan. In high dimensions, the nearest neighbor is often only slightly closer than the farthest point, so the backtracking hypersphere has to grow large before it stops intersecting split planes. Almost nothing gets pruned. You've built a very elaborate way to visit every node in the tree.

For d > 20, approximate nearest neighbor structures win: FLANN, HNSW (hierarchical navigable small world graphs), or randomized k-d forests all maintain near-linear query time at the cost of returning approximate rather than exact nearest neighbors. Ball trees partition space into hyperspheres rather than axis-aligned slabs, making them more effective when data has low intrinsic dimension but high ambient dimension.


One Structure, Every Language

All implementations below build a k-d tree on 2D points and support nearest neighbor search. Points are stored as simple coordinate pairs and the tree uses a pointer/reference style.

from __future__ import annotations from dataclasses import dataclass, field from typing import Optional import math @dataclass class KDNode: point: list[float] axis: int left: Optional["KDNode"] = field(default=None, repr=False) right: Optional["KDNode"] = field(default=None, repr=False) def build(points: list[list[float]], depth: int = 0) -> Optional[KDNode]: if not points: return None k = len(points[0]) axis = depth % k points.sort(key=lambda p: p[axis]) mid = len(points) // 2 return KDNode( point=points[mid], axis=axis, left=build(points[:mid], depth + 1), right=build(points[mid + 1:], depth + 1), ) def nearest_neighbor( root: Optional[KDNode], target: list[float], best: list[float] | None = None, best_dist: float = float("inf"), ) -> tuple[list[float] | None, float]: if root is None: return best, best_dist dist = math.dist(root.point, target) if dist < best_dist: best, best_dist = root.point, dist axis = root.axis diff = target[axis] - root.point[axis] near, far = (root.left, root.right) if diff <= 0 else (root.right, root.left) best, best_dist = nearest_neighbor(near, target, best, best_dist) if abs(diff) < best_dist: best, best_dist = nearest_neighbor(far, target, best, best_dist) return best, best_dist # Example points = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]] tree = build(points) neighbor, distance = nearest_neighbor(tree, [9, 2]) print(neighbor, distance) # [8, 1] 1.4142...

What Problems a k-d Tree Solves

Nearest neighbor search. Given a set of n points and a query point, find the closest. The k-d tree was built for this. Applications: k-NN classifiers, point cloud registration (ICP), feature matching in computer vision, recommendation systems using embedding distance.

k-Nearest neighbors. Find the k closest points. Same algorithm, keep a max-heap of size k.

Range queries. Find all points inside an axis-aligned bounding box or hypersphere. Applications: spatial databases, collision detection broad phase in physics engines, map bounding-box queries.

Fixed-radius search. Find all points within distance r of a query. A degenerate range query.

Approximate nearest neighbor. With early termination (limit the number of leaf nodes searched), you get an approximate result much faster. FLANN builds randomized k-d forests for this.


Recognizing the Pattern

The Signals

You probably want a k-d tree when you see:

  • Points in low-dimensional coordinate space (2D, 3D, or up to roughly 10-15 dimensions).
  • Repeated "find nearest" or "find all within distance r" queries against a static or slowly changing point set.
  • The number of points n is much larger than 2^k (otherwise linear scan beats the tree).
  • The query points are arbitrary, not pre-known.

Worked Example 1: Closest Pair of Points

Problem. Given n 2D points, find the pair with the minimum distance between them.

Naive: O(n²) pairwise comparison.

Why k-d tree fits. Build the tree in O(n log n). For each point, query its nearest neighbor (excluding itself) in O(log n). Total: O(n log n). The signals: proximity in 2D space, a static point set, repeated closest-point queries.

(The optimal divide-and-conquer algorithm for closest pair is O(n log n) without a k-d tree, but the k-d tree makes the implementation straightforward and handles the k-NN variant naturally.)

Worked Example 2: KNN Classifier

Problem. You have 100,000 labeled 3D embeddings. Given a new query embedding, classify it by the majority label among its 5 nearest labeled neighbors.

Naive: O(n) per query, scanning all 100,000 embeddings.

Why k-d tree fits. Build in O(n log n). Each query runs in O(log n) average. 1,000 queries: O(n log n) vs O(n * 1000). The signals: fixed set of points (the training set), repeated queries against it, low-dimensional embeddings, exact neighbors needed. When the embedding dimension exceeds roughly 15, switch to HNSW or FAISS. At 3D the k-d tree is the right choice.


Quick Recap

  • A k-d tree is a binary tree where each node is a k-dimensional point and each internal node defines an axis-aligned split.
  • Build by finding the median at each level, cycling through axes. O(n log n) with linear-time median finding.
  • Nearest neighbor uses backtracking: descend to the likely cell, then check siblings only when the splitting hyperplane is closer than the current best.
  • Range search is O(sqrt(n) + m) in 2D via the recurrence T(n) = 2T(n/4) + O(1). General: O(n^(1-1/d) + m).
  • Deletion is tricky: move left subtrees right, then replace with the right subtree's minimum along the current splitting axis.
  • The structure degrades in high dimensions. If k > 15, use approximate methods.
  • Splitting strategies: cycling is simple, max-spread handles skew, sliding midpoint is best in practice.

If you want to practice explaining the backtracking algorithm to a real interviewer, the interviewers at SpaceComplexity will push you on both the mechanism and its complexity, with rubric-based feedback on how clearly you're reasoning through the pruning condition.


For more on the structures that k-d trees build on, the posts on the binary search tree invariant and how hash tables actually achieve O(1) are worth reading alongside this one.

Further Reading