k>1 in nearest test
This commit is contained in:
parent
db9419620b
commit
44e73f6083
1 changed files with 17 additions and 9 deletions
|
@ -56,15 +56,23 @@ def extractFeatures(label):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
arr = extractFeatures("cat") + extractFeatures("dog")
|
||||
test_label = arr[0][1]
|
||||
test_feature = arr[0][0]
|
||||
labels = map(lambda a: a[1], arr)[1:]
|
||||
features = map(lambda a: a[0], arr)[1:]
|
||||
cats = extractFeatures("cat")
|
||||
dogs = extractFeatures("dog")
|
||||
|
||||
test_count = 5
|
||||
|
||||
test_data = dogs[:test_count] + cats[:test_count]
|
||||
test_labels = map(lambda a: a[1], test_data)
|
||||
test_features = map(lambda a: a[0], test_data)
|
||||
|
||||
data = cats[test_count:] + dogs[test_count:]
|
||||
labels = map(lambda a: a[1], data)
|
||||
features = map(lambda a: a[0], data)
|
||||
|
||||
tree = KDTree(features)
|
||||
d, i = tree.query(test_feature)
|
||||
|
||||
|
||||
print test_label + " is predicted to be a " + labels[i]
|
||||
for t in xrange(0, test_count * 2):
|
||||
d, i = tree.query(test_features[t], k=2)
|
||||
for j in xrange(0, len(i)):
|
||||
print test_labels[t] + " is predicted to be a " + labels[i[j]] + " j: " + str(i[j]) + " d: " + str(d[j])
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue