Marsyas  0.5.0-beta1
/Users/jleben/code/marsyas/src/marsyas/marsystems/WekaSource.h
Go to the documentation of this file.
00001 
00008 #ifndef MARSYAS_newWEKASOURCE_H
00009 #define MARSYAS_newWEKASOURCE_H
00010 
00011 #include <marsyas/system/MarSystem.h>
00012 #include <marsyas/WekaData.h>
00013 #include <list>
00014 #include <vector>
00015 #include <iostream>
00016 #include <cstdlib>
00017 #include <cstring>
00018 //using namespace std;
00019 
00020 namespace Marsyas
00021 {
00022 class WekaFoldData : public WekaData
00023 {
00024 public:
00025   ~WekaFoldData() {}
00026 
00027   typedef enum
00028   {
00029     None,
00030     Training,
00031     Predict
00032   } nextMode;
00033 
00034 private:
00035   mrs_natural foldCount_;
00036 
00037   mrs_real  rstep_;
00038 //      mrs_natural predictSum_;
00039   mrs_natural excludeSectionStart_;
00040   mrs_natural excludeSectionEnd_;
00041 
00042   mrs_natural iteration_;
00043   mrs_natural currentIndex_;
00044 //      mrs_natural predictIndex_;
00045 
00046 public:
00047   void SetupkFoldSections(const WekaData& data, mrs_natural foldCount, mrs_natural classAttr=-1)
00048   {
00049     MRSASSERT(foldCount>0);
00050     foldCount_ = foldCount;
00051 
00052     //create the dataset with same number of columns as input data
00053     this->Create(data.getCols());
00054     if(classAttr<0)
00055     { //if no class specified, copy all data into this instance
00056       this->assign(data.begin(), data.end());
00057     }
00058     else
00059     { //otherwise only copy rows that match input class into this dataset
00060       for(mrs_natural ii=0; ii<(mrs_natural)data.size(); ++ii)
00061         if(data.GetClass(ii)==classAttr)
00062           this->Append(data[ii]);
00063     }//else
00064 
00065     //setup fold sections
00066     this->Reset();
00067 
00068   }//SetupkFoldSections
00069 
00070   //setup the fold sections for this dataset.
00071   void Reset()
00072   {
00073 
00074 
00075     this->Shuffle();
00076 
00077     rstep_ = (mrs_real)this->size() / (mrs_real)foldCount_;
00078 
00079     if (foldCount_ > (mrs_natural)this->size())
00080     {
00081       std::cout << "Folds exceed number of instances" << std::endl;
00082       std::cout << "foldCount_ = " << foldCount_ << std::endl;
00083       std::cout << "size = " << this->size() << std::endl;;
00084       exit(1);
00085     }
00086 
00087 
00088     iteration_ = 0;
00089 
00090     excludeSectionStart_ = 0;
00091     excludeSectionEnd_ = ((mrs_natural)rstep_) - 1;
00092     currentIndex_ = excludeSectionEnd_ + 1;
00093   }
00094 
00095   std::vector<mrs_real> *Next(nextMode& next)
00096   {
00097 
00098     std::vector<mrs_real> *ret = this->at(currentIndex_);
00099 
00100     if(currentIndex_ == excludeSectionEnd_)
00101     {
00102       iteration_++;
00103       if(iteration_ >= foldCount_)
00104       {
00105         next = None;
00106         return ret;
00107       }//if
00108 
00109       excludeSectionStart_ = excludeSectionEnd_ + 1;
00110       if(iteration_ == (foldCount_ - 1))
00111       {
00112         excludeSectionEnd_ = (mrs_natural)this->size() - 1;
00113         currentIndex_ = 0;
00114       }
00115       else
00116       {
00117         excludeSectionEnd_ = ((mrs_natural)((iteration_+1) * rstep_)) - 1;
00118         currentIndex_ = excludeSectionEnd_ + 1;
00119       }
00120 
00121 
00122       next = Training;
00123       return ret;
00124     }//if
00125 
00126     currentIndex_++;
00127 
00128 
00129     if(currentIndex_ >= (mrs_natural)this->size())
00130       currentIndex_ = 0;
00131 
00132     if(currentIndex_ >= excludeSectionStart_ && currentIndex_ <= excludeSectionEnd_)
00133       next = Predict;
00134     else
00135       next = Training;
00136 
00137 
00138     return ret;
00139 
00140   }//Next
00141 
00142 };
00143 
00144 typedef enum
00145 {
00146   None,
00147   kFoldStratified,
00148   kFoldNonStratified,
00149   UseTestSet,
00150   PercentageSplit,
00151   OutputInstancePair
00152 } ValidationModeEnum;
00153 
00154 class marsyas_EXPORT WekaSource : public MarSystem
00155 {
00156 public:
00157   WekaSource(std::string name);
00158   WekaSource(const WekaSource& a);
00159   ~WekaSource();
00160 
00161   MarSystem *clone()const;
00162   void myProcess(realvec& in, realvec& out);
00163 
00164 private:
00165   void addControls();
00166   void myUpdate(MarControlPtr sender);
00167 
00168   //control values
00169   std::string filename_;                        //name of arff file to read
00170   std::string attributesToInclude_;         //list of attributes to include in dataset
00171 
00172   //these are the class names froun in the arff file header
00173   std::vector<std::string>classesFound_;
00174   // if there are no classes, we're doing regression
00175   MarControlPtr ctrl_regression_;
00176 
00177   std::string relation_;
00178 
00179   //these are the attribute names found in the arff file header
00180   std::vector<std::string>attributesFound_;
00181 
00182   //Holds the actual attribute data read from the arff file
00183   WekaData data_;
00184 
00185   //an array of bools that specify if an attribute from the arff file should be included
00186   //in the dataset.
00187   std::vector<bool>attributesIncluded_;
00188 
00189   //the list of attributes that are to be included in the dataset
00190   std::vector<std::string>attributesIncludedList_;
00191 
00192   //the validation mode enum to use
00193   ValidationModeEnum validationModeEnum_;
00194 
00195   //Common validation method data members
00196   mrs_natural currentIndex_;
00197 
00198   //kFold Stratified validation method data members
00199   mrs_natural foldCount_;
00200   WekaFoldData foldData_;
00201   WekaFoldData::nextMode foldCurrentMode_;
00202   WekaFoldData::nextMode foldNextMode_;
00203 
00204   //kFold NonStratified validation method data members
00205   std::vector<WekaFoldData> foldClassData_;
00206   mrs_natural foldClassDataIndex_;
00207 
00208   //UseTestSet validation method data members
00209   WekaData useTestSetData_;
00210 
00211   //PercentageSplit validation method data members
00212   mrs_natural percentageIndex_;
00213 
00214   void handleDefault(bool trainMode, realvec& out);
00215   void handleInstancePair(realvec& out);
00216   void handleFoldingNonStratifiedValidation(bool trainMode, realvec &out);
00217   void handleFoldingStratifiedValidation(bool trainMode, realvec &out);
00218   void handleUseTestSet(bool trainMode, realvec &out);
00219   void handlePercentageSplit(bool trainMode, realvec &out);
00220 
00221 private:
00222   mrs_natural findClass(const char *className)const;
00223   mrs_natural findAttribute(const char *attribute)const;
00224   mrs_natural parseAttribute(const char *attribute)const;
00225 
00226   void parseAttributesToInclude(const std::string& attributesToInclude);
00227   void loadFile(const std::string& filename, const std::string& attributesToExtract, WekaData& data);
00228   void parseHeader(std::ifstream& mis, const std::string& filename, const std::string& attributesToExtract);
00229   void parseData(std::ifstream& mis, const std::string& filename, WekaData& data);
00230 
00231 };//class WekaSource
00232 }//namespace Marsyas
00233 
00234 #endif