k>1 in nearest test

This commit is contained in:
Jeena 2013-10-15 02:04:20 +02:00
parent db9419620b
commit 44e73f6083

View file

@ -56,15 +56,23 @@ def extractFeatures(label):
if __name__ == "__main__":
arr = extractFeatures("cat") + extractFeatures("dog")
test_label = arr[0][1]
test_feature = arr[0][0]
labels = map(lambda a: a[1], arr)[1:]
features = map(lambda a: a[0], arr)[1:]
cats = extractFeatures("cat")
dogs = extractFeatures("dog")
test_count = 5
test_data = dogs[:test_count] + cats[:test_count]
test_labels = map(lambda a: a[1], test_data)
test_features = map(lambda a: a[0], test_data)
data = cats[test_count:] + dogs[test_count:]
labels = map(lambda a: a[1], data)
features = map(lambda a: a[0], data)
tree = KDTree(features)
d, i = tree.query(test_feature)
print test_label + " is predicted to be a " + labels[i]
for t in xrange(0, test_count * 2):
d, i = tree.query(test_features[t], k=2)
for j in xrange(0, len(i)):
print test_labels[t] + " is predicted to be a " + labels[i[j]] + " j: " + str(i[j]) + " d: " + str(d[j])