Graph out training results
This commit is contained in:
parent
6d57fa4650
commit
88a3924637
34
flow.py
34
flow.py
@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
print(tf.__version__)
|
print("Running TensorFlow", tf.__version__)
|
||||||
|
|
||||||
fashion_mnist = keras.datasets.fashion_mnist
|
fashion_mnist = keras.datasets.fashion_mnist
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class_names = [
|
|||||||
model = keras.Sequential(
|
model = keras.Sequential(
|
||||||
[
|
[
|
||||||
keras.layers.Flatten(input_shape=(28, 28)),
|
keras.layers.Flatten(input_shape=(28, 28)),
|
||||||
keras.layers.Dense(128, activation=tf.nn.relu),
|
keras.layers.Dense(256, activation=tf.nn.relu),
|
||||||
keras.layers.Dense(10, activation=tf.nn.softmax),
|
keras.layers.Dense(10, activation=tf.nn.softmax),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -44,11 +44,35 @@ model.compile(
|
|||||||
metrics=["accuracy"],
|
metrics=["accuracy"],
|
||||||
)
|
)
|
||||||
|
|
||||||
model.fit(train_images, train_labels, epochs=5)
|
|
||||||
|
|
||||||
test_loss, test_acc = model.evaluate(test_images, test_labels)
|
def plot_training(history):
|
||||||
|
acc = history.history["acc"]
|
||||||
|
val_acc = history.history["val_acc"]
|
||||||
|
|
||||||
print("Test accuracy:", test_acc)
|
epochs = range(1, len(acc) + 1)
|
||||||
|
|
||||||
|
plt.plot(epochs, acc, "bo", label="Training acc")
|
||||||
|
plt.plot(epochs, val_acc, "b", label="Validation acc")
|
||||||
|
plt.title("Training and validation accuracy")
|
||||||
|
plt.xlabel("Epochs")
|
||||||
|
plt.ylabel("Accuracy")
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
early_stop = keras.callbacks.EarlyStopping(monitor="val_loss", patience=5)
|
||||||
|
|
||||||
|
history = model.fit(
|
||||||
|
train_images,
|
||||||
|
train_labels,
|
||||||
|
epochs=64,
|
||||||
|
batch_size=512,
|
||||||
|
validation_data=(test_images, test_labels),
|
||||||
|
callbacks=[early_stop],
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_training(history)
|
||||||
|
|
||||||
predictions = model.predict(test_images)
|
predictions = model.predict(test_images)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user