Internal architecture changes (to allow diferent activation functions for each layer and to allow hidden layers to have different number of nodes).

This commit is contained in:
davidjacnogueira
2016-11-09 02:31:53 +00:00
parent f647b05f70
commit 0a636416ed
7 changed files with 193 additions and 100 deletions

View File

@@ -27,14 +27,13 @@ void MLP::GetOutput(const std::vector<double> &input,
if (m_num_hidden_layers == 0)
temp_size = m_num_outputs;
else
temp_size = m_num_nodes_per_hidden_layer;
temp_size = m_layers_nodes[1];
std::vector<double> temp_in(m_num_inputs, 0.0);
std::vector<double> temp_out(temp_size, 0.0);
temp_in = input;
//m_layers.size() equals (m_num_hidden_layers + 1)
for (int i = 0; i < (m_num_hidden_layers + 1); ++i) {
for (int i = 0; i < m_layers.size(); ++i) {
if (i > 0) {
//Store this layer activation
if (all_layers_activations != nullptr)
@@ -43,11 +42,9 @@ void MLP::GetOutput(const std::vector<double> &input,
temp_in.clear();
temp_in = temp_out;
temp_out.clear();
temp_out.resize((i == m_num_hidden_layers) ?
m_num_outputs :
m_num_nodes_per_hidden_layer);
temp_out.resize(m_layers[i].GetOutputSize());
}
m_layers[i].GetOutputAfterSigmoid(temp_in, &temp_out);
m_layers[i].GetOutputAfterActivationFunction(temp_in, &temp_out);
}
if (temp_out.size() > 1)
@@ -106,8 +103,9 @@ void MLP::UpdateMiniBatch(const std::vector<TrainingSample> &training_sample_set
// }
//}
size_t i = 0;
double current_iteration_cost_function = 0.0;
for (i = 0; i < max_iterations; i++) {
double current_iteration_cost_function = 0.0;
current_iteration_cost_function = 0.0;
for (auto & training_sample_with_bias : training_sample_set_with_bias) {
std::vector<double> predicted_output;
std::vector< std::vector<double> > all_layers_activations;
@@ -153,7 +151,10 @@ void MLP::UpdateMiniBatch(const std::vector<TrainingSample> &training_sample_set
break;
}
LOG(INFO) << "******************************" ;
LOG(INFO) << "Iteration " << i << " cost function f(error): "
<< current_iteration_cost_function;
LOG(INFO) << "******************************";
LOG(INFO) << "******* TRAINING ENDED *******";
LOG(INFO) << "******* " << i << " iters *******";
LOG(INFO) << "******************************";