00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #include "KNNClassifier.h"
00020
00021 using std::ostringstream;
00022 using namespace Marsyas;
00023
00024 KNNClassifier::KNNClassifier(mrs_string name):MarSystem("KNNClassifier",name)
00025 {
00026 prev_mode_ = "predict";
00027 addControls();
00028 }
00029
00030
00031 KNNClassifier::~KNNClassifier()
00032 {
00033 }
00034
00035
00036 MarSystem*
00037 KNNClassifier::clone() const
00038 {
00039 return new KNNClassifier(*this);
00040 }
00041
00042 void
00043 KNNClassifier::addControls()
00044 {
00045 addctrl("mrs_string/mode", "train");
00046 addctrl("mrs_natural/nLabels", 1);
00047 setctrlState("mrs_natural/nLabels", true);
00048 trainSet_.create((mrs_natural)1,(mrs_natural)1);
00049 addctrl("mrs_natural/grow", 1);
00050 addctrl("mrs_natural/k", 1);
00051 k_ = 1;
00052 addctrl("mrs_realvec/trainSet", trainSet_);
00053 addctrl("mrs_natural/nPoints", 0);
00054 addctrl("mrs_bool/done", false);
00055 addctrl("mrs_natural/nPredictions", 1);
00056 setctrlState("mrs_natural/nPredictions", true);
00057 setctrlState("mrs_bool/done", true);
00058 }
00059
00060
00061 void
00062 KNNClassifier::myUpdate(MarControlPtr sender)
00063 {
00064 (void) sender;
00065 MRSDIAG("KNNClassifier.cpp - KNNClassifier:myUpdate");
00066
00067 nPredictions_ = getctrl("mrs_natural/nPredictions")->to<mrs_natural>();
00068 setctrl("mrs_natural/onSamples", getctrl("mrs_natural/inSamples"));
00069 setctrl("mrs_natural/onObservations", (mrs_natural) nPredictions_ + 1);
00070 setctrl("mrs_real/osrate", getctrl("mrs_real/israte"));
00071
00072 inObservations_ = getctrl("mrs_natural/inObservations")->to<mrs_natural>();
00073 grow_ = getctrl("mrs_natural/grow")->to<mrs_natural>();
00074 nPoints_ = getctrl("mrs_natural/nPoints")->to<mrs_natural>();
00075 k_ = getctrl("mrs_natural/k")->to<mrs_natural>();
00076 mrs_string mode = getctrl("mrs_string/mode")->to<mrs_string>();
00077
00078 if (mode == "train")
00079 {
00080 if (inObservations_ != trainSet_.getCols())
00081 {
00082 trainSet_.stretch(1, getctrl("mrs_natural/inObservations")->to<mrs_natural>());
00083 setctrl("mrs_realvec/trainSet", trainSet_);
00084 }
00085 }
00086
00087
00088
00089 if (mode == "predict")
00090 {
00091 trainSet_.create(getctrl("mrs_realvec/trainSet")->to<mrs_realvec>().getRows(),
00092 getctrl("mrs_realvec/trainSet")->to<mrs_realvec>().getCols());
00093 trainSet_ = getctrl("mrs_realvec/trainSet")->to<mrs_realvec>();
00094 }
00095
00096
00097 if (getctrl("mrs_bool/done")->to<mrs_bool>())
00098 {
00099 setctrl("mrs_bool/done", false);
00100 setctrl("mrs_realvec/trainSet", trainSet_);
00101 }
00102 }
00103
00104
00105 void
00106 KNNClassifier::myProcess(realvec& in, realvec& out)
00107 {
00108
00109
00110 mrs_real v;
00111 mrs_string mode = getctrl("mrs_string/mode")->to<mrs_string>();
00112 mrs_real label;
00113 mrs_natural nlabels = getctrl("mrs_natural/nLabels")->to<mrs_natural>();
00114 mrs_natural prediction;
00115 int x, y;
00116 int p;
00117 mrs_natural o,t;
00118
00119 if ((prev_mode_ == "predict")&&(mode == "train"))
00120 {
00121
00122
00123 for (p = 0; p < nPoints_; p++)
00124 {
00125 for (o=0; o < inObservations_-1; o++)
00126 trainSet_(p,o) = 0.0;
00127 }
00128 nPoints_ = 0;
00129 }
00130
00131
00132 if (mode == "train")
00133 {
00134 for (t = 0; t < inSamples_; t++)
00135 {
00136 label = in(inObservations_-1, t);
00137
00138 if (nPoints_ == grow_)
00139 {
00140
00141
00142 grow_ = 2*grow_;
00143 trainSet_.stretch(grow_, inObservations_);
00144 updControl("mrs_natural/grow", grow_);
00145 }
00146
00147 for (o=0; o < inObservations_; o++)
00148 {
00149
00150 trainSet_(nPoints_,o) = in(o,t);
00151 }
00152 out(0,t) = label;
00153 out(1,t) = label;
00154
00155
00156 nPoints_= nPoints_ +1;
00157 updControl("mrs_natural/nPoints", nPoints_);
00158 }
00159 }
00160
00161
00162
00163 if (mode == "predict")
00164 {
00165
00166
00167 for (t = 0; t < inSamples_; t++)
00168 {
00169 label = in(inObservations_-1, t);
00170
00171 realvec Distance;
00172 Distance.create(nPoints_);
00173
00174 realvec kMin;
00175 kMin.create(k_,2);
00176
00177 realvec kSmallest;
00178 kSmallest.create(nlabels);
00179
00180 for (p = 0; p < nPoints_; p++)
00181 {
00182 mrs_real sum = 0;
00183 for (o=0; o < inObservations_-1; o++)
00184 {
00185 v = in(o,t);
00186 v = (v - trainSet_(p,o));
00187 sum += v*v;
00188 }
00189 Distance(p) = sum;
00190 }
00191
00192
00193
00194
00195
00196 mrs_real kmaxV = Distance(0);
00197 int kmaxI = 0;
00198
00199 for (x=0; x < k_; x++)
00200 {
00201 kMin(x, 0) = Distance(0);
00202 kMin(x, 1) = 0;
00203 }
00204
00205
00206 for (y=0; y < nPoints_; y++)
00207 {
00208
00209 if (Distance(y) < kmaxV)
00210 {
00211 mrs_real kmaxV_t = 0.0;
00212 int kmaxI_t = 1;
00213
00214 kMin(kmaxI,0) = Distance(y);
00215 kMin(kmaxI,1) = trainSet_(y, inObservations_-1);
00216
00217
00218 for (x=0; x < k_; x++)
00219 {
00220 kmaxV_t = kMin(0,0);
00221 kmaxI_t = 0;
00222 if (kMin(x,0) > kmaxV_t)
00223 {
00224 kmaxV_t = kMin(x,0);
00225 kmaxI_t = x;
00226 }
00227 }
00228 kmaxV = kmaxV_t;
00229 kmaxI = kmaxI_t;
00230 }
00231 }
00232
00233
00234
00235 for (x=0; x< k_; x++)
00236 {
00237
00238 kSmallest((int)kMin(x, 1))++;
00239 }
00240
00241
00242
00243 mrs_real max = kSmallest(0);
00244 int maxI = 0;
00245 for (x=0; x<nlabels; x++)
00246 {
00247 if (kSmallest(x) > max)
00248 {
00249 max = kSmallest(x);
00250 maxI = x;
00251 }
00252 }
00253 prediction = maxI;
00254 out(0,t) = (mrs_real)prediction;
00255 if (nPredictions_ >= 1)
00256 for (x=0; x < nPredictions_; x++)
00257 out(x,t) = kMin(x,1);
00258
00259 out(onObservations_-1,t) = label;
00260 }
00261 }
00262
00263
00264
00265 prev_mode_ = mode;
00266 }
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276