Bug fix in Linux compilation.

This commit is contained in:
davidjacnogueira
2018-03-31 02:59:42 +01:00
parent f924edd50a
commit 4f55a15543
3 changed files with 43 additions and 24 deletions

View File

@@ -21,8 +21,16 @@ INITIALIZE_EASYLOGGINGPP
// Disclaimer: This is NOT an example of good machine learning practices // Disclaimer: This is NOT an example of good machine learning practices
// regarding training/testing dataset partitioning. // regarding training/testing dataset partitioning.
const int input_size = 4;
const int number_classes = 3;
#if defined(_WIN32)
const char *iris_dataset = "../../data/iris.data"; const char *iris_dataset = "../../data/iris.data";
const std::array<std::string, 3> class_names = const std::string iris_mlp_weights = "../../data/iris.mlp";
#else
const char *iris_dataset = "./data/iris.data";
const std::string iris_mlp_weights = "./data/iris.mlp";
#endif
const std::array<std::string, number_classes> class_names =
{ "Iris-setosa", "Iris-versicolor", "Iris-virginica" }; { "Iris-setosa", "Iris-versicolor", "Iris-virginica" };
@@ -46,32 +54,41 @@ bool load_data(int *samples,
LOG(INFO) << "Loading " << (*samples) LOG(INFO) << "Loading " << (*samples)
<< " data points from " << iris_dataset << "."; << " data points from " << iris_dataset << ".";
// Allocate memory for input and output data. // Allocate memory for input and output data.
input->resize((*samples) * 4); input->resize((*samples) * input_size);
iris_class->resize((*samples) * 3); iris_class->resize((*samples) * number_classes);
// Read the file into our arrays. // Read the file into our arrays.
int i, j; int i, j;
for (i = 0; i < (*samples); ++i) { for (i = 0; i < (*samples); ++i) {
double *p = &((*input)[0]) + i * 4; double *p = &((*input)[0]) + i * input_size;
double *c = &((*iris_class)[0]) + i * 3; double *c = &((*iris_class)[0]) + i * number_classes;
c[0] = c[1] = c[2] = 0.0; for (int k = 0; k < number_classes; k++) {
c[i] = 0.0;
}
fgets(line, 1024, in); fgets(line, 1024, in);
char *split = strtok(line, ","); char *split = strtok(line, ",");
for (j = 0; j < 4; ++j) { for (j = 0; j < 4; ++j) {
p[j] = atof(split); p[j] = atof(split);
split = strtok(0, ","); split = strtok(NULL, ",");
} }
split[strlen(split) - 1] = 0; if (strlen(split) >= 1 && split[strlen(split) - 1] == '\n')
split[strlen(split) - 1] = '\0';
if (strlen(split) >= 2 && split[strlen(split) - 2] == '\r')
split[strlen(split) - 2] = '\0';
if (strcmp(split, class_names[0].c_str()) == 0) { if (strcmp(split, class_names[0].c_str()) == 0) {
c[0] = 1.0; c[0] = 1.0;
} else if (strcmp(split, class_names[1].c_str()) == 0) { }
else if (strcmp(split, class_names[1].c_str()) == 0) {
c[1] = 1.0; c[1] = 1.0;
} else if (strcmp(split, class_names[2].c_str()) == 0) { }
else if (strcmp(split, class_names[2].c_str()) == 0) {
c[2] = 1.0; c[2] = 1.0;
} else { }
else {
LOG(ERROR) << "Unknown iris_class " << split LOG(ERROR) << "Unknown iris_class " << split
<< "."; << ".";
return false; return false;
@@ -126,11 +143,11 @@ int main(int argc, char *argv[]) {
LOG(INFO) << "Training for " << loops << " loops over data."; LOG(INFO) << "Training for " << loops << " loops over data.";
my_mlp.Train(training_sample_set_with_bias, .01, loops, 0.10, false); my_mlp.Train(training_sample_set_with_bias, .01, loops, 0.10, false);
my_mlp.SaveMLPNetwork(std::string("../../data/iris.mlp")); my_mlp.SaveMLPNetwork(iris_mlp_weights);
} }
//Destruction/Construction of a MLP object to show off saving and loading a trained model //Destruction/Construction of a MLP object to show off saving and loading a trained model
{ {
MLP my_mlp(std::string("../../data/iris.mlp")); MLP my_mlp(iris_mlp_weights);
int correct = 0; int correct = 0;
for (int j = 0; j < samples; ++j) { for (int j = 0; j < samples; ++j) {
@@ -141,9 +158,11 @@ int main(int argc, char *argv[]) {
if (iris_class[j * 3 + 0] == 1.0 && class_id == 0) { if (iris_class[j * 3 + 0] == 1.0 && class_id == 0) {
++correct; ++correct;
} else if (iris_class[j * 3 + 1] == 1.0 && class_id == 1) { }
else if (iris_class[j * 3 + 1] == 1.0 && class_id == 1) {
++correct; ++correct;
} else if (iris_class[j * 3 + 2] == 1.0 && class_id == 2) { }
else if (iris_class[j * 3 + 2] == 1.0 && class_id == 2) {
++correct; ++correct;
} }
} }

View File

@@ -28,7 +28,7 @@ MLP::MLP(const std::vector<uint64_t> & layers_nodes,
constant_weight_init); constant_weight_init);
}; };
MLP::MLP(std::string & filename) { MLP::MLP(const std::string & filename) {
LoadMLPNetwork(filename); LoadMLPNetwork(filename);
} }
@@ -58,7 +58,7 @@ void MLP::CreateMLP(const std::vector<uint64_t> & layers_nodes,
} }
}; };
void MLP::SaveMLPNetwork(std::string & filename)const { void MLP::SaveMLPNetwork(const std::string & filename)const {
FILE * file; FILE * file;
file = fopen(filename.c_str(), "wb"); file = fopen(filename.c_str(), "wb");
fwrite(&m_num_inputs, sizeof(m_num_inputs), 1, file); fwrite(&m_num_inputs, sizeof(m_num_inputs), 1, file);
@@ -71,7 +71,7 @@ void MLP::SaveMLPNetwork(std::string & filename)const {
} }
fclose(file); fclose(file);
}; };
void MLP::LoadMLPNetwork(std::string & filename) { void MLP::LoadMLPNetwork(const std::string & filename) {
m_layers_nodes.clear(); m_layers_nodes.clear();
m_layers.clear(); m_layers.clear();

View File

@@ -24,11 +24,11 @@ public:
const std::vector<std::string> & layers_activfuncs, const std::vector<std::string> & layers_activfuncs,
bool use_constant_weight_init = false, bool use_constant_weight_init = false,
double constant_weight_init = 0.5); double constant_weight_init = 0.5);
MLP(std::string & filename); MLP(const std::string & filename);
~MLP(); ~MLP();
void SaveMLPNetwork(std::string & filename)const; void SaveMLPNetwork(const std::string & filename)const;
void LoadMLPNetwork(std::string & filename); void LoadMLPNetwork(const std::string & filename);
void GetOutput(const std::vector<double> &input, void GetOutput(const std::vector<double> &input,
std::vector<double> * output, std::vector<double> * output,