00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019 #include "GaussianClassifier.h"
00020
00021 using std::ostringstream;
00022 using namespace Marsyas;
00023
00024 GaussianClassifier::GaussianClassifier(mrs_string name):MarSystem("GaussianClassifier",name)
00025 {
00026 prev_mode_= "predict";
00027 addControls();
00028 }
00029
00030
00031 GaussianClassifier::GaussianClassifier(const GaussianClassifier& a):MarSystem(a)
00032 {
00033 ctrl_mode_ = getctrl("mrs_string/mode");
00034 ctrl_nClasses_ = getctrl("mrs_natural/nClasses");
00035 ctrl_means_ = getctrl("mrs_realvec/means");
00036 ctrl_covars_ = getctrl("mrs_realvec/covars");
00037 prev_mode_ = "predict";
00038 }
00039
00040
00041 GaussianClassifier::~GaussianClassifier()
00042 {
00043 }
00044
00045
00046 MarSystem*
00047 GaussianClassifier::clone() const
00048 {
00049 return new GaussianClassifier(*this);
00050 }
00051
00052 void
00053 GaussianClassifier::addControls()
00054 {
00055 addctrl("mrs_string/mode", "train", ctrl_mode_);
00056 setctrlState("mrs_string/mode", true);
00057
00058 addctrl("mrs_natural/nClasses", 1, ctrl_nClasses_);
00059 setctrlState("mrs_natural/nClasses", true);
00060 addctrl("mrs_realvec/means", realvec(), ctrl_means_);
00061 addctrl("mrs_realvec/covars", realvec(), ctrl_covars_);
00062 }
00063
00064
00065 void
00066 GaussianClassifier::myUpdate(MarControlPtr sender)
00067 {
00068 mrs_natural o;
00069 (void) sender;
00070 MRSDIAG("GaussianClassifier.cpp - GaussianClassifier:myUpdate");
00071
00072 setctrl("mrs_natural/onSamples", getctrl("mrs_natural/inSamples"));
00073 setctrl("mrs_natural/onObservations", (mrs_natural)3);
00074 setctrl("mrs_real/osrate", getctrl("mrs_real/israte"));
00075
00076 mrs_natural nClasses = getctrl("mrs_natural/nClasses")->to<mrs_natural>();
00077
00078 setctrl("mrs_natural/onObservations", (mrs_natural)2 + nClasses);
00079
00080 mrs_natural mrows = (getctrl("mrs_realvec/means")->to<mrs_realvec>()).getRows();
00081 mrs_natural mcols = (getctrl("mrs_realvec/means")->to<mrs_realvec>()).getCols();
00082 mrs_string mode = getctrl("mrs_string/mode")->to<mrs_string>();
00083
00084 if (active_) {
00085
00086
00087
00088
00089 if ((nClasses != mrows) || (inObservations_ != mcols))
00090 {
00091 MarControlAccessor acc_means(ctrl_means_);
00092 MarControlAccessor acc_covars(ctrl_covars_);
00093
00094 realvec& means = acc_means.to<mrs_realvec>();
00095 realvec& covars = acc_covars.to<mrs_realvec>();
00096
00097 means.create(nClasses, inObservations_);
00098 covars.create(nClasses, inObservations_);
00099 labelSizes_.create(nClasses);
00100 }
00101
00102 if ((prev_mode_ == "train") && (mode == "predict"))
00103 {
00104
00105 MarControlAccessor acc_means(ctrl_means_);
00106 MarControlAccessor acc_covars(ctrl_covars_);
00107
00108 realvec& means = acc_means.to<mrs_realvec>();
00109 realvec& covars = acc_covars.to<mrs_realvec>();
00110
00111
00112 for (int l=0; l < nClasses; l++)
00113 {
00114 for (o=0; o < inObservations_; o++)
00115 {
00116 means(l,o) = means(l,o) / labelSizes_(l);
00117 covars(l,o) = covars(l,o) / labelSizes_(l);
00118 covars(l, o) = covars(l,o) - (means(l,o) * means(l,o));
00119 if (covars(l,o) != 0.0)
00120 {
00121 covars(l,o) = (mrs_real)(1.0 / covars(l,o));
00122 }
00123 }
00124 }
00125 prev_mode_ = mode;
00126 }
00127 }
00128 }
00129
00130 void
00131 GaussianClassifier::myProcess(realvec& in, realvec& out)
00132 {
00133 mrs_natural o,t;
00134 mrs_real v;
00135 mrs_string mode = ctrl_mode_->to<mrs_string>();
00136 mrs_natural nClasses = ctrl_nClasses_->to<mrs_natural>();
00137
00138 mrs_natural l;
00139 mrs_natural prediction = 0;
00140 mrs_real label;
00141
00142 mrs_real diff;
00143 mrs_real sq_sum=0.0;
00144
00145 MarControlAccessor acc_means(ctrl_means_);
00146 MarControlAccessor acc_covars(ctrl_covars_);
00147 realvec& means = acc_means.to<mrs_realvec>();
00148 realvec& covars = acc_covars.to<mrs_realvec>();
00149
00150
00151
00152
00153 if ((prev_mode_ == "predict") && (mode == "train"))
00154 {
00155 means.setval(0.0);
00156 covars.setval(0.0);
00157 labelSizes_.setval(0.0);
00158 }
00159
00160 if (mode == "train")
00161 {
00162 for (t = 0; t < inSamples_; t++)
00163 {
00164 label = in(inObservations_-1, t);
00165 if(label>=0)
00166 {
00167 for (o=0; o < inObservations_-1; o++)
00168 {
00169 v = in(o,t);
00170 means((mrs_natural)label,o) = means((mrs_natural)label,o) + v;
00171 covars((mrs_natural)label,o) = covars((mrs_natural)label,o) + v*v;
00172 out(0,t) = (mrs_real)label;
00173 out(1,t) = (mrs_real)label;
00174 for (int j=0; j < nClasses; j++)
00175 {
00176 out(j,t) = (mrs_real)0;
00177 if (j == label)
00178 out(j,t) = (mrs_real)1;
00179 }
00180 }
00181 labelSizes_((mrs_natural)label) = labelSizes_((mrs_natural)label) + 1;
00182 }
00183 }
00184 }
00185
00186 if (mode == "predict")
00187 {
00188
00189 mrs_real min = MAXREAL;
00190
00191 for (t = 0; t < inSamples_; t++)
00192 {
00193 label = in(inObservations_-1, t);
00194
00195 for (l=0; l < nClasses; l++)
00196 {
00197 sq_sum = 0.0;
00198
00199 for (o=0; o < inObservations_-1; o++)
00200 {
00201 v = in(o,t);
00202 diff = (v - means(l,o));
00203 sq_sum += (diff * covars(l,o) * diff);
00204 }
00205
00206 if (sq_sum < min)
00207 {
00208 min = sq_sum;
00209 prediction = l;
00210 }
00211
00212
00213 out (2+l, t) = sq_sum;
00214 }
00215 out(0,t) = (mrs_real)prediction;
00216 out(1,t) = (mrs_real)label;
00217 }
00218 }
00219
00220 prev_mode_ = mode;
00221 }