00001
00008 #ifndef MARSYAS_newWEKASOURCE_H
00009 #define MARSYAS_newWEKASOURCE_H
00010
00011 #include "MarSystem.h"
00012 #include "WekaData.h"
00013 #include <list>
00014 #include <vector>
00015 #include <iostream>
00016 #include <cstdlib>
00017 #include <cstring>
00018
00019
00020 namespace Marsyas
00021 {
00022 class WekaFoldData : public WekaData
00023 {
00024 public:
00025 ~WekaFoldData() {}
00026
00027 typedef enum
00028 {
00029 None,
00030 Training,
00031 Predict
00032 }nextMode;
00033
00034 private:
00035 mrs_natural foldCount_;
00036
00037 mrs_real rstep_;
00038
00039 mrs_natural excludeSectionStart_;
00040 mrs_natural excludeSectionEnd_;
00041
00042 mrs_natural iteration_;
00043 mrs_natural currentIndex_;
00044
00045
00046 public:
00047 void SetupkFoldSections(const WekaData& data, mrs_natural foldCount, mrs_natural classAttr=-1)
00048 {
00049 MRSASSERT(foldCount>0);
00050 foldCount_ = foldCount;
00051
00052
00053 this->Create(data.getCols());
00054 if(classAttr<0)
00055 {
00056 this->assign(data.begin(), data.end());
00057 }
00058 else
00059 {
00060 for(mrs_natural ii=0; ii<(mrs_natural)data.size(); ++ii)
00061 if(data.GetClass(ii)==classAttr)
00062 this->Append(data[ii]);
00063 }
00064
00065
00066 this->Reset();
00067
00068 }
00069
00070
00071 void Reset()
00072 {
00073
00074
00075 this->Shuffle();
00076
00077 rstep_ = (mrs_real)this->size() / (mrs_real)foldCount_;
00078
00079 if (foldCount_ > (mrs_natural)this->size())
00080 {
00081 std::cout << "Folds exceed number of instances" << std::endl;
00082 std::cout << "foldCount_ = " << foldCount_ << std::endl;
00083 std::cout << "size = " << this->size() << std::endl;;
00084 exit(1);
00085 }
00086
00087
00088 iteration_ = 0;
00089
00090 excludeSectionStart_ = 0;
00091 excludeSectionEnd_ = ((mrs_natural)rstep_) - 1;
00092 currentIndex_ = excludeSectionEnd_ + 1;
00093 }
00094
00095 std::vector<mrs_real> *Next(nextMode& next)
00096 {
00097
00098 std::vector<mrs_real> *ret = this->at(currentIndex_);
00099
00100 if(currentIndex_ == excludeSectionEnd_)
00101 {
00102 iteration_++;
00103 if(iteration_ >= foldCount_)
00104 {
00105 next = None;
00106 return ret;
00107 }
00108
00109 excludeSectionStart_ = excludeSectionEnd_ + 1;
00110 if(iteration_ == (foldCount_ - 1))
00111 {
00112 excludeSectionEnd_ = (mrs_natural)this->size() - 1;
00113 currentIndex_ = 0;
00114 }
00115 else
00116 {
00117 excludeSectionEnd_ = ((mrs_natural)((iteration_+1) * rstep_)) - 1;
00118 currentIndex_ = excludeSectionEnd_ + 1;
00119 }
00120
00121
00122 next = Training;
00123 return ret;
00124 }
00125
00126 currentIndex_++;
00127
00128
00129 if(currentIndex_ >= (mrs_natural)this->size())
00130 currentIndex_ = 0;
00131
00132 if(currentIndex_ >= excludeSectionStart_ && currentIndex_ <= excludeSectionEnd_)
00133 next = Predict;
00134 else
00135 next = Training;
00136
00137
00138 return ret;
00139
00140 }
00141
00142 };
00143
00144 typedef enum
00145 {
00146 None,
00147 kFoldStratified,
00148 kFoldNonStratified,
00149 UseTestSet,
00150 PercentageSplit,
00151 OutputInstancePair
00152 }ValidationModeEnum;
00153
00154 class marsyas_EXPORT WekaSource : public MarSystem
00155 {
00156 public:
00157 WekaSource(std::string name);
00158 WekaSource(const WekaSource& a);
00159 ~WekaSource();
00160
00161 MarSystem *clone()const;
00162 void myProcess(realvec& in, realvec& out);
00163
00164 private:
00165 void addControls();
00166 void myUpdate(MarControlPtr sender);
00167
00168
00169 std::string filename_;
00170 std::string attributesToInclude_;
00171 std::string validationMode_;
00172
00173
00174 std::vector<std::string>classesFound_;
00175
00176 MarControlPtr ctrl_regression_;
00177
00178 std::string relation_;
00179
00180
00181 std::vector<std::string>attributesFound_;
00182
00183
00184 WekaData data_;
00185
00186
00187
00188 std::vector<bool>attributesIncluded_;
00189
00190
00191 std::vector<std::string>attributesIncludedList_;
00192
00193
00194 ValidationModeEnum validationModeEnum_;
00195
00196
00197 mrs_natural currentIndex_;
00198
00199
00200 mrs_natural foldCount_;
00201 WekaFoldData foldData_;
00202
00203
00204 std::vector<WekaFoldData> foldClassData_;
00205 mrs_natural foldClassDataIndex_;
00206
00207
00208 std::string useTestSetFilename_;
00209 WekaData useTestSetData_;
00210
00211
00212 mrs_natural percentageSplit_;
00213 mrs_natural percentageIndex_;
00214
00215 void handleDefault(bool trainMode, realvec& out);
00216 void handleInstancePair(realvec& out);
00217 void handleFoldingNonStratifiedValidation(bool trainMode, realvec &out);
00218 void handleFoldingStratifiedValidation(bool trainMode, realvec &out);
00219 void handleUseTestSet(bool trainMode, realvec &out);
00220 void handlePercentageSplit(bool trainMode, realvec &out);
00221
00222 private:
00223 mrs_natural findClass(const char *className)const;
00224 mrs_natural findAttribute(const char *attribute)const;
00225 mrs_natural parseAttribute(const char *attribute)const;
00226
00227 void parseAttributesToInclude(const std::string& attributesToInclude);
00228 void loadFile(const std::string& filename, const std::string& attributesToExtract, WekaData& data);
00229 void parseHeader(std::ifstream& mis, const std::string& filename, const std::string& attributesToExtract);
00230 void parseData(std::ifstream& mis, const std::string& filename, WekaData& data);
00231
00232 };
00233 }
00234
00235 #endif