Querying of a single MNIST sample
This commit is contained in:
parent
ab9dcfbd35
commit
eecd7a0fe6
@ -8,6 +8,21 @@ void MnistLoader::load(const std::string &databaseFileName, const std::string &l
|
|||||||
loadLabels(labelsFileName);
|
loadLabels(labelsFileName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t MnistLoader::getSamleCount() const
|
||||||
|
{
|
||||||
|
return samples.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
const MnistLoader::MnistSample &MnistLoader::getSample(size_t index) const
|
||||||
|
{
|
||||||
|
if (index >= samples.size())
|
||||||
|
{
|
||||||
|
throw std::runtime_error("MNIST sample index out of range");
|
||||||
|
}
|
||||||
|
|
||||||
|
return *(samples[index].get());
|
||||||
|
}
|
||||||
|
|
||||||
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
|
const MnistLoader::MnistSample &MnistLoader::getRandomSample() const
|
||||||
{
|
{
|
||||||
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;
|
size_t sampleIndex = (std::rand() * (samples.size() - 1)) / RAND_MAX;
|
||||||
|
@ -31,6 +31,8 @@ private:
|
|||||||
public:
|
public:
|
||||||
void load(const std::string &databaseFileName, const std::string &labelsFileName);
|
void load(const std::string &databaseFileName, const std::string &labelsFileName);
|
||||||
|
|
||||||
|
size_t getSamleCount() const;
|
||||||
|
const MnistSample &getSample(size_t index) const;
|
||||||
const MnistSample &getRandomSample() const;
|
const MnistSample &getRandomSample() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Reference in New Issue
Block a user