mirror of
https://github.com/davidalbertonogueira/MLP.git
synced 2025-12-16 20:07:07 +03:00
WIP project still in dev.
MISC Remove bias from examples. Change in the license type. Update legal instrument.
This commit is contained in:
229
.gitignore
vendored
229
.gitignore
vendored
@@ -1,28 +1,213 @@
|
||||
# Compiled Object files
|
||||
*.slo
|
||||
*.lo
|
||||
*.o
|
||||
## Ignore Visual Studio temporary files, build results, and
|
||||
## files generated by popular Visual Studio add-ons.
|
||||
|
||||
# User-specific files
|
||||
*.suo
|
||||
*.user
|
||||
*.userosscache
|
||||
*.sln.docstates
|
||||
|
||||
# User-specific files (MonoDevelop/Xamarin Studio)
|
||||
*.userprefs
|
||||
|
||||
# Build results
|
||||
[Dd]ebug/
|
||||
[Dd]ebugPublic/
|
||||
[Rr]elease/
|
||||
[Rr]eleases/
|
||||
x64/
|
||||
x86/
|
||||
build/
|
||||
bld/
|
||||
[Bb]in/
|
||||
[Oo]bj/
|
||||
|
||||
# Visual Studio 2015 cache/options directory
|
||||
.vs/
|
||||
|
||||
# MSTest test Results
|
||||
[Tt]est[Rr]esult*/
|
||||
[Bb]uild[Ll]og.*
|
||||
|
||||
# NUNIT
|
||||
*.VisualState.xml
|
||||
TestResult.xml
|
||||
|
||||
# Build Results of an ATL Project
|
||||
[Dd]ebugPS/
|
||||
[Rr]eleasePS/
|
||||
dlldata.c
|
||||
|
||||
# DNX
|
||||
project.lock.json
|
||||
artifacts/
|
||||
|
||||
*_i.c
|
||||
*_p.c
|
||||
*_i.h
|
||||
*.ilk
|
||||
*.meta
|
||||
*.obj
|
||||
|
||||
# Precompiled Headers
|
||||
*.gch
|
||||
*.pch
|
||||
*.pdb
|
||||
*.pgc
|
||||
*.pgd
|
||||
*.rsp
|
||||
*.sbr
|
||||
*.tlb
|
||||
*.tli
|
||||
*.tlh
|
||||
*.tmp
|
||||
*.tmp_proj
|
||||
*.log
|
||||
*.vspscc
|
||||
*.vssscc
|
||||
.builds
|
||||
*.pidb
|
||||
*.svclog
|
||||
*.scc
|
||||
|
||||
# Compiled Dynamic libraries
|
||||
*.so
|
||||
*.dylib
|
||||
*.dll
|
||||
# Chutzpah Test files
|
||||
_Chutzpah*
|
||||
|
||||
# Fortran module files
|
||||
*.mod
|
||||
# Visual C++ cache files
|
||||
ipch/
|
||||
*.aps
|
||||
*.ncb
|
||||
*.opendb
|
||||
*.opensdf
|
||||
*.sdf
|
||||
*.cachefile
|
||||
|
||||
# Compiled Static libraries
|
||||
*.lai
|
||||
*.la
|
||||
*.a
|
||||
*.lib
|
||||
# Visual Studio profiler
|
||||
*.psess
|
||||
*.vsp
|
||||
*.vspx
|
||||
|
||||
# Executables
|
||||
*.exe
|
||||
*.out
|
||||
*.app
|
||||
# TFS 2012 Local Workspace
|
||||
$tf/
|
||||
|
||||
# Guidance Automation Toolkit
|
||||
*.gpState
|
||||
|
||||
# ReSharper is a .NET coding add-in
|
||||
_ReSharper*/
|
||||
*.[Rr]e[Ss]harper
|
||||
*.DotSettings.user
|
||||
|
||||
# JustCode is a .NET coding add-in
|
||||
.JustCode
|
||||
|
||||
# TeamCity is a build add-in
|
||||
_TeamCity*
|
||||
|
||||
# DotCover is a Code Coverage Tool
|
||||
*.dotCover
|
||||
|
||||
# NCrunch
|
||||
_NCrunch_*
|
||||
.*crunch*.local.xml
|
||||
|
||||
# MightyMoose
|
||||
*.mm.*
|
||||
AutoTest.Net/
|
||||
|
||||
# Web workbench (sass)
|
||||
.sass-cache/
|
||||
|
||||
# Installshield output folder
|
||||
[Ee]xpress/
|
||||
|
||||
# DocProject is a documentation generator add-in
|
||||
DocProject/buildhelp/
|
||||
DocProject/Help/*.HxT
|
||||
DocProject/Help/*.HxC
|
||||
DocProject/Help/*.hhc
|
||||
DocProject/Help/*.hhk
|
||||
DocProject/Help/*.hhp
|
||||
DocProject/Help/Html2
|
||||
DocProject/Help/html
|
||||
|
||||
# Click-Once directory
|
||||
publish/
|
||||
|
||||
# Publish Web Output
|
||||
*.[Pp]ublish.xml
|
||||
*.azurePubxml
|
||||
## TODO: Comment the next line if you want to checkin your
|
||||
## web deploy settings but do note that will include unencrypted
|
||||
## passwords
|
||||
#*.pubxml
|
||||
|
||||
*.publishproj
|
||||
|
||||
# NuGet Packages
|
||||
*.nupkg
|
||||
# The packages folder can be ignored because of Package Restore
|
||||
**/packages/*
|
||||
# except build/, which is used as an MSBuild target.
|
||||
!**/packages/build/
|
||||
# Uncomment if necessary however generally it will be regenerated when needed
|
||||
#!**/packages/repositories.config
|
||||
|
||||
# Windows Azure Build Output
|
||||
csx/
|
||||
*.build.csdef
|
||||
|
||||
# Windows Store app package directory
|
||||
AppPackages/
|
||||
|
||||
# Visual Studio cache files
|
||||
# files ending in .cache can be ignored
|
||||
*.[Cc]ache
|
||||
# but keep track of directories ending in .cache
|
||||
!*.[Cc]ache/
|
||||
|
||||
# Others
|
||||
ClientBin/
|
||||
[Ss]tyle[Cc]op.*
|
||||
~$*
|
||||
*~
|
||||
*.dbmdl
|
||||
*.dbproj.schemaview
|
||||
*.pfx
|
||||
*.publishsettings
|
||||
node_modules/
|
||||
orleans.codegen.cs
|
||||
|
||||
# RIA/Silverlight projects
|
||||
Generated_Code/
|
||||
|
||||
# Backup & report files from converting an old project file
|
||||
# to a newer Visual Studio version. Backup files are not needed,
|
||||
# because we have git ;-)
|
||||
_UpgradeReport_Files/
|
||||
Backup*/
|
||||
UpgradeLog*.XML
|
||||
UpgradeLog*.htm
|
||||
|
||||
# SQL Server files
|
||||
*.mdf
|
||||
*.ldf
|
||||
|
||||
# Business Intelligence projects
|
||||
*.rdl.data
|
||||
*.bim.layout
|
||||
*.bim_*.settings
|
||||
|
||||
# Microsoft Fakes
|
||||
FakesAssemblies/
|
||||
|
||||
# Node.js Tools for Visual Studio
|
||||
.ntvs_analysis.dat
|
||||
|
||||
# Visual Studio 6 build log
|
||||
*.plg
|
||||
|
||||
# Visual Studio 6 workspace options file
|
||||
*.opt
|
||||
|
||||
# LightSwitch generated files
|
||||
GeneratedArtifacts/
|
||||
_Pvt_Extensions/
|
||||
ModelManifest.xml
|
||||
|
||||
29
LICENSE
29
LICENSE
@@ -1,21 +1,16 @@
|
||||
The MIT License (MIT)
|
||||
Copyright (c) 2016, David Alberto Jácome Nogueira
|
||||
All rights reserved.
|
||||
|
||||
Copyright (c) 2016
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
--
|
||||
Additional Disclaimer:
|
||||
This project may include other projects which may contain their own licenses in their respective folders and/or in the respective header files. Please check 'deps' folder for further legal instruments.
|
||||
@@ -88,6 +88,8 @@
|
||||
<WarningLevel>Level3</WarningLevel>
|
||||
<Optimization>Disabled</Optimization>
|
||||
<PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<AdditionalIncludeDirectories>
|
||||
</AdditionalIncludeDirectories>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
@@ -101,6 +103,7 @@
|
||||
<WarningLevel>Level3</WarningLevel>
|
||||
<Optimization>Disabled</Optimization>
|
||||
<PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<AdditionalIncludeDirectories>$(SolutionDir)..\deps</AdditionalIncludeDirectories>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
@@ -116,6 +119,8 @@
|
||||
<FunctionLevelLinking>true</FunctionLevelLinking>
|
||||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<AdditionalIncludeDirectories>
|
||||
</AdditionalIncludeDirectories>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
@@ -133,6 +138,7 @@
|
||||
<FunctionLevelLinking>true</FunctionLevelLinking>
|
||||
<IntrinsicFunctions>true</IntrinsicFunctions>
|
||||
<PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<AdditionalIncludeDirectories>$(SolutionDir)..\deps</AdditionalIncludeDirectories>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<SubSystem>Console</SubSystem>
|
||||
@@ -142,6 +148,7 @@
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="..\deps\microunit.h" />
|
||||
<ClInclude Include="..\src\Layer.h" />
|
||||
<ClInclude Include="..\src\MLP.h" />
|
||||
<ClInclude Include="..\src\Node.h" />
|
||||
|
||||
@@ -30,6 +30,9 @@
|
||||
<ClInclude Include="..\src\Node.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="..\deps\microunit.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="..\src\Main.cpp">
|
||||
|
||||
@@ -3,4 +3,4 @@
|
||||
Simple multilayer perceptron c++ implementation.
|
||||
|
||||
|
||||
David Nogueira, 2016.01.16
|
||||
David Nogueira, 2016.03.26
|
||||
|
||||
368
deps/microunit.h
vendored
Normal file
368
deps/microunit.h
vendored
Normal file
@@ -0,0 +1,368 @@
|
||||
/**
|
||||
* @file microunit.h
|
||||
* @author Sebastiao Salvador de Miranda (ssm)
|
||||
* @brief Tiny library for cpp unit testing. Should work on any c++11 compiler.
|
||||
*
|
||||
* Simply include this header in your test implementation file (e.g., main.cpp)
|
||||
* and call microunit::UnitTester::Run() in the function main(). To register
|
||||
* a new unit test case, use the macro UNIT (See the example below). Inside the
|
||||
* test case body, you can use the following macros to control the result
|
||||
* of the test.
|
||||
*
|
||||
* @li PASS() : Pass the test and return.
|
||||
* @li FAIL() : Fail the test and return.
|
||||
* @li ASSERT_TRUE(condition) : If the condition does not hold, fail and return.
|
||||
* @li ASSERT_FALSE(condition) : If the condition holds, fail and return.
|
||||
*
|
||||
* @code{.cpp}
|
||||
* UNIT(Test_Two_Plus_Two) {
|
||||
* ASSERT_TRUE(2 + 2 == 4);
|
||||
* };
|
||||
* // ...
|
||||
* int main(){
|
||||
* return microunit::UnitTester::Run() ? 0 : -1;
|
||||
* }
|
||||
* @endcode
|
||||
*
|
||||
* @copyright Copyright (c) 2016, Sebastiao Salvador de Miranda.
|
||||
* All rights reserved. See licence below.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
*
|
||||
* (1) Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
*
|
||||
* (2) Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in
|
||||
* the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
*
|
||||
* (3) The name of the author may not be used to
|
||||
* endorse or promote products derived from this software without
|
||||
* specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
|
||||
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
|
||||
* INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
|
||||
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
|
||||
* IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
* POSSIBILITY OF SUCH DAMAGE.
|
||||
*/
|
||||
|
||||
#ifndef _MICROUNIT_MICROUNIT_H_
|
||||
#define _MICROUNIT_MICROUNIT_H_
|
||||
#include <string.h>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
/**
|
||||
* @brief Helper macros to get current logging filename
|
||||
*/
|
||||
#if defined(_WIN32)
|
||||
#define __FILENAME__ (strrchr(__FILE__, '\\') ? \
|
||||
strrchr(__FILE__, '\\') + 1 : __FILE__)
|
||||
#else
|
||||
#define __FILENAME__ (strrchr(__FILE__, '/') ? \
|
||||
strrchr(__FILE__, '/') + 1 : __FILE__)
|
||||
#endif
|
||||
#define MICROUNIT_SEPARATOR "----------------------------------------" \
|
||||
"----------------------------------------"
|
||||
#if defined(_WIN32)
|
||||
#include "windows.h"
|
||||
#endif
|
||||
|
||||
namespace microunit {
|
||||
const static int COLORCODE_GREY{ 7 };
|
||||
const static int COLORCODE_GREEN{ 10 };
|
||||
const static int COLORCODE_RED{ 12 };
|
||||
const static int COLORCODE_YELLOW{ 14 };
|
||||
|
||||
/**
|
||||
* @brief Helper class to convert from color codes to ansi escape codes
|
||||
* Used to print color in non-win32 systems.
|
||||
*/
|
||||
std::string ColorCodeToANSI(const int color_code) {
|
||||
switch (color_code) {
|
||||
case COLORCODE_GREY: return "\033[22;37m";
|
||||
case COLORCODE_GREEN: return "\033[01;31m";
|
||||
case COLORCODE_RED: return "\033[01;32m";
|
||||
case COLORCODE_YELLOW: return "\033[01;33m";
|
||||
default: return "";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper function to change the current terminal color.
|
||||
* @param [in] color_code Input color code.
|
||||
*/
|
||||
void SetTerminalColor(int color_code) {
|
||||
#if defined(_WIN32)
|
||||
HANDLE handler = GetStdHandle(STD_OUTPUT_HANDLE);
|
||||
CONSOLE_SCREEN_BUFFER_INFO buffer_info;
|
||||
GetConsoleScreenBufferInfo(handler, &buffer_info);
|
||||
SetConsoleTextAttribute(handler, ((buffer_info.wAttributes & 0xFFF0) |
|
||||
(WORD)color_code));
|
||||
#else
|
||||
std::cout << ColorCodeToANSI(color_code);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper class to be used as a iostream manipulator and change the
|
||||
* terminal color.
|
||||
*/
|
||||
class Color {
|
||||
public:
|
||||
Color(int code) : code_(code) {}
|
||||
void Set() const {
|
||||
SetTerminalColor(code_);
|
||||
}
|
||||
int code() const { return code_; }
|
||||
private:
|
||||
int code_;
|
||||
};
|
||||
|
||||
const static Color Grey{ COLORCODE_GREY };
|
||||
const static Color Green{ COLORCODE_GREEN };
|
||||
const static Color Red{ COLORCODE_RED };
|
||||
const static Color Yellow{ COLORCODE_YELLOW };
|
||||
|
||||
/**
|
||||
* @brief Helper class to be used in a cout streaming statement. Resets to
|
||||
* the default terminal color upon statement completion.
|
||||
*/
|
||||
class SaveColor {
|
||||
public:
|
||||
~SaveColor() {
|
||||
SetTerminalColor(Grey.code());
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Helper class to be used in a cout streaming statement. Puts a line
|
||||
* break upon statement completion.
|
||||
*/
|
||||
class EndingLineBreak {
|
||||
public:
|
||||
~EndingLineBreak() {
|
||||
std::cout << std::endl;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/** @brief Operator to allow using SaveColor class with an ostream */
|
||||
inline std::ostream& operator<<(std::ostream& os,
|
||||
const microunit::SaveColor& obj) {
|
||||
return os;
|
||||
}
|
||||
|
||||
/** @brief Operator to allow using EndingLineBreak class with an ostream */
|
||||
inline std::ostream& operator<<(std::ostream& os,
|
||||
const microunit::EndingLineBreak& obj) {
|
||||
return os;
|
||||
}
|
||||
|
||||
/** @brief Operator to allow using Color class with an ostream */
|
||||
inline std::ostream& operator<<(std::ostream& os,
|
||||
const microunit::Color& color) {
|
||||
color.Set();
|
||||
return os;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Macro for writing to the terminal an INFO-level log
|
||||
*/
|
||||
#define TERMINAL_INFO std::cout << microunit::SaveColor{} << \
|
||||
microunit::EndingLineBreak{} << microunit::Yellow << "[ ] "
|
||||
#define LOG_INFO TERMINAL_INFO << __FILENAME__ << ":" << __LINE__ << ": "
|
||||
|
||||
/**
|
||||
* @brief Macro for writing to the terminal a BAD-level log
|
||||
*/
|
||||
#define TERMINAL_BAD std::cout << microunit::SaveColor{} << \
|
||||
microunit::EndingLineBreak{} << microunit::Red << "[ ] "
|
||||
#define LOG_BAD TERMINAL_BAD << __FILENAME__ << ":" << __LINE__ << ": "
|
||||
|
||||
/**
|
||||
* @brief Macro for writing to the terminal a GOOD-level log
|
||||
*/
|
||||
#define TERMINAL_GOOD std::cout << microunit::SaveColor{} << \
|
||||
microunit::EndingLineBreak{} << microunit::Green << "[ ] "
|
||||
#define LOG_GOOD TERMINAL_GOOD << __FILENAME__ << ":" << __LINE__ << ": "
|
||||
|
||||
namespace microunit {
|
||||
/**
|
||||
* @brief Result of a unit test.
|
||||
*/
|
||||
struct UnitFunctionResult {
|
||||
bool success{ true };
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Unit test function type.
|
||||
*/
|
||||
typedef void(*UnitFunction)(UnitFunctionResult*);
|
||||
|
||||
/**
|
||||
* @brief Main class for unit test management. This class is a singleton
|
||||
* and maintains a list of all registered unit test cases.
|
||||
*/
|
||||
class UnitTester {
|
||||
public:
|
||||
/**
|
||||
* @brief Run all the registered unit test cases.
|
||||
* @returns True if all tests pass, false otherwise.
|
||||
*/
|
||||
static bool Run() {
|
||||
std::vector<std::string> failures;
|
||||
|
||||
// Iterate all registered unit tests
|
||||
for (auto& unit : Instance().unitfunction_map_) {
|
||||
std::cout << MICROUNIT_SEPARATOR << std::endl;
|
||||
TERMINAL_GOOD << "Test case '" << unit.first << "'";
|
||||
|
||||
// Run the unit test
|
||||
UnitFunctionResult result;
|
||||
unit.second(&result);
|
||||
|
||||
if (!result.success) {
|
||||
TERMINAL_BAD << "Failed test";
|
||||
failures.push_back(unit.first);
|
||||
} else {
|
||||
TERMINAL_GOOD << "Passed test";
|
||||
}
|
||||
}
|
||||
std::cout
|
||||
<< MICROUNIT_SEPARATOR << std::endl
|
||||
<< MICROUNIT_SEPARATOR << std::endl;
|
||||
|
||||
// Output result summary
|
||||
if (failures.empty()) {
|
||||
TERMINAL_GOOD << "All tests passed";
|
||||
std::cout << MICROUNIT_SEPARATOR << std::endl;
|
||||
return true;
|
||||
} else {
|
||||
TERMINAL_BAD << "Failed " << failures.size()
|
||||
<< " test cases:";
|
||||
for (const auto& failure : failures) {
|
||||
TERMINAL_BAD << failure;
|
||||
}
|
||||
std::cout << MICROUNIT_SEPARATOR << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Register a unit test case function. In regular library client usage,
|
||||
* this doesn't need to be called, and the macro UNIT should be used
|
||||
* instead.
|
||||
* @param [in] name Name of the unit test case.
|
||||
* @param [in] function Pointer to unit test case function.
|
||||
* @returns True if all tests pass, false otherwise.
|
||||
*/
|
||||
static void RegisterFunction(const std::string &name,
|
||||
UnitFunction function) {
|
||||
Instance().unitfunction_map_.emplace(name, function);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper class to register a unit test in construction time. This is
|
||||
* used to call RegisterFunction in the construction of a static
|
||||
* helper object. Used by the REGISTER_UNIT macro, which in turn is
|
||||
* used by the UNIT macro.
|
||||
* @returns True if all tests pass, false otherwise.
|
||||
*/
|
||||
class Registrator {
|
||||
public:
|
||||
Registrator(const std::string &name,
|
||||
UnitFunction function) {
|
||||
UnitTester::RegisterFunction(name, function);
|
||||
};
|
||||
Registrator(const Registrator&) = delete;
|
||||
Registrator(Registrator&&) = delete;
|
||||
~Registrator() {};
|
||||
};
|
||||
|
||||
~UnitTester() {};
|
||||
UnitTester(const UnitTester&) = delete;
|
||||
UnitTester(UnitTester&&) = delete;
|
||||
|
||||
private:
|
||||
UnitTester() {};
|
||||
static UnitTester& Instance() {
|
||||
static UnitTester instance;
|
||||
return instance;
|
||||
}
|
||||
std::map<std::string, UnitFunction> unitfunction_map_;
|
||||
};
|
||||
}
|
||||
|
||||
#define MACROCAT_NEXP(A, B) A ## B
|
||||
#define MACROCAT(A, B) MACROCAT_NEXP(A, B)
|
||||
|
||||
/**
|
||||
* @brief Register a unit function using a helper static Registrator object.
|
||||
*/
|
||||
#define REGISTER_UNIT(FUNCTION) \
|
||||
static microunit::UnitTester::Registrator \
|
||||
MACROCAT(MICROUNIT_REGISTRATION, __COUNTER__)(#FUNCTION, FUNCTION);
|
||||
|
||||
/**
|
||||
* @brief Define a unit function body. This macro is the one which should be used
|
||||
* by client code to define unit test cases.
|
||||
* @code{.cpp}
|
||||
* UNIT(Test_Two_Plus_Two) {
|
||||
* ASSERT_TRUE(2 + 2 == 4);
|
||||
* };
|
||||
* @endcode
|
||||
*/
|
||||
#define UNIT(FUNCTION) \
|
||||
void FUNCTION(microunit::UnitFunctionResult*); \
|
||||
REGISTER_UNIT(FUNCTION); \
|
||||
void FUNCTION(microunit::UnitFunctionResult *__microunit_testresult)
|
||||
|
||||
/**
|
||||
* @brief Pass the test and return from the test case.
|
||||
*/
|
||||
#define PASS() { \
|
||||
LOG_GOOD << "Test stopped: Pass"; \
|
||||
__microunit_testresult->success = true; \
|
||||
return; \
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fail the test and return from the test case.
|
||||
*/
|
||||
#define FAIL() { \
|
||||
LOG_BAD << "Test stopped: Fail"; \
|
||||
__microunit_testresult->success = false; \
|
||||
return; \
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check a particular test condition. If the condition does not hold,
|
||||
* fail the test and return.
|
||||
*/
|
||||
#define ASSERT_TRUE(condition) if(!(condition)) { \
|
||||
LOG_BAD << "Assert-true failed: " #condition; \
|
||||
FAIL(); \
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Check a particular test condition. If the condition holds, fail the
|
||||
* test and return.
|
||||
*/
|
||||
#define ASSERT_FALSE(condition) if((condition)) { \
|
||||
LOG_BAD << "Assert-false failed: " #condition << std::endl; \
|
||||
FAIL(); \
|
||||
}
|
||||
#endif
|
||||
25
src/Layer.h
25
src/Layer.h
@@ -14,6 +14,7 @@
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cassert> // for assert()
|
||||
|
||||
class Layer {
|
||||
public:
|
||||
@@ -25,14 +26,36 @@ public:
|
||||
|
||||
Layer(int num_nodes, int num_inputs_per_node) {
|
||||
m_num_nodes = num_nodes;
|
||||
m_num_inputs_per_node = num_inputs_per_node;
|
||||
m_nodes = std::vector<Node>(num_nodes, Node(num_inputs_per_node));
|
||||
};
|
||||
|
||||
~Layer() {
|
||||
|
||||
m_nodes.clear();
|
||||
};
|
||||
|
||||
void GetOutput(const std::vector<double> &input, std::vector<double> * output) const {
|
||||
assert(input.size() == m_num_inputs_per_node);
|
||||
|
||||
output->resize(m_num_nodes);
|
||||
|
||||
for (int i = 0; i < m_num_nodes; ++i) {
|
||||
(*output)[i] = m_nodes[i].GetOutput(input);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateWeights(const std::vector<double> &x,
|
||||
double m_learning_rate,
|
||||
double error) {
|
||||
assert(x.size() == m_num_inputs_per_node);
|
||||
|
||||
for (size_t i = 0; i < m_nodes.size(); i++)
|
||||
m_nodes[i].UpdateWeights(x, m_learning_rate, error);
|
||||
};
|
||||
|
||||
protected:
|
||||
int m_num_nodes;
|
||||
int m_num_inputs_per_node;
|
||||
std::vector<Node> m_nodes;
|
||||
};
|
||||
|
||||
|
||||
22
src/MLP.h
22
src/MLP.h
@@ -23,9 +23,7 @@ public:
|
||||
int num_outputs,
|
||||
int num_hidden_layers,
|
||||
int num_nodes_per_hidden_layer,
|
||||
double learning_rate,
|
||||
int max_iterations,
|
||||
double threshold) {
|
||||
double learning_rate) {
|
||||
|
||||
m_num_inputs = num_inputs;
|
||||
m_num_outputs = num_outputs;
|
||||
@@ -33,8 +31,6 @@ public:
|
||||
m_num_nodes_per_hidden_layer = num_nodes_per_hidden_layer;
|
||||
|
||||
m_learning_rate = learning_rate;
|
||||
m_max_iterations = max_iterations;
|
||||
m_threshold = threshold;
|
||||
};
|
||||
|
||||
~MLP() {
|
||||
@@ -57,15 +53,18 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetWeightMatrixCardinality()const;
|
||||
bool ExportWeights(std::vector<double> *weights)const;
|
||||
bool ImportWeights(const std::vector<double> & weights);
|
||||
|
||||
std::vector<double> & GetOutputValues(const std::vector<double> &input);
|
||||
int GetOutputClass(const std::vector<double> &input);
|
||||
void GetOutput(const std::vector<double> &input, std::vector<double> * output) const;
|
||||
void GetOutputClass(const std::vector<double> &output, size_t * class_id) const;
|
||||
|
||||
void Train(const std::vector<TrainingSample> &training_sample_set,
|
||||
bool bias_already_in);
|
||||
|
||||
//void UpdateWeight(const std::vector<double> &x,
|
||||
// double error);
|
||||
int max_iterations);
|
||||
protected:
|
||||
void UpdateWeights(const std::vector<double> &x,
|
||||
double error);
|
||||
private:
|
||||
|
||||
int m_num_inputs;
|
||||
@@ -75,7 +74,6 @@ private:
|
||||
|
||||
double m_learning_rate;
|
||||
int m_max_iterations;
|
||||
double m_threshold;
|
||||
|
||||
std::vector<Layer> m_layers;
|
||||
};
|
||||
|
||||
116
src/Main.cpp
116
src/Main.cpp
@@ -10,9 +10,9 @@
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include "microunit.h"
|
||||
|
||||
void LearnAND() {
|
||||
UNIT(LearnAND) {
|
||||
std::cout << "Train AND function with mlp." << std::endl;
|
||||
|
||||
std::vector<TrainingSample> training_set =
|
||||
@@ -23,18 +23,22 @@ void LearnAND() {
|
||||
{{ 1, 1 },{0,1}}
|
||||
};
|
||||
|
||||
MLP my_mlp(2, 2, 1, 5, 0.1, 100, 0.5);
|
||||
my_mlp.Train(training_set, false);
|
||||
MLP my_mlp(2, 2, 1, 5, 0.1);
|
||||
my_mlp.Train(training_set, 100);
|
||||
|
||||
assert(my_mlp.GetOutputClass({ 0, 0 }) == 0);
|
||||
assert(my_mlp.GetOutputClass({ 0, 1 }) == 0);
|
||||
assert(my_mlp.GetOutputClass({ 1, 0 }) == 0);
|
||||
assert(my_mlp.GetOutputClass({ 1, 1 }) == 1);
|
||||
for (const auto & training_sample : training_set){
|
||||
size_t class_id;
|
||||
my_mlp.GetOutputClass(training_sample.input_vector(), &class_id);
|
||||
ASSERT_TRUE(class_id ==
|
||||
std::distance(training_sample.output_vector().begin(),
|
||||
std::max_element(training_sample.output_vector().begin(),
|
||||
training_sample.output_vector().end()) ));
|
||||
}
|
||||
std::cout << "Trained with success." << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void LearnNAND() {
|
||||
UNIT(LearnNAND) {
|
||||
std::cout << "Train NAND function with mlp." << std::endl;
|
||||
|
||||
std::vector<TrainingSample> training_set =
|
||||
@@ -45,18 +49,22 @@ void LearnNAND() {
|
||||
{{ 1, 1 },{1,0}}
|
||||
};
|
||||
|
||||
MLP my_mlp(2, 2, 1, 5, 0.1, 100, 0.5);
|
||||
my_mlp.Train(training_set, false);
|
||||
MLP my_mlp(2, 2, 1, 5, 0.1);
|
||||
my_mlp.Train(training_set, 100);
|
||||
|
||||
assert(my_mlp.GetOutputClass({ 0, 0 }) == 1);
|
||||
assert(my_mlp.GetOutputClass({ 0, 1 }) == 1);
|
||||
assert(my_mlp.GetOutputClass({ 1, 0 }) == 1);
|
||||
assert(my_mlp.GetOutputClass({ 1, 1 }) == 0);
|
||||
for (const auto & training_sample : training_set) {
|
||||
size_t class_id;
|
||||
my_mlp.GetOutputClass(training_sample.input_vector(), &class_id);
|
||||
ASSERT_TRUE(class_id ==
|
||||
std::distance(training_sample.output_vector().begin(),
|
||||
std::max_element(training_sample.output_vector().begin(),
|
||||
training_sample.output_vector().end())));
|
||||
}
|
||||
std::cout << "Trained with success." << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void LearnOR() {
|
||||
UNIT(LearnOR) {
|
||||
std::cout << "Train OR function with mlp." << std::endl;
|
||||
|
||||
std::vector<TrainingSample> training_set =
|
||||
@@ -67,18 +75,22 @@ void LearnOR() {
|
||||
{{ 1, 1 },{0,1}}
|
||||
};
|
||||
|
||||
MLP my_mlp(2, 2, 1, 5, 0.1, 100, 0.5);
|
||||
my_mlp.Train(training_set, false);
|
||||
MLP my_mlp(2, 2, 1, 5, 0.1);
|
||||
my_mlp.Train(training_set, 100);
|
||||
|
||||
assert(my_mlp.GetOutputClass({ 0, 0 }) == 0);
|
||||
assert(my_mlp.GetOutputClass({ 0, 1 }) == 1);
|
||||
assert(my_mlp.GetOutputClass({ 1, 0 }) == 1);
|
||||
assert(my_mlp.GetOutputClass({ 1, 1 }) == 1);
|
||||
for (const auto & training_sample : training_set) {
|
||||
size_t class_id;
|
||||
my_mlp.GetOutputClass(training_sample.input_vector(), &class_id);
|
||||
ASSERT_TRUE(class_id ==
|
||||
std::distance(training_sample.output_vector().begin(),
|
||||
std::max_element(training_sample.output_vector().begin(),
|
||||
training_sample.output_vector().end())));
|
||||
}
|
||||
std::cout << "Trained with success." << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void LearnNOR() {
|
||||
UNIT(LearnNOR) {
|
||||
std::cout << "Train NOR function with mlp." << std::endl;
|
||||
|
||||
std::vector<TrainingSample> training_set =
|
||||
@@ -89,18 +101,22 @@ void LearnNOR() {
|
||||
{{ 1, 1 },{1,0}}
|
||||
};
|
||||
|
||||
MLP my_mlp(2, 2, 1, 5, 0.1, 100, 0.5);
|
||||
my_mlp.Train(training_set, false);
|
||||
MLP my_mlp(2, 2, 1, 5, 0.1);
|
||||
my_mlp.Train(training_set, 100);
|
||||
|
||||
assert(my_mlp.GetOutputClass({ 0, 0 }) == 1);
|
||||
assert(my_mlp.GetOutputClass({ 0, 1 }) == 0);
|
||||
assert(my_mlp.GetOutputClass({ 1, 0 }) == 0);
|
||||
assert(my_mlp.GetOutputClass({ 1, 1 }) == 0);
|
||||
for (const auto & training_sample : training_set) {
|
||||
size_t class_id;
|
||||
my_mlp.GetOutputClass(training_sample.input_vector(), &class_id);
|
||||
ASSERT_TRUE(class_id ==
|
||||
std::distance(training_sample.output_vector().begin(),
|
||||
std::max_element(training_sample.output_vector().begin(),
|
||||
training_sample.output_vector().end())));
|
||||
}
|
||||
std::cout << "Trained with success." << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void LearnXOR() {
|
||||
UNIT(LearnXOR) {
|
||||
std::cout << "Train XOR function with mlp." << std::endl;
|
||||
|
||||
std::vector<TrainingSample> training_set =
|
||||
@@ -111,18 +127,22 @@ void LearnXOR() {
|
||||
{ { 1, 1 },{ 1,0 } }
|
||||
};
|
||||
|
||||
MLP my_mlp(2, 2, 1, 5, 0.1, 100, 0.5);
|
||||
my_mlp.Train(training_set, false);
|
||||
MLP my_mlp(2, 2, 1, 5, 0.1);
|
||||
my_mlp.Train(training_set, 100);
|
||||
|
||||
assert(my_mlp.GetOutputClass({ 0, 0 }) == 0);
|
||||
assert(my_mlp.GetOutputClass({ 0, 1 }) == 1);
|
||||
assert(my_mlp.GetOutputClass({ 1, 0 }) == 1);
|
||||
assert(my_mlp.GetOutputClass({ 1, 1 }) == 0);
|
||||
for (const auto & training_sample : training_set) {
|
||||
size_t class_id;
|
||||
my_mlp.GetOutputClass(training_sample.input_vector(), &class_id);
|
||||
ASSERT_TRUE(class_id ==
|
||||
std::distance(training_sample.output_vector().begin(),
|
||||
std::max_element(training_sample.output_vector().begin(),
|
||||
training_sample.output_vector().end())));
|
||||
}
|
||||
std::cout << "Trained with success." << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void LearnNOT() {
|
||||
UNIT(LearnNOT) {
|
||||
std::cout << "Train NOT function with mlp." << std::endl;
|
||||
|
||||
std::vector<TrainingSample> training_set =
|
||||
@@ -131,22 +151,22 @@ void LearnNOT() {
|
||||
{{ 1},{1,1}}
|
||||
};
|
||||
|
||||
MLP my_mlp(1, 2, 1, 5, 0.1, 100, 0.5);
|
||||
my_mlp.Train(training_set, false);
|
||||
MLP my_mlp(1, 2, 1, 5, 0.1);
|
||||
my_mlp.Train(training_set, 100);
|
||||
|
||||
assert(my_mlp.GetOutputClass({ 0 }) == 1);
|
||||
assert(my_mlp.GetOutputClass({ 1 }) == 0);
|
||||
for (const auto & training_sample : training_set) {
|
||||
size_t class_id;
|
||||
my_mlp.GetOutputClass(training_sample.input_vector(), &class_id);
|
||||
ASSERT_TRUE(class_id ==
|
||||
std::distance(training_sample.output_vector().begin(),
|
||||
std::max_element(training_sample.output_vector().begin(),
|
||||
training_sample.output_vector().end())));
|
||||
}
|
||||
std::cout << "Trained with success." << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
int main() {
|
||||
LearnAND();
|
||||
LearnNAND();
|
||||
LearnOR();
|
||||
LearnNOR();
|
||||
LearnXOR();
|
||||
LearnNOT();
|
||||
|
||||
microunit::UnitTester::Run();
|
||||
return 0;
|
||||
}
|
||||
66
src/Node.h
66
src/Node.h
@@ -12,71 +12,91 @@
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cassert> // for assert()
|
||||
|
||||
#define ZERO_WEIGHT_INITIALIZATION 1
|
||||
#define USE_SIGMOID 1
|
||||
|
||||
class Node {
|
||||
public:
|
||||
Node() {
|
||||
m_bias = 0.0;
|
||||
//m_old_bias = 0.0;
|
||||
m_num_inputs = 0;
|
||||
m_weights.clear();
|
||||
//m_old_weights.clear();
|
||||
};
|
||||
Node(int num_inputs) {
|
||||
m_bias = 0.0;
|
||||
//m_old_bias = 0.0;
|
||||
m_num_inputs = num_inputs;
|
||||
m_num_inputs = num_inputs + 1;
|
||||
m_weights.clear();
|
||||
//m_old_weights.clear();
|
||||
m_weights = std::vector<double>(num_inputs);
|
||||
//m_old_weights = std::vector<double>(num_inputs);
|
||||
m_weights = std::vector<double>(m_num_inputs);
|
||||
|
||||
//initialize weight vector
|
||||
std::generate_n(m_weights.begin(),
|
||||
num_inputs,
|
||||
(ZERO_WEIGHT_INITIALIZATION) ?
|
||||
m_num_inputs,
|
||||
(ZERO_WEIGHT_INITIALIZATION) ?
|
||||
utils::gen_rand(0) : utils::gen_rand());
|
||||
};
|
||||
~Node() {
|
||||
m_weights.clear();
|
||||
//m_old_weights.clear();
|
||||
};
|
||||
int GetInputSize() {
|
||||
int GetInputSize() const {
|
||||
return m_num_inputs;
|
||||
}
|
||||
void SetInputSize(int num_inputs) {
|
||||
m_num_inputs = num_inputs;
|
||||
}
|
||||
double GetBias() {
|
||||
double GetBias() const {
|
||||
return m_bias;
|
||||
}
|
||||
//double GetOldBias() {
|
||||
// return m_old_bias;
|
||||
//}
|
||||
|
||||
void SetBias(double bias) {
|
||||
m_bias = bias;
|
||||
}
|
||||
//void SetOldBias(double old_bias) {
|
||||
// m_old_bias = old_bias;
|
||||
//}
|
||||
|
||||
std::vector<double> & GetWeights() {
|
||||
return m_weights;
|
||||
}
|
||||
//std::vector<double> & GetOldWeights() {
|
||||
// return m_old_weights;
|
||||
//}
|
||||
uint32_t GetWeightsVectorSize() const {
|
||||
|
||||
const std::vector<double> & GetWeights() const {
|
||||
return m_weights;
|
||||
}
|
||||
|
||||
size_t GetWeightsVectorSize() const {
|
||||
return m_weights.size();
|
||||
}
|
||||
|
||||
void GetOutput(const std::vector<double> &input, double * output) const {
|
||||
assert(input.size() == m_weights.size());
|
||||
double inner_prod = std::inner_product(begin(input),
|
||||
end(input),
|
||||
begin(m_weights),
|
||||
0.0);
|
||||
*output = inner_prod;
|
||||
}
|
||||
|
||||
void GetFilteredOutput(const std::vector<double> &input, double * bool_output) {
|
||||
double inner_prod;
|
||||
GetOutput(input, &inner_prod);
|
||||
#if USE_SIGMOID == 1
|
||||
double y = utils::sigmoid(inner_prod);
|
||||
*bool_output = (y > 0) ? true : false;
|
||||
#else
|
||||
*bool_output = (inner_prod > 0) ? true : false;
|
||||
#endif
|
||||
};
|
||||
|
||||
void UpdateWeights(const std::vector<double> &x,
|
||||
double m_learning_rate,
|
||||
double error) {
|
||||
assert(x.size() == m_weights.size());
|
||||
for (size_t i = 0; i < m_weights.size(); i++)
|
||||
m_weights[i] += x[i] * m_learning_rate * error;
|
||||
};
|
||||
protected:
|
||||
int m_num_inputs;
|
||||
double m_bias;
|
||||
//double m_old_bias;
|
||||
std::vector<double> m_weights;
|
||||
//std::vector<double> m_old_weights;
|
||||
};
|
||||
|
||||
#endif //NODE_H
|
||||
@@ -14,10 +14,10 @@ public:
|
||||
|
||||
m_input_vector = input_vector;
|
||||
}
|
||||
std::vector<double> & input_vector() {
|
||||
const std::vector<double> & input_vector() const {
|
||||
return m_input_vector;
|
||||
}
|
||||
uint32_t GetInputVectorSize() const {
|
||||
size_t GetInputVectorSize() const {
|
||||
return m_input_vector.size();
|
||||
}
|
||||
void AddBiasValue(double bias_value) {
|
||||
@@ -35,10 +35,10 @@ public:
|
||||
Sample(input_vector) {
|
||||
m_output_vector = output_vector;
|
||||
}
|
||||
std::vector<double> & output_vector() {
|
||||
const std::vector<double> & output_vector() const {
|
||||
return m_output_vector;
|
||||
}
|
||||
uint32_t GetOutputVectorSize() const {
|
||||
size_t GetOutputVectorSize() const {
|
||||
return m_output_vector.size();
|
||||
}
|
||||
protected:
|
||||
|
||||
27
src/Utils.h
27
src/Utils.h
@@ -15,6 +15,14 @@
|
||||
#include <sys/time.h>
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
|
||||
|
||||
|
||||
namespace utils {
|
||||
|
||||
struct gen_rand {
|
||||
@@ -39,6 +47,25 @@ inline double deriv_sigmoid(double x) {
|
||||
return sigmoid(x)*(1 - sigmoid(x));
|
||||
};
|
||||
|
||||
void Softmax(std::vector<double> *output) {
|
||||
size_t num_elements = output->size();
|
||||
std::vector<double> exp_output(num_elements);
|
||||
double exp_total = 0.0;
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
exp_output[i] = exp((*output)[i]);
|
||||
exp_total += exp_output[i];
|
||||
}
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
(*output)[i] = exp_output[i] / exp_total;
|
||||
}
|
||||
}
|
||||
|
||||
void GetIdMaxElement(const std::vector<double> &output, size_t * class_id) {
|
||||
*class_id = std::distance(output.begin(),
|
||||
std::max_element(output.begin(),
|
||||
output.end()));
|
||||
}
|
||||
|
||||
class Chronometer {
|
||||
public:
|
||||
Chronometer() {
|
||||
|
||||
Reference in New Issue
Block a user