Searching for Nearest-Neighbors between Two Coordinate Catalogs
Say I have two catalogs of points, each in two-dimensional space. For each object in a catalog, I want to find the nearest object(s) in the other catalog. I can do this by computing the distances between every single unique pair of objects and finding the ones within a search radius and possibly doing an additional sort. However, there is a much more efficient way using k-d tree, and Scipy has the class for doing this sort of thing fairly easily.
Here is a Python code example:
#!/usr/bin/env python2.6
from cProfile import Profile
import numpy as np
from numpy.random import random
from scipy.spatial import KDTree, cKDTree
def getnnidx(d1, d2, r):
t1 = KDTree(d1)
t2 = KDTree(d2)
idx = t1.query_ball_tree(t2, r)
return idx
def cgetnnidx(d1, d2, r, k=5):
t = cKDTree(d2)
d, idx = t.query(d1, k=k, eps=0, p=2, distance_upper_bound=r)
return idx
def main():
# number of points
np1 = 4000
np2 = 2000
# search radius
r = 0.01
# prepare coordinates; the input data for the constructor of
# KCTree needs to be in the form:
#
# data = [[x0, y0], [x1, y1], ... , [xN, yN]]
#
d1 = np.empty((np1, 2))
d2 = np.empty((np2, 2))
d1[:, 0] = random(np1)
d1[:, 1] = random(np1)
d2[:, 0] = random(np2)
d2[:, 1] = random(np2)
# profile two versions of KDTree implementations
p = Profile()
result = p.runcall(getnnidx, d1.copy(), d2.copy(), r)
p.print_stats()
p.clear()
result = p.runcall(cgetnnidx, d1.copy(), d2.copy(), r)
p.print_stats()
if __name__ == '__main__':
main()
When I run this script, I get the following output:
706932 function calls (692470 primitive calls) in 1.663 CPU seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
1720 0.001 0.000 0.004 0.000 fromnumeric.py:1067(nonzero)
2 0.000 0.000 0.000 0.000 fromnumeric.py:1147(shape)
42293 0.046 0.000 0.224 0.000 fromnumeric.py:1314(sum)
2 0.000 0.000 0.000 0.000 fromnumeric.py:1708(amax)
2 0.000 0.000 0.000 0.000 fromnumeric.py:1769(amin)
860 0.001 0.000 0.002 0.000 fromnumeric.py:664(argmax)
14462 0.008 0.000 0.025 0.000 function_base.py:776(copy)
2 0.000 0.000 0.045 0.022 kdtree.py:113(__init__)
862 0.001 0.000 0.001 0.000 kdtree.py:139(__init__)
860 0.001 0.000 0.001 0.000 kdtree.py:143(__init__)
1722/2 0.033 0.000 0.045 0.022 kdtree.py:150(__build)
42293 0.183 0.000 0.992 0.000 kdtree.py:22(minkowski_distance)
12744 0.085 0.000 0.104 0.000 kdtree.py:36(__init__)
1 0.002 0.002 1.618 1.618 kdtree.py:437(query_ball_tree)
12743/1 0.270 0.000 1.616 1.616 kdtree.py:462(traverse_checking)
6371 0.031 0.000 0.157 0.000 kdtree.py:49(split)
42293 0.372 0.000 0.682 0.000 kdtree.py:7(minkowski_distance_p)
12743 0.133 0.000 0.430 0.000 kdtree.py:72(min_distance_rectangle)
7385 0.052 0.000 0.219 0.000 kdtree.py:76(max_distance_rectangle)
169174 0.095 0.000 0.213 0.000 numeric.py:201(asarray)
1 0.000 0.000 1.663 1.663 xmatch.py:8(getnnidx)
57063 0.036 0.000 0.036 0.000 {isinstance}
5164 0.001 0.000 0.001 0.000 {len}
860 0.001 0.000 0.001 0.000 {method 'argmax' of 'numpy.ndarray' objects}
25488 0.019 0.000 0.019 0.000 {method 'astype' of 'numpy.ndarray' objects}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
2 0.000 0.000 0.000 0.000 {method 'max' of 'numpy.ndarray' objects}
2 0.000 0.000 0.000 0.000 {method 'min' of 'numpy.ndarray' objects}
1720 0.003 0.000 0.003 0.000 {method 'nonzero' of 'numpy.ndarray' objects}
42293 0.146 0.000 0.146 0.000 {method 'sum' of 'numpy.ndarray' objects}
22165 0.008 0.000 0.008 0.000 {method 'tolist' of 'numpy.ndarray' objects}
2 0.000 0.000 0.000 0.000 {numpy.core.multiarray.arange}
183636 0.135 0.000 0.135 0.000 {numpy.core.multiarray.array}
1 0.000 0.000 0.000 0.000 {range}
32 function calls in 0.008 CPU seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
3 0.000 0.000 0.000 0.000 fromnumeric.py:107(reshape)
3 0.000 0.000 0.000 0.000 fromnumeric.py:1147(shape)
1 0.000 0.000 0.000 0.000 fromnumeric.py:1708(amax)
1 0.000 0.000 0.000 0.000 fromnumeric.py:1769(amin)
1 0.000 0.000 0.000 0.000 fromnumeric.py:1865(prod)
1 0.000 0.000 0.000 0.000 fromnumeric.py:32(_wrapit)
2 0.000 0.000 0.000 0.000 numeric.py:201(asarray)
4 0.000 0.000 0.000 0.000 numeric.py:314(ascontiguousarray)
1 0.000 0.000 0.008 0.008 xmatch.py:15(cgetnnidx)
1 0.000 0.000 0.000 0.000 {getattr}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
1 0.000 0.000 0.000 0.000 {method 'max' of 'numpy.ndarray' objects}
1 0.000 0.000 0.000 0.000 {method 'min' of 'numpy.ndarray' objects}
1 0.000 0.000 0.000 0.000 {method 'prod' of 'numpy.ndarray' objects}
1 0.007 0.007 0.007 0.007 {method 'query' of 'scipy.spatial.ckdtree.cKDTree' objects}
3 0.000 0.000 0.000 0.000 {method 'reshape' of 'numpy.ndarray' objects}
6 0.000 0.000 0.000 0.000 {numpy.core.multiarray.array}
from which it is clear that cKDTree
is much faster (1.663 vs. 0.008 seconds!). Not surprising given cKDTree
is a C implementation of basically the same thing.
These classes are not exactly equivalent and therefore the outputs are different. KDTree.query_ball_tree
returns:
[[],
[216, 1767],
[317],
...,
[],
[367],
[1465, 1899]]
which is an array of indices of nearest neighbors. If an object is not found within the search radius, it simply returns an array of zero elements.
cKDTree.query
, on the other hand, returns something like
[[2000 2000 2000 2000 2000]
[ 216 1767 2000 2000 2000]
[ 317 2000 2000 2000 2000]
...,
[2000 2000 2000 2000 2000]
[ 367 2000 2000 2000 2000]
[1465 1899 2000 2000 2000]]
which is an array of indices of five nearest neighbors. When not found within a search radius, it uses the number of elements as a place holder. The number of nearest neighbors to retain can be specified by the input argument (k
).