-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathKNNClassifier.cpp
More file actions
73 lines (66 loc) · 1.53 KB
/
Copy pathKNNClassifier.cpp
File metadata and controls
73 lines (66 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include "KNNClassifier.h"
#include <limits>
#include <queue>
using namespace arma;
using namespace std;
KNNClassifier::KNNClassifier(mat X, mat Y, int k)
{
input_x = X;
input_y = Y;
num_nbrs = k;
}
void oneHot(mat &m)
{
int max_index = 0;
float max = 0;
for(int t=0;t<m.n_cols;t++)
{
if(max < m(0,t))
{
max = m(0,t);
max_index = t;
}
}
m = zeros<mat>(1,m.n_cols);
m(0,max_index) = 1;
}
mat KNNClassifier::classify(mat instance,float ttr)
{
priority_queue<PointDistance> nbrs;
for(int i=0;i<input_x.n_rows*ttr;i++)
{
PointDistance p(i,accu(square(input_x.row(i) - instance)));
if(nbrs.size() == num_nbrs)
{
if(nbrs.top().distance > p.distance)
{
nbrs.pop();
nbrs.push(p);
}
}
else
nbrs.push(p);
}
mat result = zeros<mat>(1,input_y.n_cols);
while(nbrs.size()>0)
{
result += input_y.row(nbrs.top().index);
nbrs.pop();
}
//result /= num_nbrs;
oneHot(result);
return result;
}
double KNNClassifier::test(float ttr)
{
int mishap = 0;
for(int i=input_x.n_rows*ttr;i<input_x.n_rows;i++)
{
mat error = classify(input_x.row(i),ttr)-input_y.row(i);
if(accu(square(error))!=0)
mishap++;
}
return 100.0 - mishap*100.0/(input_x.n_rows*(1-ttr));
}
void KNNClassifier::setInputX(mat X){input_x = X;}
mat KNNClassifier::getInputX(){return input_x;}