Reduced the progress update messages to take load from the UI,

implemented load-or-create in test funtion
This commit is contained in:
mandlm 2015-10-26 19:50:01 +01:00
parent 9bb927d2d2
commit 1e716979a9
3 changed files with 44 additions and 22 deletions

50
Net.cpp
View File

@ -4,27 +4,14 @@
#include <iostream>
#include <fstream>
Net::Net()
{
}
Net::Net(std::initializer_list<size_t> layerSizes)
{
if (layerSizes.size() < 2)
{
throw std::exception("A net needs at least 2 layers");
}
for (size_t numNeurons : layerSizes)
{
push_back(Layer(numNeurons));
}
for (auto layerIt = begin(); layerIt != end() - 1; ++layerIt)
{
Layer &currentLayer = *layerIt;
const Layer &nextLayer = *(layerIt + 1);
currentLayer.addBiasNeuron();
currentLayer.connectTo(nextLayer);
}
initialize(layerSizes);
}
Net::Net(const std::string &filename)
@ -32,6 +19,31 @@ Net::Net(const std::string &filename)
load(filename);
}
void Net::initialize(std::initializer_list<size_t> layerSizes)
{
clear();
if (layerSizes.size() < 2)
{
throw std::exception("A net needs at least 2 layers");
}
for (size_t numNeurons : layerSizes)
{
push_back(Layer(numNeurons));
}
for (auto layerIt = begin(); layerIt != end() - 1; ++layerIt)
{
Layer &currentLayer = *layerIt;
const Layer &nextLayer = *(layerIt + 1);
currentLayer.addBiasNeuron();
currentLayer.connectTo(nextLayer);
}
}
void Net::feedForward(const std::vector<double> &inputValues)
{
Layer &inputLayer = front();

3
Net.h
View File

@ -7,9 +7,12 @@
class Net : public std::vector < Layer >
{
public:
Net();
Net(std::initializer_list<size_t> layerSizes);
Net(const std::string &filename);
void initialize(std::initializer_list<size_t> layerSizes);
void feedForward(const std::vector<double> &inputValues);
std::vector<double> getOutput();
void backProp(const std::vector<double> &targetValues);

View File

@ -9,7 +9,15 @@ void NetLearner::run()
{
QElapsedTimer timer;
Net myNet({2, 3, 1});
Net myNet;
try
{
myNet.load("mynet.nnet");
}
catch (...)
{
myNet.initialize({2, 3, 1});
}
size_t batchSize = 5000;
size_t batchIndex = 0;
@ -54,6 +62,7 @@ void NetLearner::run()
emit logMessage(logString);
emit currentNetError(batchMaxError);
emit progress((double)iteration / (double)numIterations);
batchIndex = 0;
batchMaxError = 0.0;
@ -61,8 +70,6 @@ void NetLearner::run()
}
myNet.backProp(targetValues);
emit progress((double)iteration / (double)numIterations);
}
QString timerLogString;