FFNNMultiClassCrossEntropy.cpp
Go to the documentation of this file.
1 #include<shark/Data/Dataset.h>
3 #include<shark/Models/FFNet.h>
8 #include<shark/Models/FFNet.h>
9 
10 using namespace shark;
11 using namespace std;
12 
13 // data generating distribution for our toy
14 // multi-category classification problem
15 /// @cond EXAMPLE_SYMBOLS
16 class Problem : public LabeledDataDistribution<RealVector, unsigned int>
17 {
18 private:
19  double m_noise;
20 public:
21  Problem(double noise):m_noise(noise){}
22  void draw(RealVector& input, unsigned int& label)const
23  {
24  label = random::discrete(random::globalRng, 0, 4);
25  input.resize(2);
26  input(0) = m_noise * random::gauss(random::globalRng) + 3.0 * std::cos((double)label);
27  input(1) = m_noise * random::gauss(random::globalRng) + 3.0 * std::sin((double)label);
28  }
29 };
30 /// @endcond
31 
32 int main(){
33  //get problem data
34  Problem problem(1.0);
35  LabeledData<RealVector,unsigned int> training = problem.generateDataset(1000);
36  LabeledData<RealVector,unsigned int> test = problem.generateDataset(100);
37 
38  std::size_t inputs=inputDimension(training);
39  std::size_t outputs = numberOfClasses(training);
40  std::size_t hiddens = 10;
41  unsigned numberOfSteps = 1000;
42 
43  //create network and initialize weights random uniform
44  FFNet<LogisticNeuron,LinearNeuron> network;
45  network.setStructure(inputs,hiddens,outputs);
46  initRandomUniform(network,-0.1,0.1);
47 
48  //create error function
49  CrossEntropy loss;
50  ErrorFunction error(training,&network,&loss);
51 
52  // loss for evaluation
53  // The zeroOneLoss for multiclass problems assigns the class to the highest output
55 
56  // evaluate initial network
57  Data<RealVector> prediction = network(training.inputs());
58  cout << "classification error before learning:\t" << loss01.eval(training.labels(), prediction) << endl;
59 
60  //initialize Rprop
61  IRpropPlus optimizer;
62  error.init();
63  optimizer.init(error);
64 
65  for(unsigned step = 0; step != numberOfSteps; ++step)
66  optimizer.step(error);
67 
68  // evaluate solution found by training
69  network.setParameterVector(optimizer.solution().point); // set weights to weights found by learning
70  prediction = network(training.inputs());
71  cout << "classification error after learning:\t" << loss01(training.labels(), prediction) << endl;
72 }