Biboroku

Searching for Nearest-Neighbors between Two Coordinate Catalogs

Written by Taro Sato, on . Tagged: stats astro

Say I have two catalogs of points, each in two-dimensional space. For each object in a catalog, I want to find the nearest object(s) in the other catalog. I can do this by computing the distances between every single unique pair of objects and finding the ones within a search radius and possibly doing an additional sort. However, there is a much more efficient way using k-d tree, and Scipy has the class for doing this sort of thing fairly easily.

Here is a Python code example:

#!/usr/bin/env python2.6
from cProfile import Profile
import numpy as np
from numpy.random import random
from scipy.spatial import KDTree, cKDTree


def getnnidx(d1, d2, r):
    t1 = KDTree(d1)
    t2 = KDTree(d2)
    idx = t1.query_ball_tree(t2, r)
    return idx


def cgetnnidx(d1, d2, r, k=5):
    t = cKDTree(d2)
    d, idx = t.query(d1, k=k, eps=0, p=2, distance_upper_bound=r)
    return idx


def main():
    # number of points
    np1 = 4000
    np2 = 2000

    # search radius
    r = 0.01

    # prepare coordinates; the input data for the constructor of
    # KCTree needs to be in the form:
    #
    #   data = [[x0, y0], [x1, y1], ... , [xN, yN]]
    #
    d1 = np.empty((np1, 2))
    d2 = np.empty((np2, 2))
    d1[:, 0] = random(np1)
    d1[:, 1] = random(np1)
    d2[:, 0] = random(np2)
    d2[:, 1] = random(np2)

    # profile two versions of KDTree implementations
    p = Profile()

    result = p.runcall(getnnidx, d1.copy(), d2.copy(), r)
    p.print_stats()

    p.clear()

    result = p.runcall(cgetnnidx, d1.copy(), d2.copy(), r)
    p.print_stats()


if __name__ == '__main__':
    main()

When I run this script, I get the following output:

       706932 function calls (692470 primitive calls) in 1.663 CPU seconds

 Ordered by: standard name

 ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   1720    0.001    0.000    0.004    0.000 fromnumeric.py:1067(nonzero)
      2    0.000    0.000    0.000    0.000 fromnumeric.py:1147(shape)
  42293    0.046    0.000    0.224    0.000 fromnumeric.py:1314(sum)
      2    0.000    0.000    0.000    0.000 fromnumeric.py:1708(amax)
      2    0.000    0.000    0.000    0.000 fromnumeric.py:1769(amin)
    860    0.001    0.000    0.002    0.000 fromnumeric.py:664(argmax)
  14462    0.008    0.000    0.025    0.000 function_base.py:776(copy)
      2    0.000    0.000    0.045    0.022 kdtree.py:113(__init__)
    862    0.001    0.000    0.001    0.000 kdtree.py:139(__init__)
    860    0.001    0.000    0.001    0.000 kdtree.py:143(__init__)
 1722/2    0.033    0.000    0.045    0.022 kdtree.py:150(__build)
  42293    0.183    0.000    0.992    0.000 kdtree.py:22(minkowski_distance)
  12744    0.085    0.000    0.104    0.000 kdtree.py:36(__init__)
      1    0.002    0.002    1.618    1.618 kdtree.py:437(query_ball_tree)
12743/1    0.270    0.000    1.616    1.616 kdtree.py:462(traverse_checking)
   6371    0.031    0.000    0.157    0.000 kdtree.py:49(split)
  42293    0.372    0.000    0.682    0.000 kdtree.py:7(minkowski_distance_p)
  12743    0.133    0.000    0.430    0.000 kdtree.py:72(min_distance_rectangle)
   7385    0.052    0.000    0.219    0.000 kdtree.py:76(max_distance_rectangle)
 169174    0.095    0.000    0.213    0.000 numeric.py:201(asarray)
      1    0.000    0.000    1.663    1.663 xmatch.py:8(getnnidx)
  57063    0.036    0.000    0.036    0.000 {isinstance}
   5164    0.001    0.000    0.001    0.000 {len}
    860    0.001    0.000    0.001    0.000 {method 'argmax' of 'numpy.ndarray' objects}
  25488    0.019    0.000    0.019    0.000 {method 'astype' of 'numpy.ndarray' objects}
      1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
      2    0.000    0.000    0.000    0.000 {method 'max' of 'numpy.ndarray' objects}
      2    0.000    0.000    0.000    0.000 {method 'min' of 'numpy.ndarray' objects}
   1720    0.003    0.000    0.003    0.000 {method 'nonzero' of 'numpy.ndarray' objects}
  42293    0.146    0.000    0.146    0.000 {method 'sum' of 'numpy.ndarray' objects}
  22165    0.008    0.000    0.008    0.000 {method 'tolist' of 'numpy.ndarray' objects}
      2    0.000    0.000    0.000    0.000 {numpy.core.multiarray.arange}
 183636    0.135    0.000    0.135    0.000 {numpy.core.multiarray.array}
      1    0.000    0.000    0.000    0.000 {range}


       32 function calls in 0.008 CPU seconds

 Ordered by: standard name

 ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      3    0.000    0.000    0.000    0.000 fromnumeric.py:107(reshape)
      3    0.000    0.000    0.000    0.000 fromnumeric.py:1147(shape)
      1    0.000    0.000    0.000    0.000 fromnumeric.py:1708(amax)
      1    0.000    0.000    0.000    0.000 fromnumeric.py:1769(amin)
      1    0.000    0.000    0.000    0.000 fromnumeric.py:1865(prod)
      1    0.000    0.000    0.000    0.000 fromnumeric.py:32(_wrapit)
      2    0.000    0.000    0.000    0.000 numeric.py:201(asarray)
      4    0.000    0.000    0.000    0.000 numeric.py:314(ascontiguousarray)
      1    0.000    0.000    0.008    0.008 xmatch.py:15(cgetnnidx)
      1    0.000    0.000    0.000    0.000 {getattr}
      1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
      1    0.000    0.000    0.000    0.000 {method 'max' of 'numpy.ndarray' objects}
      1    0.000    0.000    0.000    0.000 {method 'min' of 'numpy.ndarray' objects}
      1    0.000    0.000    0.000    0.000 {method 'prod' of 'numpy.ndarray' objects}
      1    0.007    0.007    0.007    0.007 {method 'query' of 'scipy.spatial.ckdtree.cKDTree' objects}
      3    0.000    0.000    0.000    0.000 {method 'reshape' of 'numpy.ndarray' objects}
      6    0.000    0.000    0.000    0.000 {numpy.core.multiarray.array}

from which it is clear that cKDTree is much faster (1.663 vs. 0.008 seconds!). Not surprising given cKDTree is a C implementation of basically the same thing.

These classes are not exactly equivalent and therefore the outputs are different. KDTree.query_ball_tree returns:

[[],
 [216, 1767],
 [317],
 ...,
 [],
 [367],
 [1465, 1899]]

which is an array of indices of nearest neighbors. If an object is not found within the search radius, it simply returns an array of zero elements.

cKDTree.query, on the other hand, returns something like

[[2000 2000 2000 2000 2000]
 [ 216 1767 2000 2000 2000]
 [ 317 2000 2000 2000 2000]
 ...,
 [2000 2000 2000 2000 2000]
 [ 367 2000 2000 2000 2000]
 [1465 1899 2000 2000 2000]]

which is an array of indices of five nearest neighbors. When not found within a search radius, it uses the number of elements as a place holder. The number of nearest neighbors to retain can be specified by the input argument (k).