00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00024 #include "OneRClassifier.h"
00025
00026 using std::ostringstream;
00027 using std::cout;
00028 using std::endl;
00029 using std::vector;
00030
00031 using namespace Marsyas;
00032
00033 OneRClassifier::OneRClassifier(const mrs_string name) : MarSystem("OneRClassifier", name)
00034 {
00035 addControls();
00036 rule_ = NULL;
00037 lastModePredict_ = false;
00038 }
00039
00040
00041 OneRClassifier::~OneRClassifier()
00042 {
00043 if(rule_ != NULL)
00044 delete rule_;
00045 }
00046
00047 MarSystem *OneRClassifier::clone() const
00048 {
00049 return new OneRClassifier(*this);
00050 }
00051
00052 void
00053 OneRClassifier::addControls()
00054 {
00055 addctrl("mrs_string/mode", "train");
00056 addctrl("mrs_natural/nClasses", 1);
00057 setctrlState("mrs_natural/nClasses", true);
00058 }
00059
00060 void
00061 OneRClassifier::myUpdate(MarControlPtr sender)
00062 {
00063 MRSDIAG("OneRClassifier.cpp - OneRClassifier:myUpdate");
00064 ctrl_onSamples_->setValue(ctrl_inSamples_, NOUPDATE);
00065 setctrl("mrs_natural/onObservations", 2);
00066 }
00067
00068 void
00069 OneRClassifier::myProcess(realvec& in, realvec& out)
00070 {
00071 cout << "OneRClassifier::myProcess" << endl;
00072 cout << "in.getCols() = " << in.getCols() << endl;
00073 cout << "in.getRows() = " << in.getRows() << endl;
00074
00075 bool trainMode = (getctrl("mrs_string/mode")->to<mrs_string>() == "train");
00076 row_.stretch(in.getRows());
00077 if (trainMode)
00078 {
00079 if(lastModePredict_ || instances_.getCols()<=0)
00080 {
00081 mrs_natural nAttributes = getctrl("mrs_natural/inObservations")->to<mrs_natural>();
00082 cout << "nAttributes = " << nAttributes << endl;
00083 instances_.Create(nAttributes);
00084 }
00085
00086 lastModePredict_ = false;
00087
00088
00089 for (mrs_natural ii=0; ii< inSamples_; ++ii)
00090 {
00091 mrs_real label = in(inObservations_-1, ii);
00092 instances_.Append(in);
00093 out(0,ii) = label;
00094 out(1,ii) = label;
00095 }
00096 }
00097 else
00098 {
00099
00100 cout << "OneRClassifier::predict" << endl;
00101 if(!lastModePredict_)
00102 {
00103
00104 mrs_natural nAttributes = getctrl("mrs_natural/inObservations")->to<mrs_natural>();
00105 cout << "BUILD nAttributes = " << nAttributes << endl;
00106 Build(nAttributes);
00107 }
00108 lastModePredict_ = true;
00109 cout << "After lastModePredict" << endl;
00110
00111
00112
00113
00114 for (mrs_natural ii=0; ii<inSamples_; ++ii)
00115 {
00116
00117 mrs_natural label = (mrs_natural)in(inObservations_-1, ii);
00118
00119
00120 in.getCol(ii,row_);
00121 mrs_natural prediction = Predict(row_);
00122 cout << "PREDICTION = " << prediction << endl;
00123 cout << "row_ " << row_ << endl;
00124
00125
00126 out(0,ii) = (mrs_real)prediction;
00127 out(1,ii) = (mrs_real)label;
00128 }
00129 }
00130
00131 }
00132
00133
00134
00135 OneRClassifier::OneRRule *OneRClassifier::newRule(mrs_natural attr, mrs_natural nClasses)
00136 {
00137
00138 vector<mrs_natural> classifications(instances_.size());
00139 vector<mrs_real> breakpoints(instances_.size());
00140 vector<mrs_natural> counts(nClasses);
00141
00142
00143 mrs_natural correct = 0;
00144 mrs_natural lastInstance = instances_.size();
00145
00146
00147 instances_.Sort(attr);
00148
00149 mrs_natural ii = 0;
00150 mrs_natural cl = 0;
00151 mrs_natural it = 0;
00152
00153
00154 while(ii < lastInstance)
00155 {
00156
00157 for(mrs_natural jj=0; jj<(mrs_natural)counts.size(); jj++) counts[jj]=0;
00158 do
00159 {
00160 it = instances_.GetClass(++ii);
00161 counts[it]++;
00162 }while(counts[it] < minBucketSize_ && ii < lastInstance);
00163
00164
00165 while(ii < lastInstance && instances_.GetClass(ii) == it)
00166 {
00167 counts[it]++;
00168 ++ii;
00169 }
00170
00171
00172 while(ii < lastInstance && instances_.at(ii-1)->at(attr) == instances_.at(ii)->at(attr))
00173 {
00174 mrs_natural index = instances_.GetClass(ii++);
00175 counts[index]++;
00176 }
00177
00178 for(mrs_natural jj=0; jj<nClasses; jj++)
00179 {
00180 if(counts[jj] > counts[it])
00181 {
00182 it = jj;
00183 }
00184 }
00185
00186 if(cl > 0)
00187 {
00188 if(counts[classifications[cl-1]] == counts[it])
00189 it = classifications[cl-1];
00190
00191 if(it == classifications[cl-1])
00192 cl--;
00193 }
00194
00195 correct += counts[it];
00196 classifications[cl] = it;
00197
00198 if(ii < lastInstance)
00199 breakpoints[cl] = (((instances_.at(ii-1)->at(attr) + instances_.at(ii)->at(attr)) / 2.0));
00200
00201 cl++;
00202 }
00203
00204
00205 OneRRule *rule = new OneRRule(attr, cl, correct);
00206 for(mrs_natural vv=0; vv<cl; vv++)
00207 {
00208 rule->getClassifications()[vv] = classifications[vv];
00209 if(vv < (cl-1))
00210 rule->getBreakpoints()[vv] = breakpoints[vv];
00211
00212 }
00213
00214 return rule;
00215 }
00216
00217
00218 void
00219 OneRClassifier::Build(mrs_natural nClasses)
00220 {
00221
00222 if(rule_!=NULL)
00223 delete rule_;
00224 rule_ = NULL;
00225
00226
00227 for(mrs_natural enu = 0; enu < instances_.getCols()-1; enu++)
00228 {
00229
00230 OneRClassifier::OneRRule *r = newRule(enu, nClasses);
00231
00232
00233 if(!rule_ || r->getCorrect() > rule_->getCorrect())
00234 {
00235 if(rule_!=NULL)
00236 delete rule_;
00237
00238 rule_ = r;
00239 }
00240 }
00241 }
00242
00243
00244 mrs_natural
00245 OneRClassifier::Predict(const realvec& in)
00246 {
00247 mrs_natural vv = 0;
00248 mrs_real instValue = in(rule_->getAttr());
00249
00250
00251 while(vv < rule_->getnBreaks()-1 && instValue >= rule_->getBreakpoints()[vv])
00252 {
00253 vv++;
00254 }
00255
00256
00257 return rule_->getClassifications()[vv];
00258 }