Merge pull request #1 from mandlm/digits

Digits
This commit is contained in:
mandlm 2015-10-31 15:08:01 +01:00
commit 1a0d2b9ea7
16 changed files with 225 additions and 55 deletions

View File

@ -21,7 +21,7 @@ void Layer::setOutputValues(const std::vector<double> & outputValues)
for (const double &value : outputValues)
{
(neuronIt++)->setOutputValue(value);
}
}
}
void Layer::feedForward(const Layer &inputLayer)
@ -54,7 +54,7 @@ void Layer::connectTo(const Layer & nextLayer)
void Layer::updateInputWeights(Layer & prevLayer)
{
static const double trainingRate = 0.3;
static const double trainingRate = 0.2;
for (size_t targetLayerIndex = 0; targetLayerIndex < sizeWithoutBiasNeuron(); ++targetLayerIndex)
{

View File

@ -13,9 +13,10 @@ public:
Layer(size_t numNeurons);
void setOutputValues(const std::vector<double> & outputValues);
void feedForward(const Layer &inputLayer);
double getWeightedSum(size_t outputNeuron) const;
void connectTo(const Layer & nextLayer);
void connectTo(const Layer &nextLayer);
void updateInputWeights(Layer &prevLayer);

View File

@ -63,7 +63,7 @@ void Net::feedForward(const std::vector<double> &inputValues)
Layer &nextLayer = *(layerIt + 1);
nextLayer.feedForward(currentLayer);
}
}
}
std::vector<double> Net::getOutput()

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -18,14 +18,16 @@ SOURCES += main.cpp\
../../Net.cpp \
../../Neuron.cpp \
netlearner.cpp \
errorplotter.cpp
errorplotter.cpp \
mnistloader.cpp
HEADERS += neuroui.h \
../../Layer.h \
../../Net.h \
../../Neuron.h \
netlearner.h \
errorplotter.h
errorplotter.h \
mnistloader.h
FORMS += neuroui.ui

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 34 KiB

View File

@ -0,0 +1,97 @@
#include "mnistloader.h"
#include <fstream>
void MnistLoader::load(const std::string &databaseFileName, const std::string &labelsFileName)
{
loadDatabase(databaseFileName);
loadLabels(labelsFileName);
}
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
{
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;
return *(samples[sampleIndex].get());
}
void MnistLoader::loadDatabase(const std::string &fileName)
{
std::ifstream databaseFile;
databaseFile.open(fileName, std::ios::binary);
if (!databaseFile.is_open())
{
throw std::runtime_error("unable to open MNIST database file");
}
int32_t magicNumber = readInt32(databaseFile);
if (magicNumber != DatabaseFileMagicNumber)
{
throw std::runtime_error("unexpected data reading MNIST database file");
}
int32_t sampleCount = readInt32(databaseFile);
int32_t sampleWidth = readInt32(databaseFile);
int32_t sampleHeight = readInt32(databaseFile);
if (sampleWidth != SampleWidth || sampleHeight != SampleHeight)
{
throw std::runtime_error("unexpected sample size loading MNIST database");
}
samples.reserve(samples.size() + sampleCount);
for (int32_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex)
{
std::unique_ptr<MnistSample> sample = std::make_unique<MnistSample>();
databaseFile.read(reinterpret_cast<char *>(sample->data), sampleWidth * sampleHeight);
samples.push_back(std::move(sample));
}
}
void MnistLoader::loadLabels(const std::string &fileName)
{
std::ifstream labelFile;
labelFile.open(fileName, std::ios::binary);
if (!labelFile.is_open())
{
throw std::runtime_error("unable to open MNIST label file");
}
int32_t magicNumber = readInt32(labelFile);
if (magicNumber != LabelFileMagicNumber)
{
throw std::runtime_error("unexpected data reading MNIST label file");
}
int32_t labelCount = readInt32(labelFile);
if (labelCount != static_cast<int32_t>(samples.size()))
{
throw std::runtime_error("MNIST database and label files don't match in size");
}
auto sampleIt = samples.begin();
for (int32_t labelIndex = 0; labelIndex < labelCount; ++labelIndex)
{
(*sampleIt++)->label = readInt8(labelFile);
}
}
int8_t MnistLoader::readInt8(std::ifstream &file)
{
int8_t buf8;
file.read(reinterpret_cast<char *>(&buf8), sizeof(buf8));
return buf8;
}
int32_t MnistLoader::readInt32(std::ifstream &file)
{
int32_t buf32;
file.read(reinterpret_cast<char *>(&buf32), sizeof(buf32));
return _byteswap_ulong(buf32);
}

44
gui/NeuroUI/mnistloader.h Normal file
View File

@ -0,0 +1,44 @@
#ifndef MNISTLOADER_H
#define MNISTLOADER_H
#include <string>
#include <vector>
#include <memory>
#include <inttypes.h>
class MnistLoader
{
private:
static const uint32_t DatabaseFileMagicNumber = 2051;
static const uint32_t LabelFileMagicNumber = 2049;
static const size_t SampleWidth = 28;
static const size_t SampleHeight = 28;
public:
template<size_t SAMPLE_WIDTH, size_t SAMPLE_HEIGHT>
class Sample
{
public:
uint8_t label;
uint8_t data[SAMPLE_WIDTH * SAMPLE_HEIGHT];
};
using MnistSample = Sample<SampleWidth, SampleHeight>;
private:
std::vector<std::unique_ptr<MnistSample>> samples;
public:
void load(const std::string &databaseFileName, const std::string &labelsFileName);
const MnistSample &getRandomSample() const;
private:
void loadDatabase(const std::string &fileName);
void loadLabels(const std::string &fileName);
static int8_t readInt8(std::ifstream &file);
static int32_t readInt32(std::ifstream &file);
};
#endif // MNISTLOADER_H

View File

@ -1,7 +1,9 @@
#include "netlearner.h"
#include "../../Net.h"
#include "mnistloader.h"
#include <QElapsedTimer>
#include <QImage>
void NetLearner::run()
{
@ -9,67 +11,54 @@ void NetLearner::run()
{
QElapsedTimer timer;
Net myNet;
try
{
myNet.load("mynet.nnet");
}
catch (...)
{
myNet.initialize({2, 3, 1});
}
emit logMessage("Loading training data...");
size_t batchSize = 5000;
size_t batchIndex = 0;
double batchMaxError = 0.0;
double batchMeanError = 0.0;
MnistLoader mnistLoader;
mnistLoader.load("../NeuroUI/MNIST Database/train-images.idx3-ubyte",
"../NeuroUI/MNIST Database/train-labels.idx1-ubyte");
emit logMessage("done");
Net digitClassifier({28*28, 256, 1});
timer.start();
size_t numIterations = 1000000;
size_t numIterations = 100000;
for (size_t iteration = 0; iteration < numIterations; ++iteration)
{
std::vector<double> inputValues =
{
std::rand() / (double)RAND_MAX,
std::rand() / (double)RAND_MAX
};
auto trainingSample = mnistLoader.getRandomSample();
QImage trainingImage(trainingSample.data, 28, 28, QImage::Format_Grayscale8);
emit sampleImageLoaded(trainingImage);
std::vector<double> targetValues =
{
(inputValues[0] + inputValues[1]) / 2.0
trainingSample.label / 10.0
};
myNet.feedForward(inputValues);
std::vector<double> trainingData;
trainingData.reserve(28*28);
for (const uint8_t &val : trainingSample.data)
{
trainingData.push_back(val / 255.0);
}
std::vector<double> outputValues = myNet.getOutput();
digitClassifier.feedForward(trainingData);
std::vector<double> outputValues = digitClassifier.getOutput();
double error = outputValues[0] - targetValues[0];
batchMeanError += error;
batchMaxError = std::max<double>(batchMaxError, error);
QString logString;
if (batchIndex++ == batchSize)
{
QString logString;
logString.append("Error: ");
logString.append(QString::number(std::abs(error)));
logString.append("Batch error (");
logString.append(QString::number(batchSize));
logString.append(" iterations, max/mean): ");
logString.append(QString::number(std::abs(batchMaxError)));
logString.append(" / ");
logString.append(QString::number(std::abs(batchMeanError / batchSize)));
emit logMessage(logString);
emit currentNetError(error);
emit progress((double)iteration / (double)numIterations);
emit logMessage(logString);
emit currentNetError(batchMaxError);
emit progress((double)iteration / (double)numIterations);
batchIndex = 0;
batchMaxError = 0.0;
batchMeanError = 0.0;
}
myNet.backProp(targetValues);
digitClassifier.backProp(targetValues);
}
QString timerLogString;
@ -79,7 +68,7 @@ void NetLearner::run()
emit logMessage(timerLogString);
myNet.save("mynet.nnet");
digitClassifier.save("DigitClassifier.nnet");
}
catch (std::exception &ex)
{

View File

@ -14,6 +14,7 @@ signals:
void logMessage(const QString &logMessage);
void progress(double progress);
void currentNetError(double error);
void sampleImageLoaded(const QImage &image);
};
#endif // NETLEARNER_H

View File

@ -31,6 +31,8 @@ void NeuroUI::on_runButton_clicked()
connect(m_netLearner.get(), &NetLearner::finished, this, &NeuroUI::netLearnerFinished);
connect(m_netLearner.get(), &NetLearner::currentNetError, ui->errorPlotter, &ErrorPlotter::addErrorValue);
connect(m_netLearner.get(), &NetLearner::sampleImageLoaded, this, &NeuroUI::setImage);
}
m_netLearner->start();
@ -61,3 +63,10 @@ void NeuroUI::progress(double progress)
ui->progressBar->setValue(value);
}
void NeuroUI::setImage(const QImage &image)
{
QPixmap pixmap;
pixmap.convertFromImage(image);
ui->label->setPixmap(pixmap);
}

View File

@ -28,6 +28,7 @@ private slots:
void netLearnerStarted();
void netLearnerFinished();
void progress(double progress);
void setImage(const QImage &image);
private:
Ui::NeuroUI *ui;

View File

@ -20,11 +20,37 @@
<widget class="QWidget" name="centralWidget">
<layout class="QVBoxLayout" name="verticalLayout_2">
<item>
<widget class="QListWidget" name="logView">
<property name="uniformItemSizes">
<bool>true</bool>
</property>
</widget>
<layout class="QHBoxLayout" name="horizontalLayout_2">
<item>
<widget class="QListWidget" name="logView">
<property name="uniformItemSizes">
<bool>true</bool>
</property>
</widget>
</item>
<item>
<widget class="QLabel" name="label">
<property name="sizePolicy">
<sizepolicy hsizetype="Fixed" vsizetype="Preferred">
<horstretch>0</horstretch>
<verstretch>0</verstretch>
</sizepolicy>
</property>
<property name="minimumSize">
<size>
<width>128</width>
<height>0</height>
</size>
</property>
<property name="text">
<string/>
</property>
<property name="alignment">
<set>Qt::AlignCenter</set>
</property>
</widget>
</item>
</layout>
</item>
<item>
<widget class="ErrorPlotter" name="errorPlotter" native="true">