Add IRIS dataset usage example.

Add LoadModel and SaveModel methods.
This commit is contained in:
davidjacnogueira
2017-03-10 00:20:35 +00:00
parent 2fd468ca63
commit 4005aff367
11 changed files with 659 additions and 57 deletions

View File

@@ -0,0 +1,157 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="14.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|Win32">
<Configuration>Release</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{D58D3DD3-DF71-479D-A8EF-C52308C34C11}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>IrisDatasetTest</RootNamespace>
<WindowsTargetPlatformVersion>8.1</WindowsTargetPlatformVersion>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v140</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v140</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v140</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v140</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Label="Shared">
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
<PrecompiledHeader>
</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<PrecompiledHeader>
</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<AdditionalIncludeDirectories>$(SolutionDir)..\deps</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PrecompiledHeader>
</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<PrecompiledHeader>
</PrecompiledHeader>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<AdditionalIncludeDirectories>$(SolutionDir)..\deps</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClCompile Include="..\..\src\IrisDatasetTest.cpp" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\MLP_MVS.vcxproj">
<Project>{6bfa9d94-b136-4985-83a1-ee76fff6f374}</Project>
</ProjectReference>
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
</Project>

View File

@@ -0,0 +1,22 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<Filter Include="Source Files">
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
<Extensions>cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
</Filter>
<Filter Include="Header Files">
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
<Extensions>h;hh;hpp;hxx;hm;inl;inc;xsd</Extensions>
</Filter>
<Filter Include="Resource Files">
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
</Filter>
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\src\IrisDatasetTest.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
</Project>

View File

@@ -11,6 +11,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "MLPTest", "MLPTest\MLPTest.
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "LayerTest", "LayerTest\LayerTest.vcxproj", "{10A8D77B-A596-4B06-87DA-B28492D77905}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "IrisDatasetTest", "IrisDatasetTest\IrisDatasetTest.vcxproj", "{D58D3DD3-DF71-479D-A8EF-C52308C34C11}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|x64 = Debug|x64
@@ -51,6 +53,14 @@ Global
{10A8D77B-A596-4B06-87DA-B28492D77905}.Release|x64.Build.0 = Release|x64
{10A8D77B-A596-4B06-87DA-B28492D77905}.Release|x86.ActiveCfg = Release|Win32
{10A8D77B-A596-4B06-87DA-B28492D77905}.Release|x86.Build.0 = Release|Win32
{D58D3DD3-DF71-479D-A8EF-C52308C34C11}.Debug|x64.ActiveCfg = Debug|x64
{D58D3DD3-DF71-479D-A8EF-C52308C34C11}.Debug|x64.Build.0 = Debug|x64
{D58D3DD3-DF71-479D-A8EF-C52308C34C11}.Debug|x86.ActiveCfg = Debug|Win32
{D58D3DD3-DF71-479D-A8EF-C52308C34C11}.Debug|x86.Build.0 = Debug|Win32
{D58D3DD3-DF71-479D-A8EF-C52308C34C11}.Release|x64.ActiveCfg = Release|x64
{D58D3DD3-DF71-479D-A8EF-C52308C34C11}.Release|x64.Build.0 = Release|x64
{D58D3DD3-DF71-479D-A8EF-C52308C34C11}.Release|x86.ActiveCfg = Release|Win32
{D58D3DD3-DF71-479D-A8EF-C52308C34C11}.Release|x86.Build.0 = Release|Win32
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE

150
data/iris.data Normal file
View File

@@ -0,0 +1,150 @@
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica

BIN
data/iris.mlp Normal file

Binary file not shown.

157
src/IrisDatasetTest.cpp Normal file
View File

@@ -0,0 +1,157 @@
//============================================================================
// Name : Main.cpp
// Author : David Nogueira
//============================================================================
#include "MLP.h"
#include <stdio.h>
#include <stdlib.h>
#include <iostream>
#include <sstream>
#include <fstream>
#include <vector>
#include <array>
#include <algorithm>
#include "microunit.h"
#include "easylogging++.h"
INITIALIZE_EASYLOGGINGPP
// Example illustrating practical use of this MLP lib.
// Disclaimer: This is NOT an example of good machine learning practices
// regarding training/testing dataset partitioning.
const char *iris_dataset = "../../data/iris.data";
const std::array<std::string, 3> class_names =
{ "Iris-setosa", "Iris-versicolor", "Iris-virginica" };
bool load_data(int *samples,
std::vector<double> *input,
std::vector<double> *iris_class) {
// Load the iris data-set.
FILE *in = fopen(iris_dataset, "r");
if (!in) {
LOG(ERROR) << "Could not open file: " << iris_dataset << ".";
return false;
}
// Loop through the data to get a count.
char line[1024];
while (!feof(in) && fgets(line, 1024, in)) {
++(*samples);
}
fseek(in, 0, SEEK_SET);
LOG(INFO) << "Loading " << (*samples)
<< " data points from " << iris_dataset << ".";
// Allocate memory for input and output data.
input->resize((*samples) * 4);
iris_class->resize((*samples) * 3);
// Read the file into our arrays.
int i, j;
for (i = 0; i < (*samples); ++i) {
double *p = &((*input)[0]) + i * 4;
double *c = &((*iris_class)[0]) + i * 3;
c[0] = c[1] = c[2] = 0.0;
fgets(line, 1024, in);
char *split = strtok(line, ",");
for (j = 0; j < 4; ++j) {
p[j] = atof(split);
split = strtok(0, ",");
}
split[strlen(split) - 1] = 0;
if (strcmp(split, class_names[0].c_str()) == 0) {
c[0] = 1.0;
} else if (strcmp(split, class_names[1].c_str()) == 0) {
c[1] = 1.0;
} else if (strcmp(split, class_names[2].c_str()) == 0) {
c[2] = 1.0;
} else {
LOG(ERROR) << "Unknown iris_class " << split
<< ".";
return false;
}
}
fclose(in);
return true;
}
int main(int argc, char *argv[]) {
LOG(INFO) << "Train MLP with IRIS dataset using backpropagation.";
int samples = 0;
std::vector<double> input;
std::vector<double> iris_class;
// Load the data from file.
if (!load_data(&samples, &input, &iris_class)) {
LOG(ERROR) << "Error processing input file.";
return -1;
}
std::vector<TrainingSample> training_set;
for (int j = 0; j < samples; ++j) {
std::vector<double> training_set_input;
std::vector<double> training_set_output;
training_set_input.reserve(4);
for (int i = 0; i < 4; i++)
training_set_input.push_back(*(&(input[0]) + j * 4 + i));
training_set_output.reserve(3);
for (int i = 0; i < 3; i++)
training_set_output.push_back(*(&(iris_class[0]) + j * 3 + i));
training_set.emplace_back(std::move(training_set_input),
std::move(training_set_output));
}
std::vector<TrainingSample> training_sample_set_with_bias(std::move(training_set));
//set up bias
for (auto & training_sample_with_bias : training_sample_set_with_bias) {
training_sample_with_bias.AddBiasValue(1);
}
{
/* 4 inputs + 1 bias.
* 1 hidden layer(s) of 4 neurons.
* 3 outputs (1 per iris_class)
*/
MLP my_mlp({ 4 + 1, 4 ,3 }, { "sigmoid", "linear" }, false);
int loops = 5000;
// Train the network with backpropagation.
LOG(INFO) << "Training for " << loops << " loops over data.";
my_mlp.UpdateMiniBatch(training_sample_set_with_bias, .01, loops, 0.10, false);
my_mlp.SaveMLPNetwork(std::string("../../data/iris.mlp"));
}
//Destruction/Construction of a MLP object to show off saving and loading a trained model
{
MLP my_mlp(std::string("../../data/iris.mlp"));
int correct = 0;
for (int j = 0; j < samples; ++j) {
std::vector<double> guess;
my_mlp.GetOutput(training_sample_set_with_bias[j].input_vector(), &guess);
size_t class_id;
my_mlp.GetOutputClass(guess, &class_id);
if (iris_class[j * 3 + 0] == 1.0 && class_id == 0) {
++correct;
} else if (iris_class[j * 3 + 1] == 1.0 && class_id == 1) {
++correct;
} else if (iris_class[j * 3 + 2] == 1.0 && class_id == 2) {
++correct;
}
}
LOG(INFO) << correct << "/" << samples
<< " (" << ((double)correct / samples * 100.0) << "%).";
}
return 0;
}

View File

@@ -47,6 +47,7 @@ public:
assert(ret_val);
m_activation_function = (*pair).first;
m_deriv_activation_function = (*pair).second;
m_activation_function_str = activation_function;
};
~Layer() {
@@ -114,11 +115,52 @@ public:
}
};
void SaveLayer(FILE * file) const {
fwrite(&m_num_nodes, sizeof(m_num_nodes), 1, file);
fwrite(&m_num_inputs_per_node, sizeof(m_num_inputs_per_node), 1, file);
size_t str_size = m_activation_function_str.size();
fwrite(&str_size, sizeof(size_t), 1, file);
fwrite(m_activation_function_str.c_str(), sizeof(char), str_size, file);
for (int i = 0; i < m_nodes.size(); i++) {
m_nodes[i].SaveNode(file);
}
};
void LoadLayer(FILE * file) {
m_nodes.clear();
fread(&m_num_nodes, sizeof(m_num_nodes), 1, file);
fread(&m_num_inputs_per_node, sizeof(m_num_inputs_per_node), 1, file);
size_t str_size = 0;
fread(&str_size, sizeof(size_t), 1, file);
m_activation_function_str.resize(str_size);
fread(&(m_activation_function_str[0]), sizeof(char), str_size, file);
std::pair<std::function<double(double)>,
std::function<double(double)> > *pair;
bool ret_val = utils::ActivationFunctionsManager::Singleton().
GetActivationFunctionPair(m_activation_function_str,
&pair);
assert(ret_val);
m_activation_function = (*pair).first;
m_deriv_activation_function = (*pair).second;
m_nodes.resize(m_num_nodes);
for (int i = 0; i < m_nodes.size(); i++) {
m_nodes[i].LoadNode(file);
}
};
protected:
int m_num_inputs_per_node{ 0 };
int m_num_nodes{ 0 };
std::vector<Node> m_nodes;
std::string m_activation_function_str;
std::function<double(double)> m_activation_function;
std::function<double(double)> m_deriv_activation_function;
};

View File

@@ -12,11 +12,82 @@
#include <algorithm>
#include "easylogging++.h"
bool MLP::ExportNNWeights(std::vector<double> *weights) const {
return true;
//desired call sintax : MLP({64*64,20,4}, {"sigmoid", "linear"},
MLP::MLP(const std::vector<uint64_t> & layers_nodes,
const std::vector<std::string> & layers_activfuncs,
bool use_constant_weight_init,
double constant_weight_init) {
assert(layers_nodes.size() >= 2);
assert(layers_activfuncs.size() + 1 == layers_nodes.size());
CreateMLP(layers_nodes,
layers_activfuncs,
use_constant_weight_init,
constant_weight_init);
};
bool MLP::ImportNNWeights(const std::vector<double> & weights) {
return true;
MLP::MLP(std::string & filename) {
LoadMLPNetwork(filename);
}
MLP::~MLP() {
m_num_inputs = 0;
m_num_outputs = 0;
m_num_hidden_layers = 0;
m_layers_nodes.clear();
m_layers.clear();
};
void MLP::CreateMLP(const std::vector<uint64_t> & layers_nodes,
const std::vector<std::string> & layers_activfuncs,
bool use_constant_weight_init,
double constant_weight_init) {
m_layers_nodes = layers_nodes;
m_num_inputs = m_layers_nodes[0];
m_num_outputs = m_layers_nodes[m_layers_nodes.size() - 1];
m_num_hidden_layers = m_layers_nodes.size() - 2;
for (int i = 0; i < m_layers_nodes.size() - 1; i++) {
m_layers.emplace_back(Layer(m_layers_nodes[i],
m_layers_nodes[i + 1],
layers_activfuncs[i],
use_constant_weight_init,
constant_weight_init));
}
};
void MLP::SaveMLPNetwork(std::string & filename)const {
FILE * file;
file = fopen(filename.c_str(), "wb");
fwrite(&m_num_inputs, sizeof(m_num_inputs), 1, file);
fwrite(&m_num_outputs, sizeof(m_num_outputs), 1, file);
fwrite(&m_num_hidden_layers, sizeof(m_num_hidden_layers), 1, file);
if (!m_layers_nodes.empty())
fwrite(&m_layers_nodes[0], sizeof(m_layers_nodes[0]), m_layers_nodes.size(), file);
for (int i = 0; i < m_layers.size(); i++) {
m_layers[i].SaveLayer(file);
}
fclose(file);
};
void MLP::LoadMLPNetwork(std::string & filename) {
m_layers_nodes.clear();
m_layers.clear();
FILE * file;
file = fopen(filename.c_str(), "rb");
fread(&m_num_inputs, sizeof(m_num_inputs), 1, file);
fread(&m_num_outputs, sizeof(m_num_outputs), 1, file);
fread(&m_num_hidden_layers, sizeof(m_num_hidden_layers), 1, file);
m_layers_nodes.resize(m_num_hidden_layers + 2);
if (!m_layers_nodes.empty())
fread(&m_layers_nodes[0], sizeof(m_layers_nodes[0]), m_layers_nodes.size(), file);
m_layers.resize(m_layers_nodes.size() - 1);
for (int i = 0; i < m_layers.size(); i++) {
m_layers[i].LoadLayer(file);
}
fclose(file);
};
void MLP::GetOutput(const std::vector<double> &input,
@@ -80,7 +151,8 @@ void MLP::UpdateWeights(const std::vector<std::vector<double>> & all_layers_acti
void MLP::UpdateMiniBatch(const std::vector<TrainingSample> &training_sample_set_with_bias,
double learning_rate,
int max_iterations,
double min_error_cost) {
double min_error_cost,
bool output_log) {
int num_examples = training_sample_set_with_bias.size();
int num_features = training_sample_set_with_bias[0].GetInputVectorSize();
@@ -102,23 +174,28 @@ void MLP::UpdateMiniBatch(const std::vector<TrainingSample> &training_sample_set
// }
// }
//}
size_t i = 0;
double current_iteration_cost_function = 0.0;
for (i = 0; i < max_iterations; i++) {
current_iteration_cost_function = 0.0;
for (auto & training_sample_with_bias : training_sample_set_with_bias) {
std::vector<double> predicted_output;
std::vector< std::vector<double> > all_layers_activations;
GetOutput(training_sample_with_bias.input_vector(),
&predicted_output,
&all_layers_activations);
const std::vector<double> & correct_output =
training_sample_with_bias.output_vector();
assert(correct_output.size() == predicted_output.size());
std::vector<double> deriv_error_output(predicted_output.size());
if ((i % (max_iterations / 100)) == 0) {
if (output_log && ((i % (max_iterations / 10)) == 0)) {
std::stringstream temp_training;
temp_training << training_sample_with_bias << "\t\t";
@@ -144,7 +221,7 @@ void MLP::UpdateMiniBatch(const std::vector<TrainingSample> &training_sample_set
learning_rate);
}
if ((i % (max_iterations / 100)) == 0)
if (output_log && ((i % (max_iterations / 10)) == 0))
LOG(INFO) << "Iteration " << i << " cost function f(error): "
<< current_iteration_cost_function;
if (current_iteration_cost_function < min_error_cost)

View File

@@ -22,28 +22,13 @@ public:
//desired call sintax : MLP({64*64,20,4}, {"sigmoid", "linear"},
MLP(const std::vector<uint64_t> & layers_nodes,
const std::vector<std::string> & layers_activfuncs,
bool use_constant_weight_init = true,
double constant_weight_init = 0.5) {
assert(layers_nodes.size() >= 2);
assert(layers_activfuncs.size() + 1 == layers_nodes.size());
bool use_constant_weight_init = false,
double constant_weight_init = 0.5);
MLP(std::string & filename);
~MLP();
CreateMLP(layers_nodes,
layers_activfuncs,
use_constant_weight_init,
constant_weight_init);
}
~MLP() {
m_num_inputs = 0;
m_num_outputs = 0;
m_num_hidden_layers = 0;
m_layers_nodes.clear();
m_layers.clear();
};
bool ExportNNWeights(std::vector<double> *weights)const;
bool ImportNNWeights(const std::vector<double> & weights);
void SaveMLPNetwork(std::string & filename)const;
void LoadMLPNetwork(std::string & filename);
void GetOutput(const std::vector<double> &input,
std::vector<double> * output,
@@ -53,7 +38,8 @@ public:
void UpdateMiniBatch(const std::vector<TrainingSample> &training_sample_set_with_bias,
double learning_rate,
int max_iterations = 5000,
double min_error_cost = 0.001);
double min_error_cost = 0.001,
bool output_log = true);
protected:
void UpdateWeights(const std::vector<std::vector<double>> & all_layers_activations,
const std::vector<double> &error,
@@ -62,20 +48,7 @@ private:
void CreateMLP(const std::vector<uint64_t> & layers_nodes,
const std::vector<std::string> & layers_activfuncs,
bool use_constant_weight_init,
double constant_weight_init = 0.5) {
m_layers_nodes = layers_nodes;
m_num_inputs = m_layers_nodes[0];
m_num_outputs = m_layers_nodes[m_layers_nodes.size() - 1];
m_num_hidden_layers = m_layers_nodes.size() - 2;
for (int i = 0; i < m_layers_nodes.size() - 1; i++) {
m_layers.emplace_back(Layer(m_layers_nodes[i],
m_layers_nodes[i + 1],
layers_activfuncs[i],
use_constant_weight_init,
constant_weight_init));
}
}
double constant_weight_init = 0.5);
int m_num_inputs{ 0 };
int m_num_outputs{ 0 };
int m_num_hidden_layers{ 0 };

View File

@@ -39,7 +39,7 @@ UNIT(LearnAND) {
size_t num_examples = training_sample_set_with_bias.size();
size_t num_features = training_sample_set_with_bias[0].GetInputVectorSize();
size_t num_outputs = training_sample_set_with_bias[0].GetOutputVectorSize();
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" }, false);
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" });
//Train MLP
my_mlp.UpdateMiniBatch(training_sample_set_with_bias, 0.5, 500, 0.25);
@@ -79,7 +79,7 @@ UNIT(LearnNAND) {
size_t num_examples = training_sample_set_with_bias.size();
size_t num_features = training_sample_set_with_bias[0].GetInputVectorSize();
size_t num_outputs = training_sample_set_with_bias[0].GetOutputVectorSize();
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" }, false);
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" });
//Train MLP
my_mlp.UpdateMiniBatch(training_sample_set_with_bias, 0.5, 500, 0.25);
@@ -119,7 +119,7 @@ UNIT(LearnOR) {
size_t num_examples = training_sample_set_with_bias.size();
size_t num_features = training_sample_set_with_bias[0].GetInputVectorSize();
size_t num_outputs = training_sample_set_with_bias[0].GetOutputVectorSize();
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" }, false);
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" });
//Train MLP
my_mlp.UpdateMiniBatch(training_sample_set_with_bias, 0.5, 500, 0.25);
@@ -159,7 +159,7 @@ UNIT(LearnNOR) {
size_t num_examples = training_sample_set_with_bias.size();
size_t num_features = training_sample_set_with_bias[0].GetInputVectorSize();
size_t num_outputs = training_sample_set_with_bias[0].GetOutputVectorSize();
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" }, false);
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" });
//Train MLP
my_mlp.UpdateMiniBatch(training_sample_set_with_bias, 0.5, 500, 0.25);
@@ -197,7 +197,7 @@ UNIT(LearnXOR) {
size_t num_examples = training_sample_set_with_bias.size();
size_t num_features = training_sample_set_with_bias[0].GetInputVectorSize();
size_t num_outputs = training_sample_set_with_bias[0].GetOutputVectorSize();
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" }, false);
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" });
//Train MLP
my_mlp.UpdateMiniBatch(training_sample_set_with_bias, 0.5, 500, 0.25);
@@ -233,7 +233,7 @@ UNIT(LearnNOT) {
size_t num_examples = training_sample_set_with_bias.size();
size_t num_features = training_sample_set_with_bias[0].GetInputVectorSize();
size_t num_outputs = training_sample_set_with_bias[0].GetOutputVectorSize();
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" }, false);
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" });
//Train MLP
my_mlp.UpdateMiniBatch(training_sample_set_with_bias, 0.5, 500, 0.25);
@@ -271,7 +271,7 @@ UNIT(LearnX1) {
size_t num_examples = training_sample_set_with_bias.size();
size_t num_features = training_sample_set_with_bias[0].GetInputVectorSize();
size_t num_outputs = training_sample_set_with_bias[0].GetOutputVectorSize();
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" }, false);
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" });
//Train MLP
my_mlp.UpdateMiniBatch(training_sample_set_with_bias, 0.5, 500, 0.25);
@@ -309,7 +309,7 @@ UNIT(LearnX2) {
size_t num_examples = training_sample_set_with_bias.size();
size_t num_features = training_sample_set_with_bias[0].GetInputVectorSize();
size_t num_outputs = training_sample_set_with_bias[0].GetOutputVectorSize();
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" }, false);
MLP my_mlp({ num_features, 2 ,num_outputs }, { "sigmoid", "linear" });
//Train MLP
my_mlp.UpdateMiniBatch(training_sample_set_with_bias, 0.5, 500, 0.25);

View File

@@ -109,7 +109,7 @@ public:
double threshold = 0.5) const {
double value;
GetOutputAfterActivationFunction(input, activation_function, &value);
*bool_output = (value >threshold) ? true : false;
*bool_output = (value > threshold) ? true : false;
};
void UpdateWeights(const std::vector<double> &x,
@@ -126,6 +126,20 @@ public:
m_weights[weight_id] += learning_rate*increment;
}
void SaveNode(FILE * file) const {
fwrite(&m_num_inputs, sizeof(m_num_inputs), 1, file);
fwrite(&m_bias, sizeof(m_bias), 1, file);
fwrite(&m_weights[0], sizeof(m_weights[0]), m_weights.size(), file);
};
void LoadNode(FILE * file) {
m_weights.clear();
fread(&m_num_inputs, sizeof(m_num_inputs), 1, file);
fread(&m_bias, sizeof(m_bias), 1, file);
m_weights.resize(m_num_inputs);
fread(&m_weights[0], sizeof(m_weights[0]), m_weights.size(), file);
};
protected:
int m_num_inputs{ 0 };
double m_bias{ 0.0 };