mirror of
https://github.com/davidalbertonogueira/MLP.git
synced 2025-12-17 12:24:40 +03:00
Bug fix in Linux compilation.
This commit is contained in:
@@ -21,8 +21,16 @@ INITIALIZE_EASYLOGGINGPP
|
|||||||
// Disclaimer: This is NOT an example of good machine learning practices
|
// Disclaimer: This is NOT an example of good machine learning practices
|
||||||
// regarding training/testing dataset partitioning.
|
// regarding training/testing dataset partitioning.
|
||||||
|
|
||||||
|
const int input_size = 4;
|
||||||
|
const int number_classes = 3;
|
||||||
|
#if defined(_WIN32)
|
||||||
const char *iris_dataset = "../../data/iris.data";
|
const char *iris_dataset = "../../data/iris.data";
|
||||||
const std::array<std::string, 3> class_names =
|
const std::string iris_mlp_weights = "../../data/iris.mlp";
|
||||||
|
#else
|
||||||
|
const char *iris_dataset = "./data/iris.data";
|
||||||
|
const std::string iris_mlp_weights = "./data/iris.mlp";
|
||||||
|
#endif
|
||||||
|
const std::array<std::string, number_classes> class_names =
|
||||||
{ "Iris-setosa", "Iris-versicolor", "Iris-virginica" };
|
{ "Iris-setosa", "Iris-versicolor", "Iris-virginica" };
|
||||||
|
|
||||||
|
|
||||||
@@ -46,32 +54,41 @@ bool load_data(int *samples,
|
|||||||
LOG(INFO) << "Loading " << (*samples)
|
LOG(INFO) << "Loading " << (*samples)
|
||||||
<< " data points from " << iris_dataset << ".";
|
<< " data points from " << iris_dataset << ".";
|
||||||
// Allocate memory for input and output data.
|
// Allocate memory for input and output data.
|
||||||
input->resize((*samples) * 4);
|
input->resize((*samples) * input_size);
|
||||||
iris_class->resize((*samples) * 3);
|
iris_class->resize((*samples) * number_classes);
|
||||||
|
|
||||||
// Read the file into our arrays.
|
// Read the file into our arrays.
|
||||||
int i, j;
|
int i, j;
|
||||||
for (i = 0; i < (*samples); ++i) {
|
for (i = 0; i < (*samples); ++i) {
|
||||||
double *p = &((*input)[0]) + i * 4;
|
double *p = &((*input)[0]) + i * input_size;
|
||||||
double *c = &((*iris_class)[0]) + i * 3;
|
double *c = &((*iris_class)[0]) + i * number_classes;
|
||||||
c[0] = c[1] = c[2] = 0.0;
|
for (int k = 0; k < number_classes; k++) {
|
||||||
|
c[i] = 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
fgets(line, 1024, in);
|
fgets(line, 1024, in);
|
||||||
|
|
||||||
char *split = strtok(line, ",");
|
char *split = strtok(line, ",");
|
||||||
for (j = 0; j < 4; ++j) {
|
for (j = 0; j < 4; ++j) {
|
||||||
p[j] = atof(split);
|
p[j] = atof(split);
|
||||||
split = strtok(0, ",");
|
split = strtok(NULL, ",");
|
||||||
}
|
}
|
||||||
|
|
||||||
split[strlen(split) - 1] = 0;
|
if (strlen(split) >= 1 && split[strlen(split) - 1] == '\n')
|
||||||
|
split[strlen(split) - 1] = '\0';
|
||||||
|
if (strlen(split) >= 2 && split[strlen(split) - 2] == '\r')
|
||||||
|
split[strlen(split) - 2] = '\0';
|
||||||
|
|
||||||
if (strcmp(split, class_names[0].c_str()) == 0) {
|
if (strcmp(split, class_names[0].c_str()) == 0) {
|
||||||
c[0] = 1.0;
|
c[0] = 1.0;
|
||||||
} else if (strcmp(split, class_names[1].c_str()) == 0) {
|
}
|
||||||
|
else if (strcmp(split, class_names[1].c_str()) == 0) {
|
||||||
c[1] = 1.0;
|
c[1] = 1.0;
|
||||||
} else if (strcmp(split, class_names[2].c_str()) == 0) {
|
}
|
||||||
|
else if (strcmp(split, class_names[2].c_str()) == 0) {
|
||||||
c[2] = 1.0;
|
c[2] = 1.0;
|
||||||
} else {
|
}
|
||||||
|
else {
|
||||||
LOG(ERROR) << "Unknown iris_class " << split
|
LOG(ERROR) << "Unknown iris_class " << split
|
||||||
<< ".";
|
<< ".";
|
||||||
return false;
|
return false;
|
||||||
@@ -126,11 +143,11 @@ int main(int argc, char *argv[]) {
|
|||||||
LOG(INFO) << "Training for " << loops << " loops over data.";
|
LOG(INFO) << "Training for " << loops << " loops over data.";
|
||||||
my_mlp.Train(training_sample_set_with_bias, .01, loops, 0.10, false);
|
my_mlp.Train(training_sample_set_with_bias, .01, loops, 0.10, false);
|
||||||
|
|
||||||
my_mlp.SaveMLPNetwork(std::string("../../data/iris.mlp"));
|
my_mlp.SaveMLPNetwork(iris_mlp_weights);
|
||||||
}
|
}
|
||||||
//Destruction/Construction of a MLP object to show off saving and loading a trained model
|
//Destruction/Construction of a MLP object to show off saving and loading a trained model
|
||||||
{
|
{
|
||||||
MLP my_mlp(std::string("../../data/iris.mlp"));
|
MLP my_mlp(iris_mlp_weights);
|
||||||
|
|
||||||
int correct = 0;
|
int correct = 0;
|
||||||
for (int j = 0; j < samples; ++j) {
|
for (int j = 0; j < samples; ++j) {
|
||||||
@@ -141,9 +158,11 @@ int main(int argc, char *argv[]) {
|
|||||||
|
|
||||||
if (iris_class[j * 3 + 0] == 1.0 && class_id == 0) {
|
if (iris_class[j * 3 + 0] == 1.0 && class_id == 0) {
|
||||||
++correct;
|
++correct;
|
||||||
} else if (iris_class[j * 3 + 1] == 1.0 && class_id == 1) {
|
}
|
||||||
|
else if (iris_class[j * 3 + 1] == 1.0 && class_id == 1) {
|
||||||
++correct;
|
++correct;
|
||||||
} else if (iris_class[j * 3 + 2] == 1.0 && class_id == 2) {
|
}
|
||||||
|
else if (iris_class[j * 3 + 2] == 1.0 && class_id == 2) {
|
||||||
++correct;
|
++correct;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ MLP::MLP(const std::vector<uint64_t> & layers_nodes,
|
|||||||
constant_weight_init);
|
constant_weight_init);
|
||||||
};
|
};
|
||||||
|
|
||||||
MLP::MLP(std::string & filename) {
|
MLP::MLP(const std::string & filename) {
|
||||||
LoadMLPNetwork(filename);
|
LoadMLPNetwork(filename);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ void MLP::CreateMLP(const std::vector<uint64_t> & layers_nodes,
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void MLP::SaveMLPNetwork(std::string & filename)const {
|
void MLP::SaveMLPNetwork(const std::string & filename)const {
|
||||||
FILE * file;
|
FILE * file;
|
||||||
file = fopen(filename.c_str(), "wb");
|
file = fopen(filename.c_str(), "wb");
|
||||||
fwrite(&m_num_inputs, sizeof(m_num_inputs), 1, file);
|
fwrite(&m_num_inputs, sizeof(m_num_inputs), 1, file);
|
||||||
@@ -71,7 +71,7 @@ void MLP::SaveMLPNetwork(std::string & filename)const {
|
|||||||
}
|
}
|
||||||
fclose(file);
|
fclose(file);
|
||||||
};
|
};
|
||||||
void MLP::LoadMLPNetwork(std::string & filename) {
|
void MLP::LoadMLPNetwork(const std::string & filename) {
|
||||||
m_layers_nodes.clear();
|
m_layers_nodes.clear();
|
||||||
m_layers.clear();
|
m_layers.clear();
|
||||||
|
|
||||||
|
|||||||
@@ -24,11 +24,11 @@ public:
|
|||||||
const std::vector<std::string> & layers_activfuncs,
|
const std::vector<std::string> & layers_activfuncs,
|
||||||
bool use_constant_weight_init = false,
|
bool use_constant_weight_init = false,
|
||||||
double constant_weight_init = 0.5);
|
double constant_weight_init = 0.5);
|
||||||
MLP(std::string & filename);
|
MLP(const std::string & filename);
|
||||||
~MLP();
|
~MLP();
|
||||||
|
|
||||||
void SaveMLPNetwork(std::string & filename)const;
|
void SaveMLPNetwork(const std::string & filename)const;
|
||||||
void LoadMLPNetwork(std::string & filename);
|
void LoadMLPNetwork(const std::string & filename);
|
||||||
|
|
||||||
void GetOutput(const std::vector<double> &input,
|
void GetOutput(const std::vector<double> &input,
|
||||||
std::vector<double> * output,
|
std::vector<double> * output,
|
||||||
|
|||||||
Reference in New Issue
Block a user