diff --git a/find_lines.py b/find_lines.py index c497a97..13c3a43 100755 --- a/find_lines.py +++ b/find_lines.py @@ -56,15 +56,23 @@ def extractFeatures(label): if __name__ == "__main__": - arr = extractFeatures("cat") + extractFeatures("dog") - test_label = arr[0][1] - test_feature = arr[0][0] - labels = map(lambda a: a[1], arr)[1:] - features = map(lambda a: a[0], arr)[1:] + cats = extractFeatures("cat") + dogs = extractFeatures("dog") + + test_count = 5 + + test_data = dogs[:test_count] + cats[:test_count] + test_labels = map(lambda a: a[1], test_data) + test_features = map(lambda a: a[0], test_data) + + data = cats[test_count:] + dogs[test_count:] + labels = map(lambda a: a[1], data) + features = map(lambda a: a[0], data) tree = KDTree(features) - d, i = tree.query(test_feature) - - - print test_label + " is predicted to be a " + labels[i] + + for t in xrange(0, test_count * 2): + d, i = tree.query(test_features[t], k=2) + for j in xrange(0, len(i)): + print test_labels[t] + " is predicted to be a " + labels[i[j]] + " j: " + str(i[j]) + " d: " + str(d[j])