From 44e73f60838af48334fc5c3c3d5746b0c3e1b48a Mon Sep 17 00:00:00 2001 From: Jeena Date: Tue, 15 Oct 2013 02:04:20 +0200 Subject: [PATCH] k>1 in nearest test --- find_lines.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/find_lines.py b/find_lines.py index c497a97..13c3a43 100755 --- a/find_lines.py +++ b/find_lines.py @@ -56,15 +56,23 @@ def extractFeatures(label): if __name__ == "__main__": - arr = extractFeatures("cat") + extractFeatures("dog") - test_label = arr[0][1] - test_feature = arr[0][0] - labels = map(lambda a: a[1], arr)[1:] - features = map(lambda a: a[0], arr)[1:] + cats = extractFeatures("cat") + dogs = extractFeatures("dog") + + test_count = 5 + + test_data = dogs[:test_count] + cats[:test_count] + test_labels = map(lambda a: a[1], test_data) + test_features = map(lambda a: a[0], test_data) + + data = cats[test_count:] + dogs[test_count:] + labels = map(lambda a: a[1], data) + features = map(lambda a: a[0], data) tree = KDTree(features) - d, i = tree.query(test_feature) - - - print test_label + " is predicted to be a " + labels[i] + + for t in xrange(0, test_count * 2): + d, i = tree.query(test_features[t], k=2) + for j in xrange(0, len(i)): + print test_labels[t] + " is predicted to be a " + labels[i[j]] + " j: " + str(i[j]) + " d: " + str(d[j])