Marsyas  0.5.0-beta1
/Users/jleben/code/marsyas/src/marsyas/marsystems/OneRClassifier.cpp
Go to the documentation of this file.
00001 /*
00002 ** Copyright (C) 1998-2010 George Tzanetakis <gtzan@cs.uvic.ca>
00003 **
00004 ** This program is free software; you can redistribute it and/or modify
00005 ** it under the terms of the GNU General Public License as published by
00006 ** the Free Software Foundation; either version 2 of the License, or
00007 ** (at your option) any later version.
00008 **
00009 ** This program is distributed in the hope that it will be useful,
00010 ** but WITHOUT ANY WARRANTY; without even the implied warranty of
00011 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00012 ** GNU General Public License for more details.
00013 **
00014 ** You should have received a copy of the GNU General Public License
00015 ** along with this program; if not, write to the Free Software
00016 ** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00017 */
00018 
00024 #include "OneRClassifier.h"
00025 #include "../common_source.h"
00026 #include <cstddef>
00027 
00028 using std::ostringstream;
00029 using std::cout;
00030 using std::endl;
00031 using std::vector;
00032 using std::size_t;
00033 
00034 using namespace Marsyas;
00035 
00036 OneRClassifier::OneRClassifier(const mrs_string name) : MarSystem("OneRClassifier", name)
00037 {
00038   addControls();
00039   rule_ = NULL;
00040   lastModePredict_ = false;
00041 }
00042 
00043 //Only thing needing destroying is the current rule.
00044 OneRClassifier::~OneRClassifier()
00045 {
00046   if(rule_ != NULL)
00047     delete rule_;
00048 }
00049 
00050 MarSystem *OneRClassifier::clone() const
00051 {
00052   return new OneRClassifier(*this);
00053 }
00054 
00055 void
00056 OneRClassifier::addControls()
00057 {
00058   addctrl("mrs_string/mode", "train");
00059   addctrl("mrs_natural/nClasses", 1);
00060   setctrlState("mrs_natural/nClasses", true);
00061 }
00062 
00063 void
00064 OneRClassifier::myUpdate(MarControlPtr sender)
00065 {
00066   (void) sender;  //suppress warning of unused parameter(s)
00067   MRSDIAG("OneRClassifier.cpp - OneRClassifier:myUpdate");
00068   ctrl_onSamples_->setValue(ctrl_inSamples_, NOUPDATE);
00069   setctrl("mrs_natural/onObservations", 2);
00070   ctrl_osrate_->setValue(ctrl_israte_->to<mrs_real>());
00071   ctrl_onObsNames_->setValue("OneRClassifier_"
00072                              + ctrl_inObsNames_->to<mrs_string>() , NOUPDATE);
00073 }
00074 
00075 void
00076 OneRClassifier::myProcess(realvec& in, realvec& out)
00077 {
00078   cout << "OneRClassifier::myProcess" << endl;
00079   cout << "in.getCols() = " << in.getCols() << endl;
00080   cout << "in.getRows() = " << in.getRows() << endl;
00081   //get the current mode, either train of predict mode
00082   bool trainMode = (getctrl("mrs_string/mode")->to<mrs_string>() == "train");
00083   row_.stretch(in.getRows());
00084   if (trainMode)
00085   {
00086     if(lastModePredict_ || instances_.getCols()<=0)
00087     {
00088       mrs_natural nAttributes = getctrl("mrs_natural/inObservations")->to<mrs_natural>();
00089       cout << "nAttributes = " << nAttributes << endl;
00090       instances_.Create(nAttributes);
00091     }
00092 
00093     lastModePredict_ = false;
00094 
00095     //get the incoming data and append it to the data table
00096     for (mrs_natural ii=0; ii< inSamples_; ++ii)
00097     {
00098       mrs_real label = in(inObservations_-1, ii);
00099       instances_.Append(in);
00100       out(0,ii) = label;
00101       out(1,ii) = label;
00102     }//for t
00103   }//if
00104   else
00105   { //predict mode
00106 
00107     cout << "OneRClassifier::predict" << endl;
00108     if(!lastModePredict_)
00109     {
00110       //get the number of class labels and build the classifier
00111       mrs_natural nAttributes = getctrl("mrs_natural/inObservations")->to<mrs_natural>();
00112       cout << "BUILD nAttributes = " << nAttributes << endl;
00113       Build(nAttributes);
00114     }//if
00115     lastModePredict_ = true;
00116     cout << "After lastModePredict" << endl;
00117 
00118 
00119     //foreach row of predict data, extract the actual class, then call the
00120     //classifier predict method. Output the actual and predicted classes.
00121     for (mrs_natural ii=0; ii<inSamples_; ++ii)
00122     {
00123       //extract the actual class
00124       mrs_natural label = (mrs_natural)in(inObservations_-1, ii);
00125 
00126       //invoke the classifier predict method to predict the class
00127       in.getCol(ii,row_);
00128       mrs_natural prediction = Predict(row_);
00129       cout << "PREDICTION = " << prediction << endl;
00130       cout << "row_ " << row_ << endl;
00131 
00132       //and output actual/predicted classes
00133       out(0,ii) = (mrs_real)prediction;
00134       out(1,ii) = (mrs_real)label;
00135     }//for t
00136   }//if
00137 
00138 }//myProcess
00139 
00140 //Create a new rule for this attribute.
00141 //Sorts the data table on this attribute and executes the OneR algorithm.
00142 OneRClassifier::OneRRule *OneRClassifier::newRule(mrs_natural attr, mrs_natural nClasses)
00143 {
00144   //create the counting variables
00145   vector<mrs_natural> classifications(instances_.size());
00146   vector<mrs_real> breakpoints(instances_.size());
00147   vector<mrs_natural> counts(nClasses);
00148 
00149   //set correct count to 0
00150   mrs_natural correct = 0;
00151   mrs_natural lastInstance = (mrs_natural) instances_.size();
00152 
00153   //Sort the data table for this attribute
00154   instances_.Sort(attr);
00155 
00156   mrs_natural ii = 0;
00157   mrs_natural cl = 0;           //index of next bucket to create
00158   mrs_natural it = 0;
00159 
00160   //scan thru all rows in table
00161   while(ii < lastInstance)
00162   {
00163     //zero the current counts
00164     for(mrs_natural jj=0; jj<(mrs_natural)counts.size(); jj++) counts[jj]=0;
00165     do
00166     { //fill it until is has enough of the majority class
00167       it = instances_.GetClass(++ii);
00168       counts[it]++;
00169     } while(counts[it] < minBucketSize_ && ii < lastInstance);
00170 
00171     //while class remains the same, keep on filling
00172     while(ii < lastInstance && instances_.GetClass(ii) == it)
00173     {
00174       counts[it]++;
00175       ++ii;
00176     }//while
00177 
00178     //keep on while attr value is the same
00179     while(ii < lastInstance && instances_.at(ii-1)->at(attr) == instances_.at(ii)->at(attr))
00180     {
00181       mrs_natural index = instances_.GetClass(ii++);
00182       counts[index]++;
00183     }//while
00184 
00185     for(mrs_natural jj=0; jj<nClasses; jj++)
00186     {
00187       if(counts[jj] > counts[it])
00188       {
00189         it = jj;
00190       }//if
00191     }//for jj
00192 
00193     if(cl > 0)
00194     { //can we coalesce with previous class?
00195       if(counts[classifications[cl-1]] == counts[it])
00196         it = classifications[cl-1];
00197 
00198       if(it == classifications[cl-1])
00199         cl--;
00200     }//if
00201 
00202     correct += counts[it];
00203     classifications[cl] = it;
00204 
00205     if(ii < lastInstance)
00206       breakpoints[cl] = (((instances_.at(ii-1)->at(attr) + instances_.at(ii)->at(attr)) / 2.0));
00207 
00208     cl++;
00209   }//while
00210 
00211   //create a new rule with cl branches
00212   OneRRule *rule = new OneRRule(attr, cl, correct);
00213   for(mrs_natural vv=0; vv<cl; vv++)
00214   {
00215     rule->getClassifications()[vv] = classifications[vv];
00216     if(vv < (cl-1))
00217       rule->getBreakpoints()[vv] = breakpoints[vv];
00218 
00219   }//for vv
00220 
00221   return rule;
00222 }//newRule
00223 
00224 //Build the classifier from the data table
00225 void
00226 OneRClassifier::Build(mrs_natural nClasses)
00227 {
00228   //make sure any previous rule is out
00229   if(rule_!=NULL)
00230     delete rule_;
00231   rule_ = NULL;
00232 
00233   //scan through all the attributes(columns) of the table
00234   for(mrs_natural enu = 0; enu < instances_.getCols()-1; enu++)
00235   {
00236     //construct a new rule for this attribute
00237     OneRClassifier::OneRRule *r = newRule(enu, nClasses);
00238 
00239     //if a current rule does not exist or this new rule is better, replace old rule
00240     if(!rule_ || r->getCorrect() > rule_->getCorrect())
00241     {
00242       if(rule_!=NULL)
00243         delete rule_;
00244 
00245       rule_ = r;
00246     }//if
00247   }//for enu
00248 }//Build
00249 
00250 //Predict a class given a row of attribute data
00251 mrs_natural
00252 OneRClassifier::Predict(const realvec& in)
00253 {
00254   mrs_natural vv = 0;
00255   mrs_real instValue = in(rule_->getAttr());
00256 
00257   //find the breakpoint whose value exceeds the attribute value.
00258   while(vv < rule_->getnBreaks()-1 && instValue >= rule_->getBreakpoints()[vv])
00259   {
00260     vv++;
00261   }//while
00262 
00263   //return the class for this prediction.
00264   return rule_->getClassifications()[vv];
00265 }//Predict