00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00036 #include "ClassificationReport.h"
00037
00038
00039 using std::ostringstream;
00040 using std::vector;
00041 using std::cout;
00042 using std::endl;
00043
00044
00045 using namespace Marsyas;
00046
00047 ClassificationReport::ClassificationReport(mrs_string name) : MarSystem("ClassificationReport", name)
00048 {
00049 regCorr.sumClass = 0;
00050 regCorr.sumSqrClass = 0;
00051 regCorr.sumClassPredicted = 0;
00052 regCorr.sumPredicted = 0;
00053 regCorr.sumSqrPredicted = 0;
00054 regCorr.withClass = 0;
00055 addControls();
00056 }
00057
00058
00059 ClassificationReport::~ClassificationReport()
00060 {
00061 }
00062
00063 MarSystem *ClassificationReport::clone() const
00064 {
00065 return new ClassificationReport(*this);
00066 }
00067
00068 void ClassificationReport::addControls()
00069 {
00070 addctrl("mrs_string/mode", "train");
00071 setctrlState("mrs_string/mode", true);
00072 addctrl("mrs_natural/nClasses", 2);
00073 setctrlState("mrs_natural/nClasses", true);
00074 addctrl("mrs_string/classNames", "Music,Speech");
00075 setctrlState("mrs_string/classNames", true);
00076 addctrl("mrs_bool/done", false);
00077 addctrl("mrs_bool/regression", false);
00078 }
00079
00080 void ClassificationReport::myUpdate(MarControlPtr sender)
00081 {
00082 (void) sender;
00083 MRSDIAG("ClassificationReport.cpp - ClassificationReport:myUpdate");
00084
00085 setctrl("mrs_natural/onSamples", getctrl("mrs_natural/inSamples"));
00086 setctrl("mrs_natural/onObservations", (mrs_natural)2);
00087 setctrl("mrs_real/osrate", getctrl("mrs_real/israte"));
00088
00089 mrs_natural nClasses = getctrl("mrs_natural/nClasses")->to<mrs_natural>();
00090 if (confusionMatrix.getRows() != nClasses)
00091 {
00092 confusionMatrix.create(nClasses, nClasses);
00093 }
00094 classNames = getctrl("mrs_string/classNames")->to<mrs_string>();
00095
00096 }
00097
00098 void ClassificationReport::myProcess(realvec& in, realvec& out)
00099 {
00100
00101 static int count = 0;
00102
00103
00104 mrs_natural t;
00105 mrs_string mode = getctrl("mrs_string/mode")->to<mrs_string>();
00106
00107
00108 bool done = getctrl("mrs_bool/done")->to<mrs_bool>();
00109
00110
00111
00112 if ((mode == "train") && !done)
00113 {
00114 for (t=0; t < inSamples_; t++)
00115 {
00116 mrs_real label = in(inObservations_-1, t);
00117 out(0,t) = label;
00118 out(1,t) = label;
00119 }
00120 }
00121 else if ((mode == "predict") && !done)
00122 {
00123 count++;
00124
00125 for (t=0; t < inSamples_; t++)
00126 {
00127 if (getctrl("mrs_bool/regression")->isTrue()) {
00128 mrs_real prediction = in(0, t);
00129 mrs_real actual = in(1, t);
00130
00131 regCorr.sumClass += actual;
00132 regCorr.sumSqrClass += actual*actual;
00133 regCorr.sumClassPredicted += actual*prediction;
00134 regCorr.sumPredicted += prediction;
00135 regCorr.sumSqrPredicted += prediction*prediction;
00136 regCorr.withClass += 1.0;
00137 out(0,t) = prediction;
00138 out(1,t) = actual;
00139 } else {
00140
00141 mrs_natural prediction = (mrs_natural)in(0, t);
00142 mrs_natural actual = (mrs_natural)in(1, t);
00143
00144 confusionMatrix(actual,prediction)++;
00145
00146
00147 out(0,t) = prediction;
00148 out(1,t) = actual;
00149 }
00150 }
00151
00152
00153
00154
00155 }
00156
00157 if (done)
00158 {
00159 if (getctrl("mrs_bool/regression")->isTrue()) {
00160
00161 mrs_real varActual = regCorr.sumSqrClass -
00162 (regCorr.sumClass*regCorr.sumClass) /
00163 regCorr.withClass;
00164 mrs_real varPredicted = regCorr.sumSqrPredicted -
00165 (regCorr.sumPredicted*regCorr.sumPredicted) /
00166 regCorr.withClass;
00167 mrs_real varProd = regCorr.sumClassPredicted -
00168 (regCorr.sumClass*regCorr.sumPredicted) /
00169 regCorr.withClass;
00170
00171 mrs_real correlation;
00172 if (varActual * varPredicted <= 0) {
00173 correlation = 0.0;
00174 } else {
00175 correlation = varProd / sqrt(varActual*varPredicted);
00176 }
00177
00178 mrs_real meanAbsoluteError = 0.0;
00179 mrs_real rootMeanSquaredError = 0.0;
00180 mrs_real relativeAbsoluteError = 0.0;
00181 mrs_real rootRelativeSquaredError = 0.0;
00182 mrs_real instances = 0;
00183 cout << "=== ClassificationReport ===" << endl << endl;
00184 cout << "Correlation coefficient" << "\t\t\t" << correlation << "\t" << endl;
00185 cout << "Mean absolute error" << "\t\t\t" << meanAbsoluteError << endl;
00186 cout << "Root mean squared error" << "\t\t\t" << rootMeanSquaredError << endl;
00187 cout << "Relative absolute error" << "\t\t\t" << relativeAbsoluteError << endl;
00188 cout << "Root relative squared error" << "\t\t" << rootRelativeSquaredError << endl;
00189 cout << "Total Number of Instances" << "\t\t" << instances << endl << endl;
00190 } else {
00191
00192 summaryStatistics stats = computeSummaryStatistics(confusionMatrix);
00193 cout << "=== ClassificationReport ===" << endl << endl;
00194
00195 cout << "Correctly Classified Instances" << "\t\t" << stats.correctInstances << "\t";
00196 cout << (((mrs_real)stats.correctInstances / (mrs_real)stats.instances)*100.0);
00197 cout << " %" << endl;
00198
00199 cout << "Incorrectly Classified Instances" << "\t" << (stats.instances - stats.correctInstances) << "\t";
00200 cout << (((mrs_real)(stats.instances - stats.correctInstances) / (mrs_real)stats.instances)*100.0);
00201 cout << " %" << endl;
00202
00203 cout << "Kappa statistic" << "\t\t\t\t" << stats.kappa << "\t" << endl;
00204 cout << "Mean absolute error" << "\t\t\t" << stats.meanAbsoluteError << endl;
00205 cout << "Root mean squared error" << "\t\t\t" << stats.rootMeanSquaredError << endl;
00206 cout << "Relative absolute error" << "\t\t\t" << stats.relativeAbsoluteError << endl;
00207 cout << "Root relative squared error" << "\t\t" << stats.rootRelativeSquaredError << endl;
00208 cout << "Total Number of Instances" << "\t\t" << stats.instances << endl << endl;
00209
00210 cout << "=== Confusion Matrix ===";
00211 cout << endl; cout << endl;
00212
00213 if(!classNames.size())
00214 classNames = ",";
00215
00216 mrs_string::size_type from = 0;
00217 mrs_string::size_type to = classNames.find(",");
00218
00219 mrs_natural correct = 0;
00220 mrs_natural total = 0;
00221 for (mrs_natural x = 0;x<confusionMatrix.getCols();x++)
00222 cout << "\t" << (char)(x+'a');
00223 cout << "\t" << "<-- classified as";
00224 cout << endl;
00225
00226 for(mrs_natural y = 0;y<confusionMatrix.getRows();y++)
00227 {
00228 for(mrs_natural x = 0;x<confusionMatrix.getCols();x++)
00229 {
00230 mrs_natural value = (mrs_natural)confusionMatrix(y, x);
00231 total += value;
00232 if(x == y)
00233 correct += value;
00234
00235 cout << "\t" << value;
00236 }
00237 cout << "\t" << "| ";
00238 if(from < classNames.size())
00239 {
00240 cout << (char)(y+'a') << " = " << classNames.substr(from, to - from);
00241 from = to + 1;
00242 to = classNames.find(",", from);
00243 if(to == mrs_string::npos)
00244 to = classNames.size();
00245 }
00246 cout << endl;
00247 }
00248 cout << (total > 0 ? correct * 100 / total: 0) << "% classified correctly (" << correct << "/" << total << ")" << endl;
00249 }
00250 }
00251 }
00252
00253 summaryStatistics ClassificationReport::computeSummaryStatistics(const realvec& mat)
00254 {
00255 MRSASSERT(mat.getCols()==mat.getRows());
00256
00257 summaryStatistics stats;
00258
00259 mrs_natural size = mat.getCols();
00260
00261 vector<mrs_natural>rowSums(size);
00262 for(int ii=0; ii<size; ++ii) rowSums[ii] = 0;
00263 vector<mrs_natural>colSums(size);
00264 for(int ii=0; ii<size; ++ii) colSums[ii] = 0;
00265 mrs_natural diagonalSum = 0;
00266
00267 mrs_natural instanceCount = 0;
00268 for(mrs_natural row=0; row<size; row++)
00269 {
00270 for(mrs_natural col=0; col<size; col++)
00271 {
00272 mrs_natural num = (mrs_natural)mat(row,col);
00273 instanceCount += num;
00274
00275 rowSums[row] += num;
00276 colSums[col] += num;
00277
00278 if(row==col)
00279 diagonalSum += num;
00280 }
00281 }
00282
00283
00284
00285
00286
00287
00288
00289 mrs_natural N = instanceCount;
00290 mrs_natural N2 = (N*N);
00291 stats.instances = instanceCount;
00292 stats.correctInstances = diagonalSum;
00293
00294 mrs_natural sum = 0;
00295 for(mrs_natural ii=0; ii<size; ++ii)
00296 {
00297 sum += (rowSums[ii] * colSums[ii]);
00298 }
00299 mrs_real PE = (mrs_real)sum / (mrs_real)N2;
00300 mrs_real PA = (mrs_real)diagonalSum / (mrs_real)N;
00301 stats.kappa = (PA - PE) / (1.0 - PE);
00302
00303 mrs_natural not_diagonal_sum = instanceCount - diagonalSum;
00304 mrs_real MeanAbsoluteError = (mrs_real)not_diagonal_sum / (mrs_real)instanceCount;
00305
00306 stats.meanAbsoluteError = MeanAbsoluteError;
00307
00308 mrs_real RootMeanSquaredError = sqrt(MeanAbsoluteError);
00309
00310 stats.rootMeanSquaredError = RootMeanSquaredError;
00311
00312 mrs_real RelativeAbsoluteError = (MeanAbsoluteError / 0.5) * 100.0;
00313
00314 stats.relativeAbsoluteError = RelativeAbsoluteError;
00315
00316 mrs_real RootRelativeSquaredError = (RootMeanSquaredError / (0.5)) * 100.0;
00317
00318 stats.rootRelativeSquaredError = RootRelativeSquaredError;
00319
00320 return stats;
00321 }