Backprop seems to be working, yay

This commit is contained in:
mandlm 2015-10-18 21:20:37 +02:00
parent a79abb5db1
commit 6ef1f9657c
4 changed files with 32 additions and 8 deletions

View File

@ -53,19 +53,37 @@ void Layer::connectTo(const Layer & nextLayer)
void Layer::updateInputWeights(Layer & prevLayer)
{
static const double trainingRate = 0.8;
static const double trainingRate = 0.5;
for (size_t currentLayerIndex = 0; currentLayerIndex < size() - 1; ++currentLayerIndex)
for (size_t currentLayerIndex = 0; currentLayerIndex < sizeWithoutBiasNeuron(); ++currentLayerIndex)
{
Neuron &targetNeuron = at(currentLayerIndex);
for (size_t prevLayerIndex = 0; prevLayerIndex < prevLayer.size(); ++prevLayerIndex)
{
Neuron &sourceNeuron = prevLayer.at(prevLayerIndex);
sourceNeuron.setOutputWeight(currentLayerIndex,
sourceNeuron.getOutputWeight(currentLayerIndex) +
sourceNeuron.getOutputValue() * targetNeuron.getGradient() * trainingRate);
}
}
}
void Layer::addBiasNeuron()
{
push_back(Neuron(1.0));
hasBiasNeuron = true;
}
size_t Layer::sizeWithoutBiasNeuron() const
{
if (hasBiasNeuron)
{
return size() - 1;
}
else
{
return size();
}
}

View File

@ -6,6 +6,9 @@
class Layer : public std::vector < Neuron >
{
private:
bool hasBiasNeuron = false;
public:
Layer(size_t numNeurons);
@ -15,4 +18,8 @@ public:
void connectTo(const Layer & nextLayer);
void updateInputWeights(Layer &prevLayer);
void addBiasNeuron();
size_t sizeWithoutBiasNeuron() const;
};

View File

@ -17,8 +17,7 @@ Net::Net(std::initializer_list<size_t> layerSizes)
Layer &currentLayer = *layerIt;
const Layer &nextLayer = *(layerIt + 1);
Neuron biasNeuron(1.0);
currentLayer.push_back(biasNeuron);
currentLayer.addBiasNeuron();
currentLayer.connectTo(nextLayer);
}

View File

@ -9,12 +9,12 @@ int main()
{
std::cout << "Neuro running" << std::endl;
std::vector<double> inputValues = { 1.0, 4.0, 5.0 };
std::vector<double> targetValues = { 3.0 };
std::vector<double> inputValues = { 0.1, 0.2, 0.8 };
std::vector<double> targetValues = { 0.8 };
Net myNet({ inputValues.size(), 4, targetValues.size() });
for (int i = 0; i < 20; ++i)
for (int i = 0; i < 200; ++i)
{
myNet.feedForward(inputValues);