Learning from and displaying of digit samples
This commit is contained in:
parent
d98ec63fbd
commit
cd1101dfe2
@ -1,16 +1,6 @@
|
|||||||
#include "mnistloader.h"
|
#include "mnistloader.h"
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <functional>
|
|
||||||
#include <memory>
|
|
||||||
#include <list>
|
|
||||||
|
|
||||||
#include <intrin.h>
|
|
||||||
|
|
||||||
MnistLoader::MnistLoader()
|
|
||||||
{
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void MnistLoader::load(const std::string &databaseFileName, const std::string &labelsFileName)
|
void MnistLoader::load(const std::string &databaseFileName, const std::string &labelsFileName)
|
||||||
{
|
{
|
||||||
@ -18,6 +8,13 @@ void MnistLoader::load(const std::string &databaseFileName, const std::string &l
|
|||||||
loadLabels(labelsFileName);
|
loadLabels(labelsFileName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
|
||||||
|
{
|
||||||
|
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;
|
||||||
|
|
||||||
|
return *(samples[sampleIndex].get());
|
||||||
|
}
|
||||||
|
|
||||||
void MnistLoader::loadDatabase(const std::string &fileName)
|
void MnistLoader::loadDatabase(const std::string &fileName)
|
||||||
{
|
{
|
||||||
std::ifstream databaseFile;
|
std::ifstream databaseFile;
|
||||||
@ -43,6 +40,8 @@ void MnistLoader::loadDatabase(const std::string &fileName)
|
|||||||
throw std::runtime_error("unexpected sample size loading MNIST database");
|
throw std::runtime_error("unexpected sample size loading MNIST database");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
samples.reserve(samples.size() + sampleCount);
|
||||||
|
|
||||||
for (int32_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex)
|
for (int32_t sampleIndex = 0; sampleIndex < sampleCount; ++sampleIndex)
|
||||||
{
|
{
|
||||||
std::unique_ptr<MnistSample> sample = std::make_unique<MnistSample>();
|
std::unique_ptr<MnistSample> sample = std::make_unique<MnistSample>();
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
#define MNISTLOADER_H
|
#define MNISTLOADER_H
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <list>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <inttypes.h>
|
#include <inttypes.h>
|
||||||
|
|
||||||
@ -26,13 +26,13 @@ public:
|
|||||||
using MnistSample = Sample<SampleWidth, SampleHeight>;
|
using MnistSample = Sample<SampleWidth, SampleHeight>;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::list<std::unique_ptr<MnistSample>> samples;
|
std::vector<std::unique_ptr<MnistSample>> samples;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MnistLoader();
|
|
||||||
|
|
||||||
void load(const std::string &databaseFileName, const std::string &labelsFileName);
|
void load(const std::string &databaseFileName, const std::string &labelsFileName);
|
||||||
|
|
||||||
|
const MnistSample &getRandomSample() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void loadDatabase(const std::string &fileName);
|
void loadDatabase(const std::string &fileName);
|
||||||
void loadLabels(const std::string &fileName);
|
void loadLabels(const std::string &fileName);
|
||||||
|
@ -12,30 +12,38 @@ void NetLearner::run()
|
|||||||
QElapsedTimer timer;
|
QElapsedTimer timer;
|
||||||
|
|
||||||
emit logMessage("Loading training data...");
|
emit logMessage("Loading training data...");
|
||||||
emit progress(0.0);
|
|
||||||
|
|
||||||
MnistLoader mnistLoader;
|
MnistLoader mnistLoader;
|
||||||
mnistLoader.load("../NeuroUI/MNIST Database/train-images.idx3-ubyte",
|
mnistLoader.load("../NeuroUI/MNIST Database/train-images.idx3-ubyte",
|
||||||
"../NeuroUI/MNIST Database/train-labels.idx1-ubyte");
|
"../NeuroUI/MNIST Database/train-labels.idx1-ubyte");
|
||||||
|
|
||||||
emit logMessage("done");
|
emit logMessage("done");
|
||||||
emit progress(0.0);
|
|
||||||
|
|
||||||
return;
|
Net digitClassifier({28*28, 256, 1});
|
||||||
|
|
||||||
Net digitClassifier({32*32, 16*16, 32, 1});
|
|
||||||
|
|
||||||
timer.start();
|
timer.start();
|
||||||
|
|
||||||
size_t numIterations = 10000;
|
size_t numIterations = 100000;
|
||||||
for (size_t iteration = 0; iteration < numIterations; ++iteration)
|
for (size_t iteration = 0; iteration < numIterations; ++iteration)
|
||||||
{
|
{
|
||||||
|
auto trainingSample = mnistLoader.getRandomSample();
|
||||||
|
|
||||||
|
QImage trainingImage(trainingSample.data, 28, 28, QImage::Format_Grayscale8);
|
||||||
|
emit sampleImageLoaded(trainingImage);
|
||||||
|
|
||||||
std::vector<double> targetValues =
|
std::vector<double> targetValues =
|
||||||
{
|
{
|
||||||
//trainingSample.first / 10.0
|
trainingSample.label / 10.0
|
||||||
};
|
};
|
||||||
|
|
||||||
//digitClassifier.feedForward(trainingSample.second);
|
std::vector<double> trainingData;
|
||||||
|
trainingData.reserve(28*28);
|
||||||
|
for (const uint8_t &val : trainingSample.data)
|
||||||
|
{
|
||||||
|
trainingData.push_back(val / 255.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
digitClassifier.feedForward(trainingData);
|
||||||
|
|
||||||
std::vector<double> outputValues = digitClassifier.getOutput();
|
std::vector<double> outputValues = digitClassifier.getOutput();
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user