1645 lines
996 KiB
Plaintext
1645 lines
996 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Programming Assignment"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Data pipeline with Keras and tf.data"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Instructions\n",
|
||
|
"\n",
|
||
|
"In this notebook, you will implement a data processing pipeline using tools from both Keras and the tf.data module. You will use the `ImageDataGenerator` class in the tf.keras module to feed a network with training and test images from a local directory containing a subset of the LSUN dataset, and train the model both with and without data augmentation. You will then use the `map` and `filter` functions of the `Dataset` class with the CIFAR-100 dataset to train a network to classify a processed subset of the images.\n",
|
||
|
"\n",
|
||
|
"Some code cells are provided you in the notebook. You should avoid editing provided code, and make sure to execute the cells in order to avoid unexpected errors. Some cells begin with the line:\n",
|
||
|
"\n",
|
||
|
"`#### GRADED CELL ####`\n",
|
||
|
"\n",
|
||
|
"Don't move or edit this first line - this is what the automatic grader looks for to recognise graded cells. These cells require you to write your own code to complete them, and are automatically graded when you submit the notebook. Don't edit the function name or signature provided in these cells, otherwise the automatic grader might not function properly. Inside these graded cells, you can use any functions or classes that are imported below, but make sure you don't use any variables that are outside the scope of the function.\n",
|
||
|
"\n",
|
||
|
"### How to submit\n",
|
||
|
"\n",
|
||
|
"Complete all the tasks you are asked for in the worksheet. When you have finished and are happy with your code, press the **Submit Assignment** button at the top of this notebook.\n",
|
||
|
"\n",
|
||
|
"### Let's get started!\n",
|
||
|
"\n",
|
||
|
"We'll start running some imports, and loading the dataset. Do not edit the existing imports in the following cell. If you would like to make further Tensorflow imports, you should add them here."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 56,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#### PACKAGE IMPORTS ####\n",
|
||
|
"\n",
|
||
|
"# Run this cell first to import all required packages. Do not make any imports elsewhere in the notebook\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"import tensorflow as tf\n",
|
||
|
"from tensorflow.keras.datasets import cifar100\n",
|
||
|
"import numpy as np\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import json\n",
|
||
|
"%matplotlib inline\n",
|
||
|
"\n",
|
||
|
"# If you would like to make further imports from tensorflow, add them here\n",
|
||
|
"from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
|
||
|
"\n",
|
||
|
"from tensorflow.keras.layers import Conv2D,MaxPooling2D,Dense,Flatten,Input\n",
|
||
|
"from tensorflow.keras.models import Model\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Part 1: tf.keras\n",
|
||
|
"<table><tr>\n",
|
||
|
"<td> <img src=\"data/lsun/church.png\" alt=\"Church\" style=\"height: 210px;\"/> </td>\n",
|
||
|
"<td> <img src=\"data/lsun/classroom.png\" alt=\"Classroom\" style=\"height: 210px;\"/> </td>\n",
|
||
|
" <td> <img src=\"data/lsun/conference_room.png\" alt=\"Conference Room\" style=\"height: 210px;\"/> </td>\n",
|
||
|
"</tr></table>\n",
|
||
|
" \n",
|
||
|
"#### The LSUN Dataset\n",
|
||
|
"\n",
|
||
|
"In the first part of this assignment, you will use a subset of the [LSUN dataset](https://www.yf.io/p/lsun). This is a large-scale image dataset with 10 scene and 20 object categories. A subset of the LSUN dataset has been provided, and has already been split into training and test sets. The three classes included in the subset are `church_outdoor`, `classroom` and `conference_room`.\n",
|
||
|
"\n",
|
||
|
"* F. Yu, A. Seff, Y. Zhang, S. Song, T. Funkhouser and J. Xia. \"LSUN: Construction of a Large-scale Image Dataset using Deep Learning with Humans in the Loop\". arXiv:1506.03365, 10 Jun 2015 \n",
|
||
|
"\n",
|
||
|
"Your goal is to use the Keras preprocessing tools to construct a data ingestion and augmentation pipeline to train a neural network to classify the images into the three classes."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Save the directory locations for the training, validation and test sets\n",
|
||
|
"\n",
|
||
|
"train_dir = 'data/lsun/train'\n",
|
||
|
"valid_dir = 'data/lsun/valid'\n",
|
||
|
"test_dir = 'data/lsun/test'"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Create a data generator using the ImageDataGenerator class"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"You should first write a function that creates an `ImageDataGenerator` object, which rescales the image pixel values by a factor of 1/255."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#### GRADED CELL ####\n",
|
||
|
"\n",
|
||
|
"# Complete the following function. \n",
|
||
|
"# Make sure to not change the function name or arguments.\n",
|
||
|
"\n",
|
||
|
"def get_ImageDataGenerator():\n",
|
||
|
" \"\"\"\n",
|
||
|
" This function should return an instance of the ImageDataGenerator class.\n",
|
||
|
" This instance should be set up to rescale the data with the above scaling factor.\n",
|
||
|
" \"\"\"\n",
|
||
|
" imagedatagenerator = ImageDataGenerator(rescale =(1/255.))\n",
|
||
|
" \n",
|
||
|
" return imagedatagenerator\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Call the function to get an ImageDataGenerator as specified\n",
|
||
|
"\n",
|
||
|
"image_gen = get_ImageDataGenerator()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"You should now write a function that returns a generator object that will yield batches of images and labels from the training and test set directories. The generators should:\n",
|
||
|
"\n",
|
||
|
"* Generate batches of size 20.\n",
|
||
|
"* Resize the images to 64 x 64 x 3.\n",
|
||
|
"* Return one-hot vectors for labels. These should be encoded as follows:\n",
|
||
|
" * `classroom` $\\rightarrow$ `[1., 0., 0.]`\n",
|
||
|
" * `conference_room` $\\rightarrow$ `[0., 1., 0.]`\n",
|
||
|
" * `church_outdoor` $\\rightarrow$ `[0., 0., 1.]`\n",
|
||
|
"* Pass in an optional random `seed` for shuffling (this should be passed into the `flow_from_directory` method).\n",
|
||
|
"\n",
|
||
|
"**Hint:** you may need to refer to the [documentation](https://keras.io/preprocessing/image/#imagedatagenerator-class) for the `ImageDataGenerator`."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 81,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#### GRADED CELL ####\n",
|
||
|
"\n",
|
||
|
"# Complete the following function.\n",
|
||
|
"# Make sure not to change the function name or arguments.\n",
|
||
|
"\n",
|
||
|
"def get_generator(image_data_generator, directory, seed=None):\n",
|
||
|
" \"\"\"\n",
|
||
|
" This function takes an ImageDataGenerator object in the first argument and a \n",
|
||
|
" directory path in the second argument.\n",
|
||
|
" It should use the ImageDataGenerator to return a generator object according \n",
|
||
|
" to the above specifications. \n",
|
||
|
" The seed argument should be passed to the flow_from_directory method.\n",
|
||
|
" \n",
|
||
|
" \"\"\"\n",
|
||
|
" # I couldn't get this one right. If you get it let me know So, I can recitfy.\n",
|
||
|
" image_data_gen = image_data_generator.flow_from_directory(\n",
|
||
|
" directory = directory,\n",
|
||
|
" batch_size = 20,\n",
|
||
|
" target_size = (64,64),\n",
|
||
|
" \n",
|
||
|
" class_mode = \"categorical\"\n",
|
||
|
" \n",
|
||
|
" )\n",
|
||
|
" return image_data_gen\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 82,
|
||
|
"metadata": {
|
||
|
"scrolled": true
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Found 300 images belonging to 3 classes.\n",
|
||
|
"Found 120 images belonging to 3 classes.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Run this cell to define training and validation generators\n",
|
||
|
"\n",
|
||
|
"train_generator = get_generator(image_gen, train_dir)\n",
|
||
|
"valid_generator = get_generator(image_gen, valid_dir)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We are using a small subset of the dataset for demonstrative purposes in this assignment."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Display sample images and labels from the training set\n",
|
||
|
"\n",
|
||
|
"The following cell depends on your function `get_generator` to be implemented correctly. If it raises an error, go back and check the function specifications carefully."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"scrolled": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Display a few images and labels from the training set\n",
|
||
|
"\n",
|
||
|
"batch = next(train_generator)\n",
|
||
|
"batch_images = np.array(batch[0])\n",
|
||
|
"batch_labels = np.array(batch[1])\n",
|
||
|
"lsun_classes = ['classroom', 'conference_room', 'church_outdoor']\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(16,10))\n",
|
||
|
"for i in range(20):\n",
|
||
|
" ax = plt.subplot(4, 5, i+1)\n",
|
||
|
" plt.imshow(batch_images[i])\n",
|
||
|
" plt.title(lsun_classes[np.where(batch_labels[i] == 1.)[0][0]])\n",
|
||
|
" plt.axis('off')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 32,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Found 300 images belonging to 3 classes.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Reset the training generator\n",
|
||
|
"\n",
|
||
|
"train_generator = get_generator(image_gen, train_dir)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Build the neural network model\n",
|
||
|
"\n",
|
||
|
"You will now build and compile a convolutional neural network classifier. Using the functional API, build your model according to the following specifications:\n",
|
||
|
"\n",
|
||
|
"* The model should use the `input_shape` in the function argument to define the Input layer.\n",
|
||
|
"* The first hidden layer should be a Conv2D layer with 8 filters, a 8x8 kernel size.\n",
|
||
|
"* The second hidden layer should be a MaxPooling2D layer with a 2x2 pooling window size.\n",
|
||
|
"* The third hidden layer should be a Conv2D layer with 4 filters, a 4x4 kernel size.\n",
|
||
|
"* The fourth hidden layer should be a MaxPooling2D layer with a 2x2 pooling window size.\n",
|
||
|
"* This should be followed by a Flatten layer, and then a Dense layer with 16 units and ReLU activation.\n",
|
||
|
"* The final layer should be a Dense layer with 3 units and softmax activation.\n",
|
||
|
"* All Conv2D layers should use `\"SAME\"` padding and a ReLU activation function.\n",
|
||
|
"\n",
|
||
|
"In total, the network should have 8 layers. The model should then be compiled with the Adam optimizer with learning rate 0.0005, categorical cross entropy loss, and categorical accuracy metric."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 33,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#### GRADED CELL ####\n",
|
||
|
"\n",
|
||
|
"# Complete the following function.\n",
|
||
|
"# Make sure not to change the function name or arguments.\n",
|
||
|
"\n",
|
||
|
"def get_model(input_shape):\n",
|
||
|
" \"\"\"\n",
|
||
|
" This function should build and compile a CNN model according to the above specification,\n",
|
||
|
" using the functional API. Your function should return the model.\n",
|
||
|
" \"\"\"\n",
|
||
|
" inputs = Input(input_shape)\n",
|
||
|
" h = Conv2D(8,(8,8),padding = \"SAME\")(inputs)\n",
|
||
|
" h = MaxPooling2D((2,2))(h)\n",
|
||
|
" h = Conv2D(4,(4,4),padding = \"SAME\")(h)\n",
|
||
|
" h = MaxPooling2D((2,2))(h)\n",
|
||
|
" h = Flatten()(h)\n",
|
||
|
" h = Dense(16, activation = \"relu\")(h)\n",
|
||
|
" outputs = Dense(3, activation = \"softmax\")(h)\n",
|
||
|
" \n",
|
||
|
" model = Model(inputs = inputs,outputs = outputs)\n",
|
||
|
" \n",
|
||
|
" model.compile(\n",
|
||
|
" optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0005),\n",
|
||
|
" loss = \"categorical_crossentropy\",\n",
|
||
|
" metrics = ['categorical_accuracy']\n",
|
||
|
" )\n",
|
||
|
" \n",
|
||
|
" return model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 34,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Model: \"model_1\"\n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"Layer (type) Output Shape Param # \n",
|
||
|
"=================================================================\n",
|
||
|
"input_2 (InputLayer) [(None, 64, 64, 3)] 0 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"conv2d_2 (Conv2D) (None, 64, 64, 8) 1544 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"max_pooling2d_2 (MaxPooling2 (None, 32, 32, 8) 0 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"conv2d_3 (Conv2D) (None, 32, 32, 4) 516 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"max_pooling2d_3 (MaxPooling2 (None, 16, 16, 4) 0 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"flatten_1 (Flatten) (None, 1024) 0 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dense_2 (Dense) (None, 16) 16400 \n",
|
||
|
"_________________________________________________________________\n",
|
||
|
"dense_3 (Dense) (None, 3) 51 \n",
|
||
|
"=================================================================\n",
|
||
|
"Total params: 18,511\n",
|
||
|
"Trainable params: 18,511\n",
|
||
|
"Non-trainable params: 0\n",
|
||
|
"_________________________________________________________________\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Build and compile the model, print the model summary\n",
|
||
|
"\n",
|
||
|
"lsun_model = get_model((64, 64, 3))\n",
|
||
|
"lsun_model.summary()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Train the neural network model\n",
|
||
|
"\n",
|
||
|
"You should now write a function to train the model for a specified number of epochs (specified in the `epochs` argument). The function takes a `model` argument, as well as `train_gen` and `valid_gen` arguments for the training and validation generators respectively, which you should use for training and validation data in the training run. You should also use the following callbacks:\n",
|
||
|
"\n",
|
||
|
"* An `EarlyStopping` callback that monitors the validation accuracy and has patience set to 10. \n",
|
||
|
"* A `ReduceLROnPlateau` callback that monitors the validation loss and has the factor set to 0.5 and minimum learning set to 0.0001\n",
|
||
|
"\n",
|
||
|
"Your function should return the training history."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 35,
|
||
|
"metadata": {
|
||
|
"scrolled": true
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#### GRADED CELL ####\n",
|
||
|
"\n",
|
||
|
"# Complete the following function.\n",
|
||
|
"# Make sure not to change the function name or arguments.\n",
|
||
|
"\n",
|
||
|
"def train_model(model, train_gen, valid_gen, epochs):\n",
|
||
|
" \"\"\"\n",
|
||
|
" This function should define the callback objects specified above, and then use the\n",
|
||
|
" train_gen and valid_gen generator object arguments to train the model for the (maximum) \n",
|
||
|
" number of epochs specified in the function argument, using the defined callbacks.\n",
|
||
|
" The function should return the training history.\n",
|
||
|
" \"\"\"\n",
|
||
|
" history = model.fit(train_gen,\n",
|
||
|
" validation_data = valid_gen,\n",
|
||
|
" epochs = epochs,\n",
|
||
|
" callbacks = [tf.keras.callbacks.EarlyStopping(patience = 10),tf.keras.callbacks.ReduceLROnPlateau(factor = 0.5,min_delta = 0.0001)]\n",
|
||
|
" )\n",
|
||
|
" \n",
|
||
|
" return history"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 36,
|
||
|
"metadata": {
|
||
|
"scrolled": true
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Train for 15 steps, validate for 6 steps\n",
|
||
|
"Epoch 1/50\n",
|
||
|
"15/15 [==============================] - 11s 743ms/step - loss: 1.0774 - categorical_accuracy: 0.3967 - val_loss: 0.9810 - val_categorical_accuracy: 0.5667\n",
|
||
|
"Epoch 2/50\n",
|
||
|
"15/15 [==============================] - 10s 647ms/step - loss: 0.9670 - categorical_accuracy: 0.5367 - val_loss: 0.8666 - val_categorical_accuracy: 0.6000\n",
|
||
|
"Epoch 3/50\n",
|
||
|
"15/15 [==============================] - 9s 633ms/step - loss: 0.8358 - categorical_accuracy: 0.6100 - val_loss: 0.8097 - val_categorical_accuracy: 0.6667\n",
|
||
|
"Epoch 4/50\n",
|
||
|
"15/15 [==============================] - 9s 633ms/step - loss: 0.7275 - categorical_accuracy: 0.6867 - val_loss: 0.7484 - val_categorical_accuracy: 0.6583\n",
|
||
|
"Epoch 5/50\n",
|
||
|
"15/15 [==============================] - 10s 653ms/step - loss: 0.6560 - categorical_accuracy: 0.7033 - val_loss: 0.7666 - val_categorical_accuracy: 0.6417\n",
|
||
|
"Epoch 6/50\n",
|
||
|
"15/15 [==============================] - 10s 673ms/step - loss: 0.6627 - categorical_accuracy: 0.7133 - val_loss: 0.7740 - val_categorical_accuracy: 0.6750\n",
|
||
|
"Epoch 7/50\n",
|
||
|
"15/15 [==============================] - 10s 640ms/step - loss: 0.6027 - categorical_accuracy: 0.7333 - val_loss: 0.8330 - val_categorical_accuracy: 0.6000\n",
|
||
|
"Epoch 8/50\n",
|
||
|
"15/15 [==============================] - 9s 633ms/step - loss: 0.6221 - categorical_accuracy: 0.7500 - val_loss: 0.7647 - val_categorical_accuracy: 0.6917\n",
|
||
|
"Epoch 9/50\n",
|
||
|
"15/15 [==============================] - 9s 627ms/step - loss: 0.6372 - categorical_accuracy: 0.7133 - val_loss: 0.8383 - val_categorical_accuracy: 0.6417\n",
|
||
|
"Epoch 10/50\n",
|
||
|
"15/15 [==============================] - 9s 625ms/step - loss: 0.6037 - categorical_accuracy: 0.7367 - val_loss: 0.7184 - val_categorical_accuracy: 0.7083\n",
|
||
|
"Epoch 11/50\n",
|
||
|
"15/15 [==============================] - 10s 634ms/step - loss: 0.5290 - categorical_accuracy: 0.8000 - val_loss: 0.7152 - val_categorical_accuracy: 0.7000\n",
|
||
|
"Epoch 12/50\n",
|
||
|
"15/15 [==============================] - 9s 631ms/step - loss: 0.4615 - categorical_accuracy: 0.8400 - val_loss: 0.7218 - val_categorical_accuracy: 0.7083\n",
|
||
|
"Epoch 13/50\n",
|
||
|
"15/15 [==============================] - 9s 627ms/step - loss: 0.4641 - categorical_accuracy: 0.8233 - val_loss: 0.8215 - val_categorical_accuracy: 0.6833\n",
|
||
|
"Epoch 14/50\n",
|
||
|
"15/15 [==============================] - 10s 638ms/step - loss: 0.4746 - categorical_accuracy: 0.7767 - val_loss: 0.7595 - val_categorical_accuracy: 0.6750\n",
|
||
|
"Epoch 15/50\n",
|
||
|
"15/15 [==============================] - 10s 640ms/step - loss: 0.4272 - categorical_accuracy: 0.8367 - val_loss: 0.7713 - val_categorical_accuracy: 0.6583\n",
|
||
|
"Epoch 16/50\n",
|
||
|
"15/15 [==============================] - 10s 640ms/step - loss: 0.3775 - categorical_accuracy: 0.8600 - val_loss: 0.7660 - val_categorical_accuracy: 0.6583\n",
|
||
|
"Epoch 17/50\n",
|
||
|
"15/15 [==============================] - 9s 627ms/step - loss: 0.3475 - categorical_accuracy: 0.8733 - val_loss: 0.8566 - val_categorical_accuracy: 0.6167\n",
|
||
|
"Epoch 18/50\n",
|
||
|
"15/15 [==============================] - 9s 627ms/step - loss: 0.3619 - categorical_accuracy: 0.8633 - val_loss: 0.7964 - val_categorical_accuracy: 0.6333\n",
|
||
|
"Epoch 19/50\n",
|
||
|
"15/15 [==============================] - 9s 627ms/step - loss: 0.3050 - categorical_accuracy: 0.9033 - val_loss: 0.8057 - val_categorical_accuracy: 0.6583\n",
|
||
|
"Epoch 20/50\n",
|
||
|
"15/15 [==============================] - 10s 653ms/step - loss: 0.2810 - categorical_accuracy: 0.9033 - val_loss: 0.8329 - val_categorical_accuracy: 0.6750\n",
|
||
|
"Epoch 21/50\n",
|
||
|
"15/15 [==============================] - 11s 733ms/step - loss: 0.2612 - categorical_accuracy: 0.9333 - val_loss: 0.8622 - val_categorical_accuracy: 0.6667\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Train the model for (maximum) 50 epochs\n",
|
||
|
"\n",
|
||
|
"history = train_model(lsun_model, train_generator, valid_generator, epochs=50)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Plot the learning curves"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 37,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA3sAAAFNCAYAAAC5cXZ6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzs3XdclXX7wPHPxVYBNRAnigMRB6Ii7r2ztMwcOdI027unzOrXzsoybTxPWamlptkwtUxz5Sz3nqCiIk4cOEDW9/fHfTRUZB8O4PV+vc5LOPe67js697nu6zvEGINSSimllFJKqaLFydEBKKWUUkoppZTKe5rsKaWUUkoppVQRpMmeUkoppZRSShVBmuwppZRSSimlVBGkyZ5SSimllFJKFUGa7CmllFJKKaVUEaTJnlLK7kQkQESMiLg4OhallFLKHkTkdRGZ6ug4lEpLkz1VqInIXyJyRkTcHR2LUkopVZiJSJSIdHR0HEqpvKPJniq0RCQAaAUYoEc+H1srVEoppZRSqkDTZE8VZoOBf4DJwP1pF4hIMRH5SEQOisg5EVkpIsVsy1qKyGoROSsih0VkiO39v0RkeJp9DBGRlWl+NyLymIhEABG298bb9hEnIhtEpFWa9Z1FZJSI7BOR87bl/iLyuYh8dF28c0Xk6etPUES+EJEPr3tvtog8a/v5RRE5Ytv/HhHpkJULJyIVRORnETkpIgdE5Mk0y14XkZ9E5AfbfjeKSP00y4Nt1+qsiOwQkR5plt30utsMEJFDInJKRF5Os124iKy3XcfjIjI2K+ehlFIqf4jIgyISKSKnRWSOiFSwvS8i8rGInLB97m8Vkbq2ZbeLyE7bveSIiDyfzn7dbfeTumneKyMi8SLiJyK+IvKbbZ3TIrJCRLL0/VVE7hCRzbZtV4tISJplUSLyki2+MyIySUQ8Mjtf27I6IrLQtuy4iIxKc1g3EfnOds47RCQszXY5umcrlSvGGH3pq1C+gEjgUaARkASUTbPsc+AvoCLgDDQH3IHKwHmgP+AK+AChtm3+Aoan2ccQYGWa3w2wELgNKGZ7b6BtHy7Ac8AxwMO27D/ANiAIEKC+bd1wIAZwsq3nC1xKG3+aY7YGDgNi+700EA9UsO33MFDBtiwAqJ6F6+YEbAD+D3ADqgH7gS625a/brmdv2zV6Hjhg+9nVdt1H2bZtb7ueQZlc9wDb9fsKKGa7FpeBYNt2fwODbD97Ak0d/felL33pS1+32guIAjqm83574BTQ0PaZ/imw3Lasi+2eUsp2rwsGytuWHQVa2X4uDTS8yXEnAu+k+f0xYL7t59HAF2nuQa2u3BMzOZeGwAmgie1+dL/t/NzTnOt2wB/rvr4KeDsL5+tlO6/nAA/b701sy14HEoDbbcccDfxjW5aje7a+9JXbl1b2VKEkIi2BKsBMY8wGYB9wn22ZE/AA8JQx5ogxJsUYs9oYcxkYACwyxkw3xiQZY2KNMZuzcejRxpjTxph4AGPMVNs+ko0xH2HdFIJs6w4HXjHG7DGWLbZ11wLngCtP9PoBfxljjqdzvBVYSdKVimFv4G9jTAyQYjtebRFxNcZEGWP2ZeEcGgNljDFvGmMSjTH7sZKwfmnW2WCM+ckYkwSMxbqhNbW9PIH3bNsuAX4D+mdy3a94wxgTb4zZAmzBSvrASi5riIivMeaCMeafLJyHUkqp/DEAmGiM2Wj7TH8JaCZWd4okrISnFlYStssYc9S2XRLWPcrbGHPGGLPxJvv/Hush7BX32d67so/yQBXbfXuFMcZkIeYHgS+NMWts96NvsR4yNk2zzmfGmMPGmNPAO2liyOh87wCOGWM+MsYkGGPOG2PWpNnnSmPMPGNMCjCFf+9zOb1nK5Urmuypwup+4E9jzCnb79/zb1NOX6zkJL0PUf+bvJ9Vh9P+IiLPicguW9OVs0BJ2/EzO9a3WFVBbP9OSW8l2w1tBv/egO4DptmWRQJPYz1JPCEiM9I2M8lAFaCCrVnLWVvco4Cy6Z2nMSYViMaqJlYADtveu+IgViUvo+t+xbE0P1/CShwBhgE1gd0isk5E7sjCeSillMofFbA+6wEwxlwAYoGKtod+n2G17DguIhNExNu26j1YVa6DIrJMRJrdZP9LgGIi0kREqgChwCzbsjFYLUr+FJH9IjIyizFXAZ677l7nbzuXK9Le0w+mWXbT8yXz7xHX3+c8RMQlF/dspXJFkz1V6Nj6gPUB2ojIMRE5BjwD1Lf1LTuF1YyiejqbH77J+wAXgeJpfi+XzjpXnyaK1T/vRVsspY0xpbAqdpKFY00FetriDQZ+vcl6ANOB3rYbYBPg56vBGPO9MeZKldMA72ewnysOAweMMaXSvLyMMbenWcc/zXk6AZWwmp7GAP7X9ZeoDBwh4+ueIWNMhDGmP+BnO4efRKREdvejlFLKLmKw7jMA2D6ffbA++zHGfGKMaQTUwXpw9x/b++uMMT2xPtt/BWamt3PbA8SZWA827wN+M8acty07b4x5zhhTDbgTeDaLfd0OYzUNTXuvK26MmZ5mHf80P1e2nWdm55vRvT1DObxnK5UrmuypwugurOYQtbGe/oViJUwrgMG2m8ZEYKxYA5E4i0gzsaZnmAZ0FJE+IuIiIj4iEmrb72agl4gUF5EaWNWmjHgBycBJwEVE/g/wTrP8a+AtEQm0dWAPEREfAGNMNLAOq6L385VmoekxxmyyHeNrYIEx5iyAiASJSHvbeSVg9eVLyfzysRaIs3UUL2a7PnVFpHGadRqJSC+xRh19Gqvpyz/AGqyk+AURcRWRtlg33xmZXPcMichAESlj28dZ29tZORellFJ5y1VEPNK8XLBazwwVkVDbZ/q7wBpjTJSINLZV5Fyx7g8JQIqIuInIABEpaesSEEfGn+vfA32xmlBeacJ5ZZCVGiIiafaRlfvDV8DDtthEREqISHcR8UqzzmMiUklEbsNq4fJDmljSPV+srgvlRORpsQaX8RKRJpkFk4t7tlK5osmeKozuByYZYw4ZY45deWE1IxlguzE9jzU4yjrgNNbTMydjzCGsJiXP2d7fzL/t6T8GEoHjWM0sp2USxwLgD2AvVnOPBK5tEjIW60nln1g3qG+wBie54lugHjdpwnmd6UBH0twAsdr+v4dVUTuG9eR0FIDtBrsjvR3Z+hHciZUkH7Bt/zVWE9QrZmPddM8Ag4Betr4SiVjTXHSzbfdfrAR7t227dK97Fs6vK7BDRC4A44F+xpiELGynlFIqb83DSkSuvF43xiwGXsVqWXIUq7J1pZ+3N1ZidQbrXhgLXBlFehAQJSJxwMP8233hBrZ+bxexmlD+kWZRILAIuIA1mNd/jTF/AYjIH3LtSJhp97ceq9/eZ7bYIrEGXkvre6x79H7b623btjc9X1vFsRPWffQY1ujc7W52Xmnc9J6tlD1dGeFPKZXPRKQ1VnPOgOv6wDmUiLwO1DDG3PSmrJRSShVmIhKFNQL3IkfHopQ9aWVPKQewNXd5Cvi6ICV6SimllFKq6NBkT6l8JiLBWP3SygPjHByOUkoppZQqorQZp1JKKaWUUkoVQVrZU0oppZRSSqkiSJM9pZRSSimllCqCXBwdQHb5+vqagIAAR4ehlFIqH2zYsOGUMaaMo+MoLPQeqZRSt4as3h8LXbIXEBDA+vXrHR2GUkqpfCAiBx0dQ2Gi90illLo1ZPX+qM04lVJKqXwmIhNF5ISIbL/J8loi8reIXBaR5/M7PqWUUkWDJntKKaVU/psMdM1g+WngSeDDfIlGKaVUkaTJnlJKKZXPjDHLsRK6my0/YYxZByTlX1RKKaWKmkLXZ08ppZRSSilV8CQlJREdHU1CQoKjQykyPDw8qFSpEq6urjnaXpM9pZRSqhATkRHACIDKlSs7OBql1K0sOjoaLy8vAgICEBFHh1PoGWOIjY0lOjqaqlWr5mgf2oxTKaWUKsSMMROMMWHGmLAyZXSWCqWU4yQkJODj46OJXh4REXx8fHJVKdVkTymllFJKKZUnNNHLW7m9nprsKaWUUvlMRKYDfwNBIhItIsNE5GERedi
|
||
|
"text/plain": [
|
||
|
"<Figure size 1080x360 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Run this cell to plot accuracy vs epoch and loss vs epoch\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(15,5))\n",
|
||
|
"plt.subplot(121)\n",
|
||
|
"try:\n",
|
||
|
" plt.plot(history.history['accuracy'])\n",
|
||
|
" plt.plot(history.history['val_accuracy'])\n",
|
||
|
"except KeyError:\n",
|
||
|
" try:\n",
|
||
|
" plt.plot(history.history['acc'])\n",
|
||
|
" plt.plot(history.history['val_acc'])\n",
|
||
|
" except KeyError:\n",
|
||
|
" plt.plot(history.history['categorical_accuracy'])\n",
|
||
|
" plt.plot(history.history['val_categorical_accuracy'])\n",
|
||
|
"plt.title('Accuracy vs. epochs')\n",
|
||
|
"plt.ylabel('Accuracy')\n",
|
||
|
"plt.xlabel('Epoch')\n",
|
||
|
"plt.legend(['Training', 'Validation'], loc='lower right')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(122)\n",
|
||
|
"plt.plot(history.history['loss'])\n",
|
||
|
"plt.plot(history.history['val_loss'])\n",
|
||
|
"plt.title('Loss vs. epochs')\n",
|
||
|
"plt.ylabel('Loss')\n",
|
||
|
"plt.xlabel('Epoch')\n",
|
||
|
"plt.legend(['Training', 'Validation'], loc='upper right')\n",
|
||
|
"plt.show() "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"You may notice overfitting in the above plots, through a growing discrepancy between the training and validation loss and accuracy. We will aim to mitigate this using data augmentation. Given our limited dataset, we may be able to improve the performance by applying random modifications to the images in the training data, effectively increasing the size of the dataset."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Create a new data generator with data augmentation\n",
|
||
|
"\n",
|
||
|
"You should now write a function to create a new `ImageDataGenerator` object, which performs the following data preprocessing and augmentation:\n",
|
||
|
"\n",
|
||
|
"* Scales the image pixel values by a factor of 1/255.\n",
|
||
|
"* Randomly rotates images by up to 30 degrees\n",
|
||
|
"* Randomly alters the brightness (picks a brightness shift value) from the range (0.5, 1.5)\n",
|
||
|
"* Randomly flips images horizontally\n",
|
||
|
"\n",
|
||
|
"Hint: you may need to refer to the [documentation](https://keras.io/preprocessing/image/#imagedatagenerator-class) for the `ImageDataGenerator`."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 38,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#### GRADED CELL ####\n",
|
||
|
"\n",
|
||
|
"# Complete the following function. \n",
|
||
|
"# Make sure to not change the function name or arguments.\n",
|
||
|
"\n",
|
||
|
"def get_ImageDataGenerator_augmented():\n",
|
||
|
" \"\"\"\n",
|
||
|
" This function should return an instance of the ImageDataGenerator class \n",
|
||
|
" with the above specifications.\n",
|
||
|
" \"\"\"\n",
|
||
|
" \n",
|
||
|
" image_data_gen = ImageDataGenerator(\n",
|
||
|
" rescale = (1/255.),\n",
|
||
|
" rotation_range = 30,\n",
|
||
|
" brightness_range = (0.5,1.5),\n",
|
||
|
" horizontal_flip = True)\n",
|
||
|
" \n",
|
||
|
" return image_data_gen\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 39,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Call the function to get an ImageDataGenerator as specified\n",
|
||
|
"\n",
|
||
|
"image_gen_aug = get_ImageDataGenerator_augmented()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 40,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Found 120 images belonging to 3 classes.\n",
|
||
|
"Found 300 images belonging to 3 classes.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Run this cell to define training and validation generators \n",
|
||
|
"\n",
|
||
|
"valid_generator_aug = get_generator(image_gen_aug, valid_dir)\n",
|
||
|
"train_generator_aug = get_generator(image_gen_aug, train_dir, seed=10)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 41,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Found 300 images belonging to 3 classes.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Reset the original train_generator with the same random seed\n",
|
||
|
"\n",
|
||
|
"train_generator = get_generator(image_gen, train_dir, seed=10)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Display sample augmented images and labels from the training set\n",
|
||
|
"\n",
|
||
|
"The following cell depends on your function `get_generator` to be implemented correctly. If it raises an error, go back and check the function specifications carefully. \n",
|
||
|
"\n",
|
||
|
"The cell will display augmented and non-augmented images (and labels) from the training dataset, using the `train_generator_aug` and `train_generator` objects defined above (if the images do not correspond to each other, check you have implemented the `seed` argument correctly)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 42,
|
||
|
"metadata": {
|
||
|
"scrolled": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4AAAAFTCAYAAABoAWL/AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsnXeYXVXV/z/r1umTnpBCAoTeu2AoItJERLCgooCACir62kVRBDtYUHnFF1RQBAV+gIqAIlWUjvQeEkIK6dPLbef3x1rrzL0nM0koJoG7v88zz53T9tlnn7X32Xt9V5EoiggICAgICAgICAgICAh44yO1visQEBAQEBAQEBAQEBAQsG4QFoABAQEBAQEBAQEBAQF1grAADAgICAgICAgICAgIqBOEBWBAQEBAQEBAQEBAQECdICwAAwICAgICAgICAgIC6gRhARgQEBAQEBAQEBAQEFAnCAvAgICAgA0EInKxiMwf4dj+IhKJyIHrul5vFIjIkSLy2f9CuReLyNy1OC8SkTNf6/sHBAQEBAS8HGTWdwUCAgICAgLWEY4EDgR+tJ7uvxcw7AI/ICAgICBgXSEsAAMCAgICAtYBoii6e33XISAgICAgIJiABgQEBLxOISJzReRSETlGRJ4UkV4RuV9EZiXO211ErhKR+SLSLyJPi8h3RKRxmPIuHuY+q5guisj7ReQpERkQkUdF5AgRuU1Ebqs6x81WjxSRX4rIChFZKSI/FpG01etOq/fjInLwMPfeT0RuFpFuO+9vIrJd4pzbrJwDReRBEekTkcdE5Miqcy4GjgOmWJ2iarNNERknIr8QkQUiMmjP9tFh6vNWu8eAiMwWkY+N8HpWQbIdReRM27eVPVeviMwTkRPs+IesHj0icquIbJYo7xgRuUVElto5/xGR44a573gRuVxEuqz9f2PvKxKR/RPnHiUid1sbdojIlSKyceKcD9i9ekSk097/WrdDQEBAQMD6RWAAAwICAl7f2AfYEjgDGADOBq4TkRlRFHXYORsDDwEXA93AtsDXgU2BY17uDUXkbcDvgT8DnwPGAT8BGoBnhrnkJ8DVwPuAfYGvod+fA4FzgAW272oRmR5F0TK7z9uBPwF/BY61sr4E/FNEdoii6MWqe2wGnAd8F1hm9bpKRLaKoug5a5fxwO7AEXbNoN2nDfgX0AicCcwBDgZ+ISL5KIp+ZudtDVwP3G/tlrfzW4Dyy2jCJK4ELgTOBU4Ffi0imwP7A18GsvZslwF7Vl23KXAV8D2ggrbtRSLSGEXRBVXnXQ1sD3wFeA44GvhZshIi8nHgF8BvgLOAVnu+2629u025cCnwU+ALqCJ5K2DUq3j+gICAgIB1iSiKwl/4C3/hL/xtAH/oAm3+CMf2ByLgwKp9c4GVwOiqfbvZeR8YoRxBF1/HoouGsYnyLh7mmgg4s2r738BjgFTt28XOu22YOv86Ud6Dtn9W1b4dbN9xVfueA25OXNuGLvB+UrXvNqAIbF61bwK6KDt9Te3L0OJ588T+C+1eGdv+vW03V50zDSgAc9fi/Sbb8Uzb9+GqfaOBErAcaKvaf5qdO32EslP2Xi8EHq7af5Bd997E+X+2/fvbdgvQOcy7mmHP9xnb/jywYn33lfAX/sJf+At/r/wvmIAGBAQEvL5xVxRFK6u2H7Xf2GxPRNpE5PsiMhtlvYrA79DF4OYv52YikkYXmf8viqLI90dR9CDKnA2HGxLbTwG9URTdmdgHuqDCGLDNgN+LSMb/gD7gLpTtqsazURQ9W1WfJcASqtphNTgEuAeYk7jX34CxwDZ23l7A9VEU9Vbd50WUPXw1iNvH3uUS4O4oirqqzqlpH9A2MtPOBeg7LQInoYyw403oQviaxD2vSmzvhS6uk+093+7t7X0fMFrU9PhwEQnMX0BAQMDrDMEENCAgIGDDQQlIj3AsXXVONVZUb0RRNCgioOaYjt+g5pZfR01Be4E9gPMT560NxqEmiUuGObZ4hGtWJrYLQEf1jiiKCol6T7DfX9lfEvMS2yuGOWeQtXu+CcBMdAE1HMba70YM/4yLgU3W4j4jYbj2GW4f2POISAtwE7og/jIw2845BfhI1XUbASujKEo+W/I5vL3/sbo6RlF0u4i8B/gUtqgUkduBz0ZR9MgI1wYEBAQEbEAIC8CAgICADQdLgHEikouiqJA4Ntl+R1pkDQsRaQDeiZoenle1f/thTh8AconrxyTOWYYulCawKiay6sLslWK5/X6F4RclyfZ5tfdaAnx6hONP2+8i9BmTGG7ffxt7AdOBfaqZVGPtqrEIZeyyiUVgss7e3scDjw9zv27/J4qiq1D/yhbUzPf7wI0iMjWKosoreJaAgICAgHWIsAAMCAgI2HBwK7rgOYJVTfSORifzTycvWgPyKHuYZICOH+bcF4DtEvsOr96IoqgsIvcDR4vImW4GKiK7oizYa7UAfBr1Sdw2iqLvvUZlDqKBXpK4EWW05pnp6Ei4CzhMRJrdDFREpgFvBha+RnVcWzTZb/xeRWQ0utivxt3o+38XcEXV/vckzvs3usibGUXRJWtTgSiKetCAQ5uiQWrGAkvX9gECAgICAtYPwgIwICAgYMPBP1CzvotFZCvUL60VjTj5TuCEl8uwRFHUKSJ3A58TkUUog/cRYMowp/8BjUD5Y+A6YEeGXyh+A/g7cI2I/B9qFnom8BIaWOZVI4qiSEQ+AfxJRHLo4mUZylztjS7WXm5C9yeAMSJyChrJcyCKokeBH6MRSv9pz/400IxGt9wniiJfVH0LXTj9XUTOQdnSb/IyWdnXCP8GuoDzReQbVt+voW3U7idFUfR3EbkT+D8RGYcG1nk3+m7B3lcURV0i8gUrbzzql9iJysl+aHCfy0TkLPQd3IoueqeiAWoeiqIoLP4CAgICXgcIQWACAgICNhAYm3YEmjbhw+gi7BLUj+vIKIoufoVFvx94APX5uxhdqA1n7ngJurg7CvgLmgrhXcPU8ybgg8DWqB/Yl9C0Cy+hi4bXBFEUXY8GH2kGLkKDsvwAmISycS8XF6GL3O8A96LPSBRFneii8nr0Wf4G/BpddN9aVZ8ngcNQ9u2PaPqFnwA3v4K6vCrYYutdKLt3FZr+4iI0RUMSR6Es5/fRhXQDGvkUqt5XFEW/ROVvSzRI0A3oAjeD+o6CKiVmoIvmm6zM24G3v1bPFhAQEBDw34VUBXELCAgICAh4RRCRqSi79O0ois5e3/UJWD1E5HyU3R0TRdHgeq5OQEBAQMA6RDABDQgICAh4WRCRRuBHqMnqMjQh+RfRiJQXrceqBQwDETkeNQt9HDVbPQT4OHBOWPwFBAQE1B/CAjAgICAg4OWijJph/hwN/NEL/BN4TxRFi9ZnxQKGRS/wGTSvYh7N13g6cM76rFRAQEBAwPpBMAENCAgICAgICAgICAioE4QgMAEBAQEBAQEBAQEBAXWCsAAMCAgICAgICAgICAioE4QFYEBAQEBAQEBAQEBAQJ0gLAADAgICAgICAgICAgLqBGEBiIbIFpE713c9AgL+GwjyHRAQEBAQEBAQ4AgLwICAgICA1wVE8RsRWSki967v+gQErC3WpSJORC4WkW+ti3utLURkhohEIhLSj9Uxwhi+4SB0xHUMERE0/UZlfdclIOC1RpDvgP8yZgFvA6ZGUdS7visTEBCgizs0t2Q2iqLS+q1NwAaOMIZvIKg7BlBEponI1SKyVESWi8jPhznnPBF5UUS6ROQBEdmn6tgeInK/HVssIj+y/Q0icqmV2SEi94nIRDt2m4h8W0T+BfQBm4rIZBH5s4isEJHnROTkqnvkReQnIrLQ/n4iInk7tr+IzBeRL4rIEhFZJCJHishhIvKMlXf6f7sdAzZMBPkOeINjOjD3lUwcNnTmYUOvX8CGgyArQwht8bpDGMM3ENTVAlBE0sB1wAvADGAK8IdhTr0P2AkYA1wGXCkiDXbsPOC8KIragM2AK2z/cUA7MA0YC3wc6K8q80PAR4FWu//lwHxgMvBu4Dsi8lY796vAm6wOOwJ7AF+rKmsS0GD1/zpwIXAssCu
|
||
|
"text/plain": [
|
||
|
"<Figure size 1152x360 with 10 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4AAAAFTCAYAAABoAWL/AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvXecHld1Pv6ct2/XanfVu+Qiy73JNi5ywYDBBAymxWCDcWwCCfni/ELoNjgQkm8oSX5JgBAbMKHYGDAYcAHLvVd1S7Z610ravvu2+f7xnDPvO6PdlWyMZOk9z+ezGr0zd2buzNx7Z+55znOOBEEAh8PhcDgcDofD4XAc+kgc6Ao4HA6Hw+FwOBwOh2P/wCeADofD4XA4HA6Hw1Ej8Amgw+FwOBwOh8PhcNQIfALocDgcDofD4XA4HDUCnwA6HA6Hw+FwOBwOR43AJ4AOh8PhcDgcDofDUSPwCaDD4XAcxBCR/xaRQES+dqDrciAgIseLyHUiMvZVPu4Vel9n7KXcQhFZ+Gqe2+FwOByOPyV8AuhwOBwHKUSkDsCl+vPPRSR1IOtzgHA8gC8AeFUngC8Df6l/DofD4XAcFPAJoMPhcBy8eDuAZgC/ATAOwBsPbHVqD0EQLA2CYOmBrofD4XA4HPsKnwA6HA7HwYvLAewCcAWAAQAfiBcQkZtEZM0w6/dwXRSRE0XkAREZFJH1IvJpEbleRIJYuUBEbhCRa0VkrYj0icgdIjJO/34qIl16jE8Oc+6ZIvJDEdkuIkMi8qyIvD1W5jo9z2F67F491+dFJKFlrgBwo+6yUsuHbpsikhKRT4nIcj3PJhH5FxHJxc41S8/Rr3X6JoDsKPd9xPsoIgu0Dm8TkW+JyE4R2SUiXxeRpIicIiIP6j1bIiJviB3vFBG5VUQ2iMiAiKwQkS8r21tdLqnPYLPW+w8icqSe+7pY2eNE5Hatx4CIPCQiZw1z3rtFpFOP95KI/Me+3AOHw+FwHFyoRXchh8PhOOghIpMAXADg20EQbBeRXwC4RERagyDY9QqO1w7g9wA2gRPJPID/A2DGCLu8H8Bi0P1xPIBvAPg+gCYAvwXwbdA99R9FZFEQBL/R80wF8BiAbXr87QDeDeBnIvK2IAhuj53n5+Ak7+sALgZwPYD1uu4OADcA+Kyea4Pus1mXN+s+XwXwMIC5AL6k1/QOrU8GwN0A6gB8VOt1NYBL9n7XRsU3ANym13a21jEFPrN/BrBR190mItODINih+00D8CyAmwD0AJgH4PMAZgF4T9XxrwfwaT3WPQBOBBC/dxCREwE8AOAZAFcB6AdwDYB7ROSMIAieEpFGAHcCeBw0JvSA9+iMP/IeOBwOh+O1iCAI/M///M///O8g+wPwSQABgNP19xv09zWxcjcBWDPM/gsBLKz6/WVw0jelal0dgK18VUT2DQC8ACBVte5ruv6zVetS4ITqxqp13wUnfW2xY94N4Nmq39fp8T4YK7cIwF1Vv6/QcnNi5c7S9R+Irf9zXX+8/r5Kf59WVSYBYImun7GX5xC/jwt0v/+JlXta159Zte5YXXf5CMcWvYeXASjbPQPQCqAXwH/Eyn9Cj3dd1brfA1gGIFO1LqnrfqG/T9b9jj3Q7dr//M///M///vR/7gLqcDgcByc+AGBlEASP6O97UGHvXglOA/BIEATGoiEIggGQZRsOdwdBUKz6vVyXd1btXwSwCsDUqnJvBDWLXeqimdLgNXcCOE5EmmPniZ9/MciS7Q1vBCe0P4ud5y7dfrYuTwewPgiCR6vqXQbw0304x2j4bez3cgB9QRA8GFsHVN0fEWkWka+KyIsAhgAUAPwAnAwepsWOAdAA4JbYOW6t/qFuo+douXLVPRCwvdg9WAlgN4BvichlytI6HA6H4xCFTwAdDofjIIOInALgKNB9cIyIjAFdL28DcLqIHP4KDjsRZOvi2DpC+bibaX6U9dWau3HgJLUQ+/tn3d4W239n7PdQ7HgjYRyADMiUVZ/HrtHOMxHDX+NI172vGO4+7K5eEQSB3bPq67kRdNH8VwCvB3AK6JpaXW6iLuPPK17nsSDb9znseb8/BqBVRBJBEHQBOBc0IPwHgHUislhE3rH3y3Q4HA7HwQbXADocDsfBh8t1+Un9i+MDoL4MAAbBiVAcbQA6q35vBidNcYx/hXUcCZ2gJu2rI2zf9CqeZxB0BR3tPJtBnV0cr/Z17xUanObPQBfOb1atPyZW1DSO40BXVUO8zrtB19H/H9Rn7gFlOxEEwbMA3qEM4ckAPgXgpyJyXBAEi1/ZFTkcDofjtQifADocDsdBBA1a8h4wkMrfD1Pk6wDeLyKfC4IgALAWwHgRaQ800IiIzAZwBBgYxfAogL8VkSnmBqouhG9+lS/hd6Db5RJ1Mf1jMaTLutj634GT45YgCH4/yv6PAPigiJxmbqAaZfRdr0LdXi6yIGNXiK2/IvZ7EYA+MPDNvVXrL60uFARBn4g8AOA4AE/bZG80qNvuoyLyOQBvBQPn+ATQ4XA4DiH4BNDhcDgOLrwFZO+uDYJgYXyjiHwLwH+CwUjuBfVfXwLwQxH5GoB2kN3ZEdv1awA+AuBOEbkenFh9QpcBXj18How2eb+I/DuANWBQk6MBzAqC4EMv83iWg++jIvI9cPL0fBAEC0XkRwBu1et+HGTDZgC4CMAngyB4AcD3wIn0bSLyadCt8howv+J+RRAEXSLyKIBrRWQz+Iw+BGByrNwuEfkGgE+LSA8qUUCv1CLVE71PALgffK7fBdnDdi2fDILg70XkLQD+AsAvAKwG9YV/DUYDfQQOh8PhOKTgGkCHw+E4uHA5+GEeDwBi+BGYE/ByAAiCYBWAd4KTiF8A+DtwUvBC9U7KDp4Pate+D2rB7gHTMHS9WpUPgmAd6GL4HBh59G5wwnoOgD+8guM9B0YMvRjAgwCeADBJN1+m294J4JdgkJSPgUFPtur+eVBr9yx4zd8DJ0E3vPyre1XwXgBPgW6bNwHYAuDjw5T7AoCvgM/5dgBvQoUpDJ9XEARPgzrCTlBXeBeAb4KBZO7XYivBNvM5MHjNjQCKAF5fHRTI4XA4HIcGhB5CDofD4XBEISJJMH3BjiAIzj/Q9XGMDhG5FIxeenYQBA8c6Po4HA6H47UJdwF1OBwOBwBARL4Epm1YC7qZfhjMVXfRgayXY0+IyHxQn/kYGOzmJNCV9VGQCXU4HA6HY1j4BNDhcDgchgDU6E3S/z8P4G1BEMRz2jkOPHrBPH4fBfWK20D271OBu/Y4HA6HYxS4C6jD4XA4HA6Hw+Fw1Ag8CIzD4XA4HA6Hw+Fw1Ah8AuhwOBwOh8PhcDgcNQKfADocDofD4XA4HA5HjcAngA6Hw+FwOBwOh8NRI/AJIAARuUJEPGy245CEt2+Hw+FwOBwOh8EngA6Hw+E4KCDEjSKyS0QeP9D1cTj2FfvTECciN4nIDfvjXPsKEZkhIoGIePqxGoaP4a8deEfczxARAdNvlA90XRyOVxvevh1/YpwJ4PUApgRB0HegK+NwODi5A7AaQDoIguKBrY3jNQ4fw18jqDkGUESmishtIrJdRDpF5N+HKfNNEVkvIt0i8pSInFW17VQReVK3bRWRr+n6nIjcrMfcLSJPiMh43bZQRP5BRB4C0A9glohMEpHbRWSniKwSkauqzpEVkW+IyCb9+4aIZHXbAhHZICJ/JyLbRGSziLxNRC4SkRf0eJ/+U99Hx2sT3r4dhzimA1jzSj4cXuvMw2u9fo7XDrytVOD34qCDj+GvEdTUBFBEkgB+DWAtgBkAJgP48TBFnwBwPICxAP4XwC0iktNt3wTwzSAImgHMBvBTXX85gBYAUwG0AbgGwEDVMd8P4C8ANOn5fwRgA4BJAN4J4Msicr6W/QyA07QOxwE4FcBnq441AUBO6/95AN8BcBmAkwCcBeDzIjJrn2+M45CAt2/Haw3DGSREJCEinxWRtTrJ/76ItGh5cxO7XETWicgOEfmMbrsSwH8DOF1EekXkel3/FhF5Vg0TD4vIsVXnXyMinxSR5wH0iUhKjRM/0zqtFpG
|
||
|
"text/plain": [
|
||
|
"<Figure size 1152x360 with 10 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Display a few images and labels from the non-augmented and augmented generators\n",
|
||
|
"\n",
|
||
|
"batch = next(train_generator)\n",
|
||
|
"batch_images = np.array(batch[0])\n",
|
||
|
"batch_labels = np.array(batch[1])\n",
|
||
|
"\n",
|
||
|
"aug_batch = next(train_generator_aug)\n",
|
||
|
"aug_batch_images = np.array(aug_batch[0])\n",
|
||
|
"aug_batch_labels = np.array(aug_batch[1])\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(16,5))\n",
|
||
|
"plt.suptitle(\"Unaugmented images\", fontsize=16)\n",
|
||
|
"for n, i in enumerate(np.arange(10)):\n",
|
||
|
" ax = plt.subplot(2, 5, n+1)\n",
|
||
|
" plt.imshow(batch_images[i])\n",
|
||
|
" plt.title(lsun_classes[np.where(batch_labels[i] == 1.)[0][0]])\n",
|
||
|
" plt.axis('off')\n",
|
||
|
"plt.figure(figsize=(16,5))\n",
|
||
|
"plt.suptitle(\"Augmented images\", fontsize=16)\n",
|
||
|
"for n, i in enumerate(np.arange(10)):\n",
|
||
|
" ax = plt.subplot(2, 5, n+1)\n",
|
||
|
" plt.imshow(aug_batch_images[i])\n",
|
||
|
" plt.title(lsun_classes[np.where(aug_batch_labels[i] == 1.)[0][0]])\n",
|
||
|
" plt.axis('off')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 43,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Found 300 images belonging to 3 classes.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Reset the augmented data generator\n",
|
||
|
"\n",
|
||
|
"train_generator_aug = get_generator(image_gen_aug, train_dir)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Train a new model on the augmented dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 44,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Build and compile a new model\n",
|
||
|
"\n",
|
||
|
"lsun_new_model = get_model((64, 64, 3))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 45,
|
||
|
"metadata": {
|
||
|
"scrolled": true
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Train for 15 steps, validate for 6 steps\n",
|
||
|
"Epoch 1/50\n",
|
||
|
"15/15 [==============================] - 11s 761ms/step - loss: 1.0921 - categorical_accuracy: 0.3900 - val_loss: 1.0440 - val_categorical_accuracy: 0.4500\n",
|
||
|
"Epoch 2/50\n",
|
||
|
"15/15 [==============================] - 10s 700ms/step - loss: 0.9937 - categorical_accuracy: 0.5300 - val_loss: 0.9592 - val_categorical_accuracy: 0.6000\n",
|
||
|
"Epoch 3/50\n",
|
||
|
"15/15 [==============================] - 10s 700ms/step - loss: 0.9181 - categorical_accuracy: 0.5933 - val_loss: 0.8677 - val_categorical_accuracy: 0.6000\n",
|
||
|
"Epoch 4/50\n",
|
||
|
"15/15 [==============================] - 10s 683ms/step - loss: 0.8742 - categorical_accuracy: 0.6167 - val_loss: 0.8353 - val_categorical_accuracy: 0.6333\n",
|
||
|
"Epoch 5/50\n",
|
||
|
"15/15 [==============================] - 10s 694ms/step - loss: 0.7762 - categorical_accuracy: 0.6633 - val_loss: 0.7767 - val_categorical_accuracy: 0.6417\n",
|
||
|
"Epoch 6/50\n",
|
||
|
"15/15 [==============================] - 10s 700ms/step - loss: 0.7298 - categorical_accuracy: 0.6733 - val_loss: 0.8494 - val_categorical_accuracy: 0.6333\n",
|
||
|
"Epoch 7/50\n",
|
||
|
"15/15 [==============================] - 10s 689ms/step - loss: 0.7398 - categorical_accuracy: 0.6733 - val_loss: 0.8149 - val_categorical_accuracy: 0.6417\n",
|
||
|
"Epoch 8/50\n",
|
||
|
"15/15 [==============================] - 10s 680ms/step - loss: 0.7302 - categorical_accuracy: 0.6900 - val_loss: 0.7711 - val_categorical_accuracy: 0.6333\n",
|
||
|
"Epoch 9/50\n",
|
||
|
"15/15 [==============================] - 10s 693ms/step - loss: 0.6948 - categorical_accuracy: 0.7033 - val_loss: 0.8223 - val_categorical_accuracy: 0.6167\n",
|
||
|
"Epoch 10/50\n",
|
||
|
"15/15 [==============================] - 11s 713ms/step - loss: 0.7074 - categorical_accuracy: 0.6967 - val_loss: 0.8049 - val_categorical_accuracy: 0.6833\n",
|
||
|
"Epoch 11/50\n",
|
||
|
"15/15 [==============================] - 10s 693ms/step - loss: 0.6856 - categorical_accuracy: 0.6967 - val_loss: 0.7708 - val_categorical_accuracy: 0.6750\n",
|
||
|
"Epoch 12/50\n",
|
||
|
"15/15 [==============================] - 10s 693ms/step - loss: 0.6761 - categorical_accuracy: 0.7233 - val_loss: 0.7492 - val_categorical_accuracy: 0.7000\n",
|
||
|
"Epoch 13/50\n",
|
||
|
"15/15 [==============================] - 10s 693ms/step - loss: 0.6434 - categorical_accuracy: 0.7100 - val_loss: 0.7391 - val_categorical_accuracy: 0.7083\n",
|
||
|
"Epoch 14/50\n",
|
||
|
"15/15 [==============================] - 10s 686ms/step - loss: 0.6191 - categorical_accuracy: 0.7400 - val_loss: 0.8185 - val_categorical_accuracy: 0.6333\n",
|
||
|
"Epoch 15/50\n",
|
||
|
"15/15 [==============================] - 10s 660ms/step - loss: 0.6553 - categorical_accuracy: 0.7433 - val_loss: 0.7689 - val_categorical_accuracy: 0.6917\n",
|
||
|
"Epoch 16/50\n",
|
||
|
"15/15 [==============================] - 10s 673ms/step - loss: 0.6007 - categorical_accuracy: 0.7433 - val_loss: 0.7495 - val_categorical_accuracy: 0.7000\n",
|
||
|
"Epoch 17/50\n",
|
||
|
"15/15 [==============================] - 10s 693ms/step - loss: 0.6047 - categorical_accuracy: 0.7400 - val_loss: 0.7563 - val_categorical_accuracy: 0.6833\n",
|
||
|
"Epoch 18/50\n",
|
||
|
"15/15 [==============================] - 10s 680ms/step - loss: 0.5955 - categorical_accuracy: 0.7633 - val_loss: 0.7878 - val_categorical_accuracy: 0.6750\n",
|
||
|
"Epoch 19/50\n",
|
||
|
"15/15 [==============================] - 10s 680ms/step - loss: 0.6175 - categorical_accuracy: 0.7400 - val_loss: 0.7690 - val_categorical_accuracy: 0.6917\n",
|
||
|
"Epoch 20/50\n",
|
||
|
"15/15 [==============================] - 10s 680ms/step - loss: 0.6363 - categorical_accuracy: 0.7567 - val_loss: 0.7529 - val_categorical_accuracy: 0.6583\n",
|
||
|
"Epoch 21/50\n",
|
||
|
"15/15 [==============================] - 10s 673ms/step - loss: 0.6516 - categorical_accuracy: 0.7300 - val_loss: 0.8479 - val_categorical_accuracy: 0.6333\n",
|
||
|
"Epoch 22/50\n",
|
||
|
"15/15 [==============================] - 10s 693ms/step - loss: 0.5844 - categorical_accuracy: 0.7667 - val_loss: 0.8102 - val_categorical_accuracy: 0.6500\n",
|
||
|
"Epoch 23/50\n",
|
||
|
"15/15 [==============================] - 10s 674ms/step - loss: 0.6111 - categorical_accuracy: 0.7233 - val_loss: 0.7905 - val_categorical_accuracy: 0.7250\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Train the model\n",
|
||
|
"\n",
|
||
|
"history_augmented = train_model(lsun_new_model, train_generator_aug, valid_generator_aug, epochs=50)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Plot the learning curves"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 47,
|
||
|
"metadata": {
|
||
|
"scrolled": true
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4IAAAFNCAYAAABVKNEpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzs3Xd4VVX28PHvSq+0EAgQIHQCaUDo0hQQEFARBRQFFbGOOvZxxvn5zoxlbKPOOCoqRUSQARFRmgUFpIaaQOiEXkLohISU/f6xbzCEBEKSm5uyPs9zn+Sefcq6l3DPXWfvs7YYY1BKKaWUUkopVXm4uToApZRSSimllFKlSxNBpZRSSimllKpkNBFUSimllFJKqUpGE0GllFJKKaWUqmQ0EVRKKaWUUkqpSkYTQaWUUkoppZSqZDQRVEq5jIiEiYgREQ9Xx6KUUko5g4i8LCJfuDoOpfLSRFBVSCLyi4icEBFvV8eilFJKlWcikiQivV0dh1KqZGkiqCocEQkDugEGGFzKx9aeLaWUUkopVeZpIqgqonuAFcBEYFTuBhHxFZG3RWSPiJwSkaUi4utou05ElonISRHZJyKjHct/EZExufYxWkSW5npuRORREdkObHcse8+xj9MiskZEuuVa311EXhSRnSJyxtFeX0Q+EJG388Q7R0SezPsCReQjEXkrz7LZIvKU4/fnReSAY/9bReSGwrxxIlJXRGaKSLKI7BaRx3O1vSwiM0TkK8d+14pIdK72cMd7dVJENonI4FxtBb7vDneJyF4ROSYif861XQcRiXO8j0dE5J3CvA6llFKlQ0QeEJEdInJcRL4VkbqO5SIi/xKRo47P/Y0iEuFoGyAimx3nkgMi8kw++/V2nE8ici0LFpHzIlJLRGqKyHeOdY6LyBIRKdT3WhEZKCLrHdsuE5GoXG1JIvInR3wnRGSCiPhc7fU62lqLyA+OtiMi8mKuw3qJyOeO17xJRGJzbVekc7ZSxWaM0Yc+KtQD2AE8ArQDMoDaudo+AH4B6gHuQBfAG2gAnAFGAJ5AEBDj2OYXYEyufYwGluZ6boAfgBqAr2PZSMc+PICngcOAj6PtWSAeaAEIEO1YtwNwEHBzrFcTSM0df65jdgf2AeJ4Xh04D9R17HcfUNfRFgY0KcT75gasAf4KeAGNgV3AjY72lx3v51DHe/QMsNvxu6fjfX/Rse31jvezxVXe9zDH+/cJ4Ot4L9KBcMd2y4G7Hb8HAJ1c/felD33oQx+V7QEkAb3zWX49cAxo6/hM/zew2NF2o+OcUs1xrgsH6jjaDgHdHL9XB9oWcNzxwCu5nj8KzHf8/hrwUa5zULecc+JVXktb4CjQ0XE+GuV4fd65XmsCUB97Xv8N+EchXm+g43U9Dfg4nnd0tL0MpAEDHMd8DVjhaCvSOVsf+iiJh/YIqgpFRK4DGgLTjTFrgJ3AnY42N+A+4AljzAFjTJYxZpkxJh24C/jRGDPVGJNhjEkxxqy/hkO/Zow5bow5D2CM+cKxj0xjzNvYE0YLx7pjgL8YY7Yaa4Nj3VXAKSDnSuBw4BdjzJF8jrcEm0Dl9DQOBZYbYw4CWY7jtRIRT2NMkjFmZyFeQ3sg2BjzN2PMBWPMLmyCNjzXOmuMMTOMMRnAO9iTXSfHIwB43bHtz8B3wIirvO85/p8x5rwxZgOwAZsQgk08m4pITWPMWWPMikK8DqWUUqXjLmC8MWat4zP9T0BnsbdoZGCToZbYBC3RGHPIsV0G9hxVxRhzwhiztoD9f4m9QJvjTseynH3UARo6zttLjDGmEDE/AHxsjFnpOB9Nwl6A7JRrnf8YY/YZY44Dr+SK4UqvdyBw2BjztjEmzRhzxhizMtc+lxpj5hpjsoDJ/H6eK+o5W6li00RQVTSjgIXGmGOO51/y+/DQmtjEJb8P2PoFLC+sfbmfiMjTIpLoGA5zEqjqOP7VjjUJ25uI4+fk/FZynOym8fvJ6U5giqNtB/Ak9grkURGZlnvoyhU0BOo6hsqcdMT9IlA7v9dpjMkG9mN7IesC+xzLcuzB9gBe6X3PcTjX76nYpBLgfqA5sEVEVovIwEK8DqWUUqWjLvazHgBjzFkgBajnuCD4H+yIkCMiMk5EqjhWvQ3bO7ZHRH4Vkc4F7P9nwFdEOopIQyAGmOVoexM7EmWhiOwSkRcKGXND4Ok857r6jteSI/c5fU+utgJfL1f/HpH3POcjIh7FOGcrVWyaCKoKw3HP2R1ADxE5LCKHgT8C0Y572Y5hh2Y0yWfzfQUsBzgH+OV6HpLPOhevQoq9H/B5RyzVjTHVsD19UohjfQHc7Ig3HPimgPUApgJDHSfHjsDMi8EY86UxJqd31AD/vMJ+cuwDdhtjquV6BBpjBuRap36u1+kGhGKHsx4E6ue5P6MBcIArv+9XZIzZbowZAdRyvIYZIuJ/rftRSinlFAex5xkAHJ/PQdjPfowx7xtj2gGtsRf1nnUsX22MuRn72f4NMD2/nTsuLk7HXvS8E/jOGHPG0XbGGPO0MaYxMAh4qpD31u3DDjfNfa7zM8ZMzbVO/Vy/N3C8zqu93iud26+oiOdspYpNE0FVkdyCHWLRCnvVMAabTC0B7nGcUMYD74gtiuIuIp3FTjExBegtIneIiIeIBIlIjGO/64EhIuInIk2xvVRXEghkAsmAh4j8FaiSq/1T4O8i0sxxM32UiAQBGGP2A6uxPYEzc4aa5scYs85xjE+BBcaYkwAi0kJErne8rjTsvYNZV3/7WAWcdty07ut4fyJEpH2uddqJyBCx1VGfxA6nWQGsxCbMz4mIp4j0xJ6Yp13lfb8iERkpIsGOfZx0LC7Ma1FKKVWyPEXEJ9fDAzvq5l4RiXF8pr8KrDTGJIlIe0dPnif2/JAGZImIl4jcJSJVHbcZnObKn+tfAsOwwzJzhoXmFHxpKiKSax+FOT98AjzkiE1ExF9EbhKRwFzrPCoioSJSAzsy5qtcseT7erG3Q4SIyJNiC90EikjHqwVTjHO2UsWmiaCqSEYBE4wxe40xh3Me2KEpdzlOWs9gC7WsBo5jr7q5GWP2YoepPO1Yvp7fx+//C7gAHMEO3ZxylTgWAPOAbdghJGlcOszkHewVzoXYk9dn2EIpOSYBkRQwLDSPqUBvcp0csfcavI7tiTuMveL6IoDj5Lspvx057lsYhE2gdzu2/xQ7rDXHbOwJ+QRwNzDEcW/GBexUHf0d2/0Xm3xvcWyX7/teiNfXD9gkImeB94Dhxpi0QmynlFKqZM3FJik5j5eNMT8BL2FHpBzC9ojl3FdeBZt0ncCeC1OAnGrXdwNJInIaeIjfb4m4jOM+u3PYYZnzcjU1A34EzmILi/3XGPMLgIjMk0srdubeXxz2PsH/OGLbgS0Cl9uX2HP0LsfjH45tC3y9jp7KPtjz6GFsFfFeBb2uXAo8ZyvlbDkVB5VSZYSIdMcOEQ3Lc8+dS4nIy0BTY0yBJ2yllFKqPBORJGyl8B9dHYtSzqY9gkqVIY4hNE8An5alJFAppZRSSlUsmggqVUaISDj2Prg6wLsuDkcppZRSSlVgOjRUKaWUUkoppSoZ7RFUSimllFJKqUpGE0GllFJKKaWUqmQ8XB1ASapZs6YJCwtzdRhKKaWcbM2aNceMMcGujqO80POjUkpVHoU9R1aoRDAsLIy4uDhXh6GUUsrJRGSPq2MoT/T8qJRSlUdhz5E6NFQppZRSSimlKhlNBJVSSimllFKqktFEUCmllFJKKaUqmQp1j6BSSilVnonIeGAgcNQYE5FPe0tgAtAW+LMx5q1SDlEppa5ZRkYG+/fvJy0tzdWhVCg+Pj6Ehobi6elZpO01EVRKKaXKjonAf4DPC2g/DjwO3FJaASmlVHHt37+fwMBAwsLCEBFXh1MhGGNISUlh//79NGrUqEj70KGhSimlVBlhjFmMTfYKaj9qjFkNZJReVEopVTxpaWkEBQVpEliCRISgoKBi9bJqIqiUUkpVQCIyVkTiRCQuOTnZ1eEopSo5TQJLXnHfU00ElVJKqQrIGDPOGBNrjIkNDr7qvMJKKVVhpaS
|
||
|
"text/plain": [
|
||
|
"<Figure size 1080x360 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Run this cell to plot accuracy vs epoch and loss vs epoch\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(15,5))\n",
|
||
|
"plt.subplot(121)\n",
|
||
|
"try:\n",
|
||
|
" plt.plot(history_augmented.history['accuracy'])\n",
|
||
|
" plt.plot(history_augmented.history['val_accuracy'])\n",
|
||
|
"except KeyError:\n",
|
||
|
" try:\n",
|
||
|
" plt.plot(history_augmented.history['acc'])\n",
|
||
|
" plt.plot(history_augmented.history['val_acc'])\n",
|
||
|
" except KeyError:\n",
|
||
|
" plt.plot(history_augmented.history['categorical_accuracy'])\n",
|
||
|
" plt.plot(history_augmented.history['val_categorical_accuracy'])\n",
|
||
|
"plt.title('Accuracy vs. epochs')\n",
|
||
|
"plt.ylabel('Accuracy')\n",
|
||
|
"plt.xlabel('Epoch')\n",
|
||
|
"plt.legend(['Training', 'Validation'], loc='lower right')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(122)\n",
|
||
|
"plt.plot(history_augmented.history['loss'])\n",
|
||
|
"plt.plot(history_augmented.history['val_loss'])\n",
|
||
|
"plt.title('Loss vs. epochs')\n",
|
||
|
"plt.ylabel('Loss')\n",
|
||
|
"plt.xlabel('Epoch')\n",
|
||
|
"plt.legend(['Training', 'Validation'], loc='upper right')\n",
|
||
|
"plt.show() "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Do you see an improvement in the overfitting? This will of course vary based on your particular run and whether you have altered the hyperparameters."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Get predictions using the trained model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 48,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Found 300 images belonging to 3 classes.\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Get model predictions for the first 3 batches of test data\n",
|
||
|
"\n",
|
||
|
"num_batches = 3\n",
|
||
|
"seed = 25\n",
|
||
|
"test_generator = get_generator(image_gen_aug, test_dir, seed=seed)\n",
|
||
|
"predictions = lsun_new_model.predict_generator(test_generator, steps=num_batches)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 49,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Found 300 images belonging to 3 classes.\n",
|
||
|
"[26 14 27 55]\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAtUAAAK8CAYAAAAgSEP5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvXmcXXdd//98z75ln+yTrUm6b5Q2pQtYWpCCCqggqF8URSv+xOXrV1zRb1VUFBdUkP4QsVYF0QJaaxGxCKV0S7c0adqmadZJMkkm++wz936+f7zf73vvOZm5M9NJMqF5Px+P5M4953M+n88598ydc1/3dV5vSSkRBEEQBEEQBMHLp2a6JxAEQRAEQRAE3+7ERXUQBEEQBEEQTJG4qA6CIAiCIAiCKRIX1UEQBEEQBEEwReKiOgiCIAiCIAimSFxUB0EQBEEQBMEUiYvqIAiCl4GI9IjIeVPs404R+fAE264UkSQidfb8yyLyo1MZv6Lv14rICxXPd4jIG05F39bfsyJy06nq70wiIjeJSOcE294uIv9wuudUZfwkImvs5ztE5DdfZj9TPrenC1H+VkSOiMhj0z2f4NyibronEATBuYuI/BDwi8CFwAngaeD3UkoPTmDbBKxNKW09vbMcnZRS23SMWzH+myfSbiLHKaX0TeCCUzEvEbkT6Ewpfaii/0tORd8TGDsBB4ClKaURW1YH7AXmp5TkTMzjbCCl9P6JtBORrwP/kFL6dMW203puT5EbgTcCHSml3umeTHBuEUp1EATTgoj8IvAx4PeBhcBy4K+At03nvMbDleJXCq+0/QGOApUfON4CHJmmubxsRKR2uufwbcoKYMfLuaA+238Xzvb5BWfJRbV9BfqO09DvTSJy76nuN9f/9VXW95yusYPg2xkRmQX8DvAzKaUvppR6U0rDKaV/Tyl90NqsE5GHReSoiOwTkY+LSIOte8C62mBfVb/Lln+3iDxt2zwkIpdXjHmViDwlIidE5F9E5POV1gsR+UkR2Soih0XkHhFZUrEuicjPiMiLwIsVy/yr9mYR+RMR2Skix0TkQRFptnX/IiJdtvwBEZmQaisitSLyxyLSLSLbgO/Krf+6iPyE/bxGRL5hY3SLyOfHOk5uZxCRXxGRLuBvx7A4XCMim+1r9L8VkSbr870ikvkmwY+FiNwG/DDwyzbev9v6kp1ERBpF5GMistf+fUxEGm2dz+3/iMgBe91/bCLHq4K/B36k4vmPAHfl5rvEXuPD9pr/ZMW6ZvubdERENgPXjLLtF0TkoIhsF5Gfm8ikKvbt1+012iEiP1yx/k4R+aSI3CcivcDr7Vj9sYjsEpH9opaO5optPmjHaK+I/HhuvIy1SETeZr8bx0XkJRG5VUR+D3gt8HF7vT5ubSvP7Vkicpft704R+ZCI1Ni699q5/sd2vLaLyIS+QbHtl4nIF63vQxXj19g4O+08uEv0PaPSBvWjdly6ReQ3bN37gE8D19n+/LYtr/a+sMN+F54BekWkrtprLGrx+Web0wlRa9PV4+2TrftxEXnOjtVXRGTFBI7RaO8914vIetHf9/VScR0yzrl9u+j70T/Y3DeKyPki8mt2nHeLyHdO9PULRiGlNO3/gDuBd0xh+9oxlt8E3Hsa53078EtV1vecgjFG3bf4F/++nf8BtwIjQF2VNq8GXoPa1FYCzwG/ULE+AWsqnl+FfvV/LVAL/CiwA2gEGoCdwM8D9cD3AUPAh23bm4Fu66MR+EvggdxYXwXmAs358YFPAF8HltrY1wONtu7HgRnW78eApyv6vdPnMMr+vx94Hlhm4/6PjVln678O/IT9/DngN1ChpAm4scpxusmO/R/anJptWWdFmx3Apoqxv1VxrN4LPJiba+WxOGmfrL832M+/AzwCLADmAw8Bv5ub2+/Y6/QWoA+YM8HzKgGXAvuB2fZvvy1LFe2+gX4r0gRcCRwEbrF1HwG+afu9zI5Dp62rAZ4Afgs9p84DtgFvsvW3o1aK0ebm+/andty/A+gFLqg4bseAGypex48B99hcZgD/DvxBxe+Q71sr8NmxXgdgnfX9Rut7KXBh/jwa4/W8C/g3G38lsAV4X8W5MAz8JHre/zRqtRFb/6uM8TfY2m8A/szmXzpv0d+ZrXZ824AvAn9v61ba/P4aPXevAAaBi0Y7P6nyvlBxbj5tr3XzBF/jAfTcrAX+AHhkAvv0dtuni9D3tA8BD03wnC6999jjEeA91s8P2vN5Ezi3fe5vsm3vAraj7x319jpun86/Dd/u/6ZnUFUOnrGT7+/RX/6/QN9ct2EX2OQuioGPA++t+EX4LeBB4N3AGuC/rc8ngdW2/deBu9E/Tv+I/bKPMa9bgKeAjcBncr907fbz1dbnSqAL2GO/kK8FVgEPA+uB38UuqgEBPoq+OW8E3jXO8pvQP6CfBTZP90kS/+Lfqf6Hqpldk9zmF4AvVTzPXyx+Ers4q1j2Anrx8jr7XZWKdQ9Svuj4G+CPKta1oRcLKyvGujnXd7L3nRqgH7hiAvsw27abZc/vZOyL6q8B7694/p2MfVF9F/Ap1Eea72e0i+ohoCm3LH9RXTn2W4CX7Of3MrWL6peAt1SsexP6db3Po5+KD1voBdFrJniO+GvyaeCn0A8mf23LkrVZBhSAGRXb/QFwp/28Dbi1Yt1tlC+qrwV25cb8NeBv7efbGf+iurVi2T8Dv1lx3O6qWCfoRffqimXXYRc96N+oj1SsO3+s1wH4/4E/G2NepfNolONYi16wXlyx7qeAr1ecC1sr1rXYtosm8Fpdh17wnfTBGrgf+P8qnl+A/j76B+xExbkOPAa8e7TzkyrvCxXn5o9XrJvIa/zfFesuBvonsE9fxj6M2PMa9APjigmc0zdXPH8P8FiuzcO23+Od27cDX61Y9z1ADybeoR+cEjB7Ir9v8e/kf2fcnyP61edvADeklLpFZC76yX0xeoPBhegn87sn0N1ASulG6/dR9A3mS6JfU9agJ9irgEvQT8/fQlWAk26Csm3uRD/RbRGRu9BP3R8bbeCU0g4RuQO9cP5j6+Me4JMppbtE5Gcqmn8f+onxCqAdWC/6tez1YywHVRYuTSltn8BxCIJvNw4B7SJSl+yGsjwicj763nA1+se6DlWQxmIF8KMi8rMVyxqAJegfij3J/nIYuyt+XoJ+GAcgpdQjIodQRW/HKO0raUdVoZdG2Yda4PeAd6KqbLFim2NV9sXnVDnmziptfxn9IP+YiBwB/iSl9Jkq7Q+mlAbGGT8/9pKxGk6SJWT3Jd/3odw50Yd+yJkMd6EXEwL8yijjH04pncjN4eqK9WMd9xXAEhE5WrGsFlW2J8KRlPX65ve9ctz56Hn/hEjp/kqx8Xyelb8P1c6PZcB9E5xjJe2Uv+WpHGdpxfMu/yGl1GdzncjrtQzYOcbv/2jnSB1678VJ41L9HKn2vuDszrUf7zXOj90k6neutk8rgD8XkT+pWCbosaz22uXnlz82UH5Nxju3Qb/dcPqB7pRSoeI56LGs3P9ggkyHp/pm4O6UUjdASumwLf/XlFIxpbSZ7C9ONdw3OAO92/tL1udASqnP2jyWUupMKRVRRXnlGH1dgCoAW+z536Hq1mS4Af0aFlSBd24EPpdSKqSU9qNfz1xTZbnPOy6og1cqD6NfQ769SptPot8wrU0pzQR+Hf0jNBa70eSQ2RX/WlJKnwP2AUul4uoE/QPo7EX/6AEgIq3APFTddiovyCvptn1ZPcq6H0JvvHwDMIvy+89EUij25ea4fKyGKaWulNJPppSWoEriX7kndqxNJjB+fuy99nMverEHgIgsmmTfmWOd6/tU8U1UqFnIySLKXmCu/d2onIO/1tWO+27070TlOTYjpfSWCc5rjp1blX1X7nvlsetGL3IuqRhrVionc0z4/LB5j3Z+5sfM040qxPnXa8/ozSfFbmC5jH7z3WjnyAjZC8LJjDPW+4KT/7D9cl/javu0G/ipXL/NKaWHJtBv5fzyxwbKr8l453ZwmpmOi2ph9F/iwVwb0F+iyjk25bbxT/zV/kBV9ltg7Bj
|
||
|
"text/plain": [
|
||
|
"<Figure size 1152x864 with 8 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Run this cell to view randomly selected images and model predictions\n",
|
||
|
"\n",
|
||
|
"# Get images and ground truth labels\n",
|
||
|
"test_generator = get_generator(image_gen_aug, test_dir, seed=seed)\n",
|
||
|
"batches = []\n",
|
||
|
"for i in range(num_batches):\n",
|
||
|
" batches.append(next(test_generator))\n",
|
||
|
" \n",
|
||
|
"batch_images = np.vstack([b[0] for b in batches])\n",
|
||
|
"batch_labels = np.concatenate([b[1].astype(np.int32) for b in batches])\n",
|
||
|
"\n",
|
||
|
"# Randomly select images from the batch\n",
|
||
|
"inx = np.random.choice(predictions.shape[0], 4, replace=False)\n",
|
||
|
"print(inx)\n",
|
||
|
"\n",
|
||
|
"fig, axes = plt.subplots(4, 2, figsize=(16, 12))\n",
|
||
|
"fig.subplots_adjust(hspace=0.4, wspace=-0.2)\n",
|
||
|
"\n",
|
||
|
"for n, i in enumerate(inx):\n",
|
||
|
" axes[n, 0].imshow(batch_images[i])\n",
|
||
|
" axes[n, 0].get_xaxis().set_visible(False)\n",
|
||
|
" axes[n, 0].get_yaxis().set_visible(False)\n",
|
||
|
" axes[n, 0].text(30., -3.5, lsun_classes[np.where(batch_labels[i] == 1.)[0][0]], \n",
|
||
|
" horizontalalignment='center')\n",
|
||
|
" axes[n, 1].bar(np.arange(len(predictions[i])), predictions[i])\n",
|
||
|
" axes[n, 1].set_xticks(np.arange(len(predictions[i])))\n",
|
||
|
" axes[n, 1].set_xticklabels(lsun_classes)\n",
|
||
|
" axes[n, 1].set_title(f\"Categorical distribution. Model prediction: {lsun_classes[np.argmax(predictions[i])]}\")\n",
|
||
|
" \n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Congratulations! This completes the first part of the programming assignment using the tf.keras image data processing tools."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Part 2: tf.data\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"#### The CIFAR-100 Dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"In the second part of this assignment, you will use the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html). This image dataset has 100 classes with 500 training images and 100 test images per class. \n",
|
||
|
"\n",
|
||
|
"* A. Krizhevsky. \"Learning Multiple Layers of Features from Tiny Images\". April 2009 \n",
|
||
|
"\n",
|
||
|
"Your goal is to use the tf.data module preprocessing tools to construct a data ingestion pipeline including filtering and function mapping over the dataset to train a neural network to classify the images."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Load the dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz\n",
|
||
|
"169009152/169001437 [==============================] - 3s 0us/step\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Load the data, along with the labels\n",
|
||
|
"\n",
|
||
|
"(train_data, train_labels), (test_data, test_labels) = cifar100.load_data(label_mode='fine')\n",
|
||
|
"with open('data/cifar100/cifar100_labels.json', 'r') as j:\n",
|
||
|
" cifar_labels = json.load(j)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Display sample images and labels from the training set"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {
|
||
|
"scrolled": false
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2cAAAHcCAYAAABS2dL6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvXe0ZdldHvjbN6cX64WqVzl1pc5BrW4UWllCCDRCApYJAzIsM3hh42XC2IZBY2Nje81gYxDGM8PAQgQhjwJYIim1ulFLndSpQldOr+rlcN/N8cwfe5/zfee9e6uqq25XvWp+31q1ar9zzz1nx9/ed/++/f2M53miUCgUCoVCoVAoFIpbi8itzoBCoVAoFAqFQqFQKPTHmUKhUCgUCoVCoVCsC+iPM4VCoVAoFAqFQqFYB9AfZwqFQqFQKBQKhUKxDqA/zhQKhUKhUCgUCoViHUB/nCkUCoVCoVAoFArFOoD+OOsRjDGeMWbPrc6H4vpgjHncGPOTtzofCsV6xuth54wxnzDG/FEvn/ka33/OGPPuW/X+9ZIHxbXBGPMHxphfu4b7rnmsXOleY8wPG2P+9rXmU/H6QG3g+oQxZodrm9itzksvoD/OFArFdeP1NOrGmMeMMZOvx7PfqHijTVAKxd93eJ73x57nvfdW5+N2gdpAxWvFetyc1x9nCoXidYFOjrcf3sht9kYum0Kh6A3eyHbidijb7ZDHm4E39I8zY8xWY8znjDFzxpgFY8xvG2MixphfNsacN8bMGmP+0Bgz4O5fs1PPngFjTNQY8y+NMaeNMQVjzPPGmK10+7uNMSeNMUvGmE8aY8xNLO5tjS5ttdsY8zX397wx5o+NMYP0nXPGmF8wxrxsjCkZY37PGDNujPkr1z5fMcYM0f1vNsY8ZYxZNsa8ZIx57JYU9g0CY8ynRGSbiPwPY0zRGPOLbsfyHxpjLojI19x932uMOeLq/XFjzAF6Rogi4lOGjDFZEfkrEZlwzy4aYyZuchE74jrsir+T+xPGmIvOPvy0MeYh13eXjTG/Tc//cWPMN40xv2WMyRtjXjXGvIs+D3krV1FinnD/L7s6e8Td83FjzDH37r8xxmyn73vGmH9sjDkpIidfQz28xZXnHe7v/caYLxtjFo0xx40xP+CuP2SMmeFJ1xjz/caYF+lxKWPMn7lx+x1jzD107wHXb5ZdP/pe+uyDxpgXjDErLi+foM/8el/dH3/UtdOCMeZfXWt5bwLudf0h7+oiJSJijPkpY8wpV69/4Y8D08FDYGgH2BizxxjzDfe8eWPMn9F9HdtKsRbGmPtcnyy4OkzRZx3bpsMz/sAY87uuzguuXbavuq3j+sHZg7+jZ3nOftyytYZRG+h/T21gD+Da85eMMS+LSMkYc9cVyps2xvyfLv95Y8zfGWPSHZ75/e65d7q/O67/jDH/VkTeKiK/7frLb69+1i2B53lvyH8iEhWRl0TkP4lIVqxBfYuIfFxETonILhHJicjnRORT7juPicjkquecE5F3u/QviMgrIrJPRIyI3CMiG9xnnoh8UUQGxS5Y50Tk/be6Hm6Hf1doqz0i8h4RSYrIqFij+59Xtc23RWRcRDaLyKyIfEdE7nPf+ZqI/Kq7d7OILIjId4vdlHiP+3vUff64iPzkra6L2+3fqvGxw42DP3TtmBaRO0Sk5Oo7LiK/6MZfwn3HE5E99Lw/EJFfc+k14/FW/7tOu+LXy++6+98rIlUR+YKIjFHffbu7/8dFpCki/8zV2Q+KSF5EhlfXufv7EyLyR6veFaPPP+zydkBEYiLyyyLyFH3uiciXRWRYRNJXKb/nxuX7ROSiiLzJXc+6v3/CveN+EZkXkUPu86Mi8gF6zudF5J9T/hsi8lFX3p8XkbMuHXd5/5cikhCRd4pIQUT2UR+5S+yYvltEZkTkw1fojwdFpCgibxNrI37D1fW7r1TumzSOnhGRCdcOx0Tkp1155119JkXkt0TkiSu09ePi7JiI/KmI/CtXNykRecu1tJX+C7VLQkTO01j8qOurv3altuGx4tJ/4Pqt3+9+U0T+btW9HdcPYu3BNd17k+pEbaDawF73qXMi8qKIbBWRvquU95Ni7dxmsX3xUVeOoN1dG5wSjL/bbv13yzPwOjb2I2KNVmzV9a+KyM/Q3/vcoIjJ1X+cHReR7+vyPk/c5Of+/oyI/K+3uh5uh3/d2qrDfR8WkRdWtc0P09+fFZH/Sn//rIh8waV/SdxEQZ//jYj8zy697gbn7fBPOv8420Wf/4qIfIb+jojIJRF5zP19u/04ux674tfLZvp8QUR+kP7+rIj8nEv/uIhcFhFDnz8jIj+6us7d35+QKy9M/kpE/uGqNiiLyHZqg3deY/k9EfkXYhesd9H1HxSRJ1fd+98EmyO/JCJ/7NLD7v2bKP/fXpW/KbG7mW8VkWkRidDnfyoin+iSv/8sIv/pCv3xfxORT9PfWRGpy/r4cfYj9Pd/FLuQ/T0R+Y90Pef61Y4ubf244MfZH4rI/yUiW1a964ptpf9C9fK2DmPxKbE/zrq2jft79Y+zT6+6tyUiW+nejusH6fzj7JatNURtoNrA3vepcyLycZfuWl5XLxURuafDM/yy/rzYH8Jb6LPbbv33RqY1bhWR857nNVddnxA7qHycF2s8xq/xmaev8Pk0pctiDbDi6ujYVsaYMWPMp40xl4wxKyLyRyIysuq7M5SudPjbb4PtIvIx59JeNsYsi93t29TLgihExO4e+giNN8/z2u7zzTc7Uz3CjdiVa+2rIiKXPDdr0POul9a5XUR+k/r9oljPP7fBxY7f7IyfE/uD+5VV73h41fj6YRHZ6D7/IxH5kDEmJyI/IHYRM9Xp/a6PTIot74SIXHTXfJz3826MedgY83VHr8qL9TatthGr+yO/qyR2kbge0Gn+WD1+imLzey3j5xfFtvMzjhr0cXf9am2lACak81j0P3stbXNx1b2LEh7Tr2X9cCvXGmoD1Qa+HvDzdKXyjoj1vF5pHf4LIvJJz/P4iNJtt/57I/84uygi28zaw4WXxTaUj21iXbozYulXGf8DY0xULJ2On7n7dcnt3290a6tfF7sTcrfnef0i8iNiDer1vuNTnucN0r+s53n//vqzrRDbPle6FhpvxhgjdnK/5C6VhcachBeInZ59q3E9duV6sNnVFT/vskuH7JRcvc4uisg/WtX3057nPXWV73XDx0Tkw8aYn1v1jm+sekfO87z/RUTE87xLIvItEfmfRORHReRTq54ZnN01xkREZIvY8l4Wka3umo9tgv7zJyLyF2I9EANivU2rbQSXbWrVuzIisuGaS37zsXr8ZMXm95LYfiDSpS94njfted5PeZ43ISL/SER+x9jznVdsK0UIU9J5LIpcuW06gftdTqz35HKXe9cz1AaqDXw94OfxSuWdF0uHvdI6/L0i8svGmO+na1db/627tcYb+cfZM2I74b83xmSNMSljzHeJdY/+M2PMTmcg/52I/JnbBToh9lDmB40xcbG85CQ98/8RkX9jjNlrLO42xqyHTn27o1tb9YnlRi8bYzaL3RG5Xvi7Vu8zVtglZawAzJYbz/7fa8yIPWPQDZ8RkQ8aY97lxtQ/F5GaWGqQiOWZ/wPXJu8XkbevevYG4w6VrxNcj125HoyJyD8xxsSNMR8Te1biL91nL4rID7nPHhR7TsHHnIi0Jdwmvysi/8IYc0hExBgz4J55vbgsIu9y+fsZd+2LInKHsQfN4+7fQ4bEX8TS7H5R7PmIz6965gPGmI+4Bd/Pie0j3xaRp8UuxH7RPfMxEfmQiHzafa9PRBY9z6saY94kIv/gKnn//0Tke4w9yJ8QkX8t63se/BMR+QljzL3GmKTYfvW053nnPM+bE7tg+RE3fj4utGgxxnyM7NuS2AVIS66trRQW3xL7A+OfGGNixpiPiMib3Gdd26bLs76b+t2/cfe+Fm/NeoHaQLWBrye6ltd50/5fEfkNY8yEs3uPuPHn44iIvF9EPmkgJHK19d/V1jE3HeupQXoKz/NaYht0j4hcEOsi/kGxDfspseI
|
||
|
"text/plain": [
|
||
|
"<Figure size 1080x576 with 32 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Display a few images and labels\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(15,8))\n",
|
||
|
"inx = np.random.choice(train_data.shape[0], 32, replace=False)\n",
|
||
|
"for n, i in enumerate(inx):\n",
|
||
|
" ax = plt.subplot(4, 8, n+1)\n",
|
||
|
" plt.imshow(train_data[i])\n",
|
||
|
" plt.title(cifar_labels[int(train_labels[i])])\n",
|
||
|
" plt.axis('off')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Create Dataset objects for the train and test images\n",
|
||
|
"\n",
|
||
|
"You should now write a function to create a `tf.data.Dataset` object for each of the training and test images and labels. This function should take a numpy array of images in the first argument and a numpy array of labels in the second argument, and create a `Dataset` object. \n",
|
||
|
"\n",
|
||
|
"Your function should then return the `Dataset` object. Do not batch or shuffle the `Dataset` (this will be done later)."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#### GRADED CELL ####\n",
|
||
|
"\n",
|
||
|
"# Complete the following function. \n",
|
||
|
"# Make sure to not change the function name or arguments.\n",
|
||
|
"\n",
|
||
|
"def create_dataset(data, labels):\n",
|
||
|
" \"\"\"\n",
|
||
|
" This function takes a numpy array batch of images in the first argument, and\n",
|
||
|
" a corresponding array containing the labels in the second argument.\n",
|
||
|
" The function should then create a tf.data.Dataset object with these inputs\n",
|
||
|
" and outputs, and return it.\n",
|
||
|
" \"\"\"\n",
|
||
|
" \n",
|
||
|
" dataset = tf.data.Dataset.from_tensor_slices((data,labels))\n",
|
||
|
" return dataset\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Run the below cell to convert the training and test data and labels into datasets\n",
|
||
|
"\n",
|
||
|
"train_dataset = create_dataset(train_data, train_labels)\n",
|
||
|
"test_dataset = create_dataset(test_data, test_labels)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"(TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), TensorSpec(shape=(1,), dtype=tf.int64, name=None))\n",
|
||
|
"(TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), TensorSpec(shape=(1,), dtype=tf.int64, name=None))\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Check the element_spec of your datasets\n",
|
||
|
"\n",
|
||
|
"print(train_dataset.element_spec)\n",
|
||
|
"print(test_dataset.element_spec)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Filter the Dataset\n",
|
||
|
"\n",
|
||
|
"Write a function to filter the train and test datasets so that they only generate images that belong to a specified set of classes. \n",
|
||
|
"\n",
|
||
|
"The function should take a `Dataset` object in the first argument, and a list of integer class indices in the second argument. Inside your function you should define an auxiliary function that you will use with the `filter` method of the `Dataset` object. This auxiliary function should take image and label arguments (as in the `element_spec`) for a single element in the batch, and return a boolean indicating if the label is one of the allowed classes. \n",
|
||
|
"\n",
|
||
|
"Your function should then return the filtered dataset.\n",
|
||
|
"\n",
|
||
|
"**Hint:** you may need to use the [`tf.equal`](https://www.tensorflow.org/api_docs/python/tf/math/equal), [`tf.cast`](https://www.tensorflow.org/api_docs/python/tf/dtypes/cast) and [`tf.math.reduce_any`](https://www.tensorflow.org/api_docs/python/tf/math/reduce_any) functions in your auxiliary function. "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#### GRADED CELL ####\n",
|
||
|
"\n",
|
||
|
"# Complete the following function. \n",
|
||
|
"# Make sure to not change the function name or arguments.\n",
|
||
|
"\n",
|
||
|
"def filter_classes(dataset, classes):\n",
|
||
|
" \"\"\"\n",
|
||
|
" This function should filter the dataset by only retaining dataset elements whose\n",
|
||
|
" label belongs to one of the integers in the classes list.\n",
|
||
|
" The function should then return the filtered Dataset object.\n",
|
||
|
" \"\"\"\n",
|
||
|
" def filterer(image,label):\n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" flag = tf.math.reduce_any(tf.equal(label,classes))\n",
|
||
|
" return flag\n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" dataset = dataset.filter(filterer)\n",
|
||
|
" return dataset\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Run the below cell to filter the datasets using your function\n",
|
||
|
"\n",
|
||
|
"cifar_classes = [0, 29, 99] # Your datasets should contain only classes in this list\n",
|
||
|
"\n",
|
||
|
"train_dataset = filter_classes(train_dataset, cifar_classes)\n",
|
||
|
"test_dataset = filter_classes(test_dataset, cifar_classes)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Apply map functions to the Dataset\n",
|
||
|
"\n",
|
||
|
"You should now write two functions that use the `map` method to process the images and labels in the filtered dataset. \n",
|
||
|
"\n",
|
||
|
"The first function should one-hot encode the remaining labels so that we can train the network using a categorical cross entropy loss. \n",
|
||
|
"\n",
|
||
|
"The function should take a `Dataset` object as an argument. Inside your function you should define an auxiliary function that you will use with the `map` method of the `Dataset` object. This auxiliary function should take image and label arguments (as in the `element_spec`) for a single element in the batch, and return a tuple of two elements, with the unmodified image in the first element, and a one-hot vector in the second element. The labels should be encoded according to the following:\n",
|
||
|
"\n",
|
||
|
"* Class 0 maps to `[1., 0., 0.]`\n",
|
||
|
"* Class 29 maps to `[0., 1., 0.]`\n",
|
||
|
"* Class 99 maps to `[0., 0., 1.]`\n",
|
||
|
"\n",
|
||
|
"Your function should then return the mapped dataset."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#### GRADED CELL ####\n",
|
||
|
"\n",
|
||
|
"# Complete the following function. \n",
|
||
|
"# Make sure to not change the function name or arguments.\n",
|
||
|
"\n",
|
||
|
"def map_labels(dataset):\n",
|
||
|
" \"\"\"\n",
|
||
|
" This function should map over the dataset to convert the label to a \n",
|
||
|
" one-hot vector. The encoding should be done according to the above specification.\n",
|
||
|
" The function should then return the mapped Dataset object.\n",
|
||
|
" \"\"\"\n",
|
||
|
" def One_hot(image,label):\n",
|
||
|
"\n",
|
||
|
" if label== 0:\n",
|
||
|
"\n",
|
||
|
" label = [1., 0., 0.]\n",
|
||
|
"\n",
|
||
|
" elif label== 29:\n",
|
||
|
"\n",
|
||
|
" label =[0., 1., 0.]\n",
|
||
|
"\n",
|
||
|
" else:\n",
|
||
|
"\n",
|
||
|
" label = [0., 0., 1.]\n",
|
||
|
" \n",
|
||
|
"\n",
|
||
|
" \n",
|
||
|
" return (image,label)\n",
|
||
|
" \n",
|
||
|
" dataset = dataset.map(One_hot)\n",
|
||
|
" return dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Run the below cell to one-hot encode the training and test labels.\n",
|
||
|
"\n",
|
||
|
"train_dataset = map_labels(train_dataset)\n",
|
||
|
"test_dataset = map_labels(test_dataset)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None),\n",
|
||
|
" TensorSpec(shape=(3,), dtype=tf.float32, name=None))"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"test_dataset.element_spec"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"The second function should process the images according to the following specification:\n",
|
||
|
"\n",
|
||
|
"* Rescale the image pixel values by a factor of 1/255.\n",
|
||
|
"* Convert the colour images (3 channels) to black and white images (single channel) by computing the average pixel value across all channels. \n",
|
||
|
"\n",
|
||
|
"The function should take a `Dataset` object as an argument. Inside your function you should again define an auxiliary function that you will use with the `map` method of the `Dataset` object. This auxiliary function should take image and label arguments (as in the `element_spec`) for a single element in the batch, and return a tuple of two elements, with the processed image in the first element, and the unmodified label in the second argument.\n",
|
||
|
"\n",
|
||
|
"Your function should then return the mapped dataset.\n",
|
||
|
"\n",
|
||
|
"**Hint:** you may find it useful to use [`tf.reduce_mean`](https://www.tensorflow.org/api_docs/python/tf/math/reduce_mean?version=stable) since the black and white image is the colour-average of the colour images. You can also use the `keepdims` keyword in `tf.reduce_mean` to retain the single colour channel."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"#### GRADED CELL ####\n",
|
||
|
"\n",
|
||
|
"# Complete the following function. \n",
|
||
|
"# Make sure to not change the function name or arguments.\n",
|
||
|
"\n",
|
||
|
"def map_images(dataset):\n",
|
||
|
" \"\"\"\n",
|
||
|
" This function should map over the dataset to process the image according to the \n",
|
||
|
" above specification. The function should then return the mapped Dataset object.\n",
|
||
|
" \"\"\"\n",
|
||
|
"\n",
|
||
|
" \n",
|
||
|
" def rescale(image, label):\n",
|
||
|
"\n",
|
||
|
" image = tf.cast(image, tf.float32) / 255.0\n",
|
||
|
"\n",
|
||
|
" image = tf.reduce_mean(image,axis = 2,keepdims=True)\n",
|
||
|
"\n",
|
||
|
" return (image, label)\n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" \n",
|
||
|
" dataset = dataset.map(rescale)\n",
|
||
|
" return dataset\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Run the below cell to apply your mapping function to the datasets\n",
|
||
|
"\n",
|
||
|
"train_dataset_bw = map_images(train_dataset)\n",
|
||
|
"test_dataset_bw = map_images(test_dataset)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Display a batch of processed images"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4EAAAFTCAYAAACHwwnBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvXeYHNd15n1OV+fuyQkZgwyCQQxikCiSEkVakQqUKJPSyqLW9q4trWzvynkdtLa16/WzsvzJtmxx7ZVoUZRF0gpUsEiTZhaYE0gihwEGwGBy6hzu90c3puo9A8wADEj9/p4HD+r0rVy3bt07dd561TknhBBCCCGEEEIag9DJ3gFCCCGEEEIIIScODgIJIYQQQgghpIHgIJAQQgghhBBCGggOAgkhhBBCCCGkgeAgkBBCCCGEEEIaCA4CCSGEEEIIIaSB4CCQEEJOYVT1G6raf5Syt6uqU9VrTvR+NSKq2ls/3zfPM98XVJX+S4QQQk5ZOAgkhBBCXl/+QUTecrJ3ghBCCDka4ZO9A4QQQk4vVDXmnCuc7P04VXHO9YvIEd/eEkIIIacCfBNICCFnEKq6R1VvU9UbVXWzqmZU9WlVfZuZ7xuq2q+qb1XVp1Q1X1/2c2a+m+spkFeq6p2qOi4iTwTK/4OqvlBfflhVv6mqC4+wX7+sqs+qak5Vx1T1IVV9a6A8qar/W1V3q2qx/v9/V9VQYJ60qv61qu5V1YKqHlLV+1R1fWCeX68f9+HtPK2qHzb7cr2qPq6qWVUdrx/XMjNPUlW/qqojqjqtqneLyJJjvAaz0kHr5/DPVPXzqtpXvy4/VtXu+r87VHVCVfep6u+YZbtU9Wuquq2+z/tU9XZVXXyEbd+kqlvq12OTqn5AVR9U1QfNfJ2q+nequr9+Lreo6n8y8yxQ1VtV9UB9noOq+iNV7T6W80AIIeTUhW8CCSHkzOMKEVknIn8oInkR+VMR+ZGq9jrnxgPzNYvId0Tkf4vIDhG5UUS+oqpTzrlvmHV+S0S+LSIflfqzoz5o+Fp9Hb8nIotE5H+KyKWqeqFzbro+3/8Rkc+LyD+KyB+LSFVELhORZSLyM1UNi8g9IrKhvq+b6uV/KCLt9WVFRL4sIh8Qkd8Xke0i0iEil4tIa307nxCRL4nIn4jIIyKSEJHz6uuQ+jy/IiJ/JyJfr8/XJCJfEJGHVPU859xUfdavicjPi8j/EJGnRORaEbl9jnN+LHxSRF4Skc+ISI+I/JWI/FN9H/5VRG4RkRtE5M9VdZNz7if15dqldh1/T0SGpHaePy8ij6nqeudcvn5s10rtOt1dL++sbyMuItsC56BZRB6rn58viMhuEXmXiPxd/S3vX9dn/aaILBeR3xKRffV9fqeIJF/jeSCEEHKycc7xH//xH//x3yn6T0S+ISL9Ryl7u4g4Ebkm8NseERkTkbbAb2+uz/dxs14nIjeadf6biPSJiNbjm+vzfdnM54nIIRF5wPz+tvr8v1aPV4tIRUT+co5j/GR9mSvN7/9dRIoi0l2PX5pnPX8jIs/OUZ4WkQkR+X/m9976dn6jHq+r7/Pvmvn+rr6fN89zzb5Qe7zCb05qA7Fw4Le/rP/+B4HfwiIyKCJfn2P9nogsrS/74cDvP6ufIw38dmF9vgcDvx3+48Aas97/KyLDh/dRRKYPX0f+4z/+4z/+O7P+MR2UEELOPDY658YC8ab6/8vMfBUR+Rfz2z/X57Opht8z8ToR6Zbam6cZnHOPSm0QeVX9p2ukJj24ZY79fXd9mZ+pavjwPxG5V0QiUnsrKFJ7I3ezqv6+qr5ZVT2znqdE5Px6yug1qmrfWL1Fam8/v2W20y8iW0Tkyvp8l9b3+Q6z/D/PcQzHwr8558qBeEv9/3sO/1Av3yG1Qd4Mqvqr9bTbaREpi8jeetG6erkntcH+vzjnXGB9z0rtTV+Qd0stpXe3OQ/3SO3t6ob6fE+JyG/VU2zPVVV9tQdOCCHk1IKDQEIIObUpS+3Nz5HwAvMEGQ0Gzv+IS9zMN+acK5nfDtX/t4PAgyZuP8rvIiIDgfKO+v9zfSilW2pphyXz70mzjs9JLU3zP0ptgDKoql8ODPb+SUR+VWqDuHtEZFRVv6uqvYHtiIjcd4RtnRvYzmFN4+FzIUeJj5cxExfn+H3mWmlNp/lVqe339SJyifgD48PzdUptwDx4hO3a/e6W2oDXnoM76+WHz8PPSy219LdF5EUR2a+qfxTUaRJCCDk9oSaQEEJObQZFpFNVo865oilbVP//1Q5O2lQ1YgaCPfX/95t5re/d4YHmgiOsd4GIPF2fHq7/v1hEth5lP0ak9rbqY0cp3yMi4moaw98Tkd9T1eVS0yf+udQGTb9TfwP2NRH5mqq2icjPSU0j+B2pDQxH6uu7WURePsJ2DusBDw9se0RkV6C8R04ON4rI/c65w9pIUdUVZp5hqQ3kjvTRlh7x3xyK1M7DoIj8+lG2t1VExDk3KCKfFZHPquo6EfmU1DSSQ1JLjSWEEHKawr/mEULIqc0DUvuD3QeOUPYRqQ1Yjja4mg+vvo4gN0ptwGAHgZatUht83hj8sf7Fz+Ui8lD9p/uk9iEY+PKk4adSS3+cds49fYR/w3YB51yfc+5LUkt1PecI5WPOue9ILaXzcPnPpDbQW32U7Rw+j0/U99kOSm+Uk0NSagO8IJ8OBs65itQG3h8Jpm2q6kUiYgeMPxWR9SKy9yjnYcrML865rc6535faW8tZ55sQQsjpBd8EEkLIqc19UvtYyzfqVghPSO1rkjeKyAdF5NPOueqrXPeUiPyFqnZK7WubN0lNw3dzUFd2JJxzFVX9I6m9dbtNRG6T2tu+L9bX9fX6fDtV9csi8t9UtUlq6YUVqaU0bqkP1L4ltUHN/ar6JRF5QUSiIrJKaoPfDznnsqq6sb78Jql9tOQqEXmTiNwqIqKqt9SPaaPU3nStldpHZ+6t78ukqv6WiPytqnZJ7YucE/X9vkpqH0+53Tm3VVVvF5E/qac+Hv466HtfzUl+HfipiPyOqv6+1FJkr5baW1DLH0vtWL9XPxedUvtIzYDUBrWH+bLUUj0fqV+brSKSktrA8Arn3AdVtUVqde9bUtMulqRW39rq2yCEEHIaw0EgIYScwjjnnKp+QET+QER+QWpfdiyKyPNSGxz94DWsflJqg8n/T2qauEMi8uvOuVuPcd9uUdWs1CwEfiC1gdlPROS366mbh+f7TVXdITVrhE+JSEZqGrPDg7OSqr5LRH5Xam8MV9Tn2SkiPxZfO/ew1N7O/a7Unl+7ROS/Oue+Ui9/TGqDyU+KSIuIHJDa4PSPA/vyNVXdV9/nj0tNR7e/vu7nA4f3n+vH85tSG5D+e33+R4/l3LzO/InUbDD+q9Q0gA9JzdIhmKoqzrl/q9tk/LHUPuSzQ2pWEX8ktcHu4fkm6m9s/0hEfkdqg+BxqQ0GD38oKC8iz4rIL0vtzW61Xv6J11jnCCGEnALoPH/sJYQQcgaiqt+QmrXEMRmgk9MTVV0itcHgF51zf3qy94cQQsipAd8EEkIIIWcAqpqQmvfgfVL7UMxKqX3ZMysi/3ASd40QQsgpBgeBhBBCyJlBRWpfZv0bqdk8ZETkERG5wTl3JCsPQgghDQrTQQkhhBBCCCGkgaBFBCGEEEIIIYQ0EBwEEkIIIYQQQkgDwUEgIYQQQgghhDQQHAQSQgghhBBCSAPBQaBBVb+hqn+mqleo6taTvT+EnK6oaq+qOlXlV4gJIYQQckywL35iYOfsKDjnHhGRdSd7PwghhBBCCGk02Bd/Y+GbwDMEvm0hhJATg9bg85OckbB+E/LqON364g1/k6vqBar6rKpOqep3RCRe//3tqtofmG+Pqv6mqr6oqhOq+h1VjQfKf1lVd6jqqKreraqL6r+rqn5ZVQfry72oqufUy96nqs+p6qSq7lPVLwTWB9sP7MM19ekvqOpdqnqbqk6KyM1v3FkiZzqq+ruqurN+H7yiqh+u/36zqj6mqn9dr79bVPWdgeUeVNX/papP1st/oKrtR9lGi6r+o6oeVNX99VQP70Q
|
||
|
"text/plain": [
|
||
|
"<Figure size 1152x360 with 10 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4EAAAFTCAYAAACHwwnBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvXnYXlV5tn/dIjJPYchMAklIAmGSIUwhCChopS0KtkUUrFbF+pMe39fB1mrVarUjitraIkILVSYnkEFmfgwJMgRCyEQmMg8kTAGZZH1/PE9e1jrfvHvnISHDu6/zODjY17v2s8e11l4r+772HSklGWOMMcYYY4xpBm/b1AdgjDHGGGOMMWbj4UmgMcYYY4wxxjQITwKNMcYYY4wxpkF4EmiMMcYYY4wxDcKTQGOMMcYYY4xpEJ4EGmOMMcYYY0yD8CTQGGN6CRFxbkSk7L/nI+LRiPhsRLx9Ux/f5khEXBoR89ZhvRQRX37rj8gYY4x56/GgwBhjeh9nSlooaef28nck7SXpS5vyoLZwjlbrmhpjjDFbPJ4EGmNM7+ORlNKs9vLNETFc0p+ph0lgRISkrVNKr2ysA9zSSClN3NTHYIwxxmwoHA5qjDG9nwck7RQRe0lSRMyLiMsj4o8jYrqkVyT9Trusf0T8T0Q8FREvR8TkiDibG4yIfSLisohY2l5vTkR8G+uMj4jb2mGpL0TEryJiDNY5JSLui4hnI2J1RMyIiC9l5ftFxM8iYnlEvBQR8yPi6jy8NSL2iIj/iIhF7WOZHhGfXMsxnxQRD7e3MzsiPrWuF5DhoBHx5fbfRrXP64X2sX2sXf6R9nGsjog7ImIYtveHEXF7RKxorzMpIs5Zy373jIgfR8RzEfF0RFwSEb/b3vcJWPcDETExIl6MiGfa12lvrHNWe1+r29f8sU6ugzHGmN6B3wQaY0zvZx9Jv5W0OvvbuyQdIukrkpZLmhcRO0i6S9Jukv5G0gJJZ0u6LCK2Tyn9l9SaAEr6taQXJf2dpCckDZb0njUbj4jfkfQLSde3tyFJfyXp7og4KKW0ICL2lXStpGskfVWtyegISftmx/lLSc9IOk/SU5IGSnqf2v+IGRE7S7pX0naSvixprqRTJP1HRGyTUvpOe73Rkm6Q9KCkP5S0TXv9HdvX5s1ytaSLJP2LpM9I+mFEjJB0gqTPS9pa0rcl/UjS2Ox3+7bP+5uSXpd0vKQfRMR2KaXvZ+v9VNKBkv5a0ixJH1QrvLcgIj4t6T8kXaLWtdypfX53ta/38xFxnKTLJV0o6S/UuoajJO26HudvjDFmSySl5P/8n//zf/6vF/wn6VxJSdJItf6RbzdJn1JrkvPzbL15ak3g+uH3n23//gT8/Va1JopbtfX/qDWhHFBxLLMk3Ya/7azWRO5bbX1Ge38797CNPdrlv1uxny9KeknSCPz9ova+3t7W/9vWO2TrDFZr4jlvHa5tkvTlTH+5/bePZn/bTdJrklbm5yTpc+11h/Sw7be179dFkh7N/v6e9u8+hPWvze+TWhPZZyX9EOsNbZ/fn7X1n0tatanrqf/zf/7P//m/Tf+fw0GNMab3MV3Sq5JWSfp3tSZAf4x1JqaUluJvx0talFK6E3+/XNKekvZv6/dI+mVKafHadt5+EzZM0v9GxNvX/KfWxHNCez+S9Ej7OK+IiDPWhKtmrJQ0R9I3I+JP2tslp0q6X9Jc7OtXknbPjvloSTeklF5Y88OU0gK13iKuDzdm23tarcnyxJTSc9k609v/H7zmDxExoh3muUita/CqpE+oNYFfw1FqTeB/hn1eA320WhNsXu+F7X2vud4PSNqtHQr8/ojwG0BjjGkongQaY0zv43RJR6gV6rdDSumjKaVVWGfJWn7Xp4e/L83KpdbkqupLmWsmcxfrjQnOmv/e3/69UuvjNaeo9Sy6TNLSiLg/Isa3y5Okd6sVwvkNSTPb3sPzsK/j17Kfq7NjlaT+kpat5VjX9rdOeBr6lR7+JknbSlJE7CjpFkkHqxUyOk6t+/VDtcJU19Bf0tMppVdrjnnN9b5V3a/DgXrjet+l1tdiB6s1sVwREbdGxEHrcqLGGGN6D/YEGmNM72NKeuProD2R1vK3VSrfRK2hX/v/K9v/X+PN64k16/21WhMT0vUV0pTSHZLuiIhtJB2rlp/t+ogYmlJ6KqU0R9JHIyLUmjR9VtK/R8S8lNKN7X0tl3R+D8cyo/3/JZL6rqV8bX97qzla0hBJ41JK96z5Y3TP5bhErTd3W2MiyGNec73PlfT4Wvb3/JqFlNI1kq5pT0RPkPSPkm6KiEEppdffxLkYY4zZAvEk0BhjzBruknRmRBybUsrDJM9Sa6I1ra1vlvSBiOifUlrbm8MZavkOD0gpfXNddpxSelnS7e3JyS/U+pjNU1l5kvRIRPwfSR+XNEatUMybJP1/kuanlJZX7GKCpPdFxA5rQkIjYrBaE8+1hrW+hWzf/n/XxC4idpP0e1hvoqSt1Hqze1X29zOx3n1qTfSGp5T+e10OIKW0WtIv2x/n+bZabwtXrOsJGGOM2bLxJNAYY8waLlXrjdpPI+ILaoV8flitkMxPpZTWfEXz79RKKXFfRPyDWh+BGSjp1JTS2SmlFBF/KukXEfEOtSYwT6n1BusYtSZs/9b+ouXxan21c4FaH4L5a7UmZVPaYYrflnRlex9bqfW26zVJt7eP5QJJf6DWV0cvUGsCuoNaobDjUkprJlZfU2vydHNE/LOkd6j1ZdT1DQd9M9wn6TlJ34uIv2sf79+qdY12WbNSSunmiLhH0n9FxB5qXYMz1HojKrW+KqqU0nMR8Rft7e2p1uT4WbXuyXhJd6aUfhQRX1XrHtyh1jUepNZHax5JKXkCaIwxDcKTQGOMMZKklNILbT/eP6mVumAntSZVH0kpXZ6tNy8ixqo1sfpGe71Far3BW7PODRFxvKQvSPqBWikclqr1duvK9mqPSnpvext7qRWOeo+kD6eUfhMRSyXNl/R/1JqwvCTpMUnvTyk91N7PsxFxjKQvqZWCYqBaKSVmSPpJdjzTIuJ9kv65vf9FaoVCHq1WWORGI6W0IiJOl/Svan3kZbFak90+ak2wcz6gVkqIf1TrIzHXqvVF1EvVmuit2eZ/RsQCtVI/nKVWaopFkv5/tT7AI7U+oPM5tSbOfdR6u3tze3vGGGMaRLQibIwxxhizJRAR31PrjWifdhitMcYY0xF+E2iMMcZspkTEuWqFiD6uVgjrqZI+LemfPQE0xhjzZvEk0BhjjNl8eUHSn6mVd3EbSXMl/Y1aYa3GGGPMm8LhoMYYY4wxxhjTIJws3hhjjDHGGGMahCeBxhhjjDHGGNMgPAk0xhhjjDHGmAbhSaAxxhhjjDHGNAhPAkFEXBoRX4uIcRExY1MfjzFbKhExNCJSRPgrxMYYY4xZJzwW3zh4cNYDKaW7JY3c1MdhjDHGGGNM0/BY/K3FbwJ7CX7bYowxG4do4een6ZW4fhvz5tjSxuKNb+QRcWhEPBwRz0fElZK2bf/9hIhYmK03LyL+PCImR8SzEXFlRGyblf9JRMyKiFURcW1EDGj/PSLigohY3v7d5IgY0y77nYiYFBHPRcSCiPhytr1i/9kxnNxe/nJEXBMRl0fEc5LOfeuukuntRMTnI2J2ux1MjYjT238/NyLujYjvtOvv9Ig4KfvdnRHxjYj4dbv8FxHRp4d97BIRF0fEkohY1A712GpjnaNpLhHxsYi4LtOzIuKqTC+IiEMi4piIeKBdlx+IiGOyde6MiK9HxL2SXpS0b/tvX4uI+yJidURcFxG7R8T/tvv1ByJi6MY8V9M8XL/Nlo7H4puGRk8CI+Idkn4u6TJJfSRdLemDFT/5kKRTJe0j6SC1b3ZEnCjpG+3y/pKelHRF+zfvkXS8pP0k7SrpDyStbJe9IOmj7b//jqTzIuL3OziF35N0Tfv3/9vB74whsyWNk7SLpK9Iujwi+rfLxkqaI2kPSX8n6aeY6H1U0h9LGiDpNUkX9rCP/26XD5d0qFpt4xMb9jSMWSt3SRoXEW9r1+utJR0
|
||
|
"text/plain": [
|
||
|
"<Figure size 1152x360 with 10 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Run this cell to view a selection of images before and after processing\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(16,5))\n",
|
||
|
"plt.suptitle(\"Unprocessed images\", fontsize=16)\n",
|
||
|
"for n, elem in enumerate(train_dataset.take(10)):\n",
|
||
|
" images, labels = elem\n",
|
||
|
" ax = plt.subplot(2, 5, n+1)\n",
|
||
|
" plt.title(cifar_labels[cifar_classes[np.where(labels == 1.)[0][0]]])\n",
|
||
|
" plt.imshow(np.squeeze(images), cmap='gray')\n",
|
||
|
" plt.axis('off')\n",
|
||
|
" \n",
|
||
|
"plt.figure(figsize=(16,5))\n",
|
||
|
"plt.suptitle(\"Processed images\", fontsize=16)\n",
|
||
|
"for n, elem in enumerate(train_dataset_bw.take(10)):\n",
|
||
|
" images_bw, labels_bw = elem\n",
|
||
|
" ax = plt.subplot(2, 5, n+1)\n",
|
||
|
" plt.title(cifar_labels[cifar_classes[np.where(labels_bw == 1.)[0][0]]])\n",
|
||
|
" plt.imshow(np.squeeze(images_bw), cmap='gray')\n",
|
||
|
" plt.axis('off')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"We will now batch and shuffle the Dataset objects."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Run the below cell to batch the training dataset and expand the final dimensinos\n",
|
||
|
"\n",
|
||
|
"train_dataset_bw = train_dataset_bw.batch(10)\n",
|
||
|
"train_dataset_bw = train_dataset_bw.shuffle(100)\n",
|
||
|
"\n",
|
||
|
"test_dataset_bw = test_dataset_bw.batch(10)\n",
|
||
|
"test_dataset_bw = test_dataset_bw.shuffle(100)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Train a neural network model\n",
|
||
|
"\n",
|
||
|
"Now we will train a model using the `Dataset` objects. We will use the model specification and function from the first part of this assignment, only modifying the size of the input images."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Build and compile a new model with our original spec, using the new image size\n",
|
||
|
" \n",
|
||
|
"cifar_model = get_model((32, 32, 1))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<ShuffleDataset shapes: ((None, 32, 32, 1), (None, 3)), types: (tf.float32, tf.float32)>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 21,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"train_dataset_bw"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"metadata": {
|
||
|
"scrolled": true
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Epoch 1/15\n",
|
||
|
"150/150 [==============================] - 19s 127ms/step - loss: 1.0330 - categorical_accuracy: 0.4607 - val_loss: 0.0000e+00 - val_categorical_accuracy: 0.0000e+00\n",
|
||
|
"Epoch 2/15\n",
|
||
|
"150/150 [==============================] - 18s 122ms/step - loss: 0.9124 - categorical_accuracy: 0.5660 - val_loss: 0.8098 - val_categorical_accuracy: 0.6433\n",
|
||
|
"Epoch 3/15\n",
|
||
|
"150/150 [==============================] - 19s 124ms/step - loss: 0.8025 - categorical_accuracy: 0.6567 - val_loss: 0.7377 - val_categorical_accuracy: 0.7267\n",
|
||
|
"Epoch 4/15\n",
|
||
|
"150/150 [==============================] - 19s 126ms/step - loss: 0.7323 - categorical_accuracy: 0.6940 - val_loss: 0.6823 - val_categorical_accuracy: 0.7367\n",
|
||
|
"Epoch 5/15\n",
|
||
|
"150/150 [==============================] - 18s 122ms/step - loss: 0.6936 - categorical_accuracy: 0.7160 - val_loss: 0.6610 - val_categorical_accuracy: 0.7333\n",
|
||
|
"Epoch 6/15\n",
|
||
|
"150/150 [==============================] - 18s 123ms/step - loss: 0.6527 - categorical_accuracy: 0.7307 - val_loss: 0.6424 - val_categorical_accuracy: 0.7533\n",
|
||
|
"Epoch 7/15\n",
|
||
|
"150/150 [==============================] - 18s 121ms/step - loss: 0.6353 - categorical_accuracy: 0.7353 - val_loss: 0.6299 - val_categorical_accuracy: 0.7567\n",
|
||
|
"Epoch 8/15\n",
|
||
|
"150/150 [==============================] - 18s 121ms/step - loss: 0.5887 - categorical_accuracy: 0.7553 - val_loss: 0.6124 - val_categorical_accuracy: 0.7533\n",
|
||
|
"Epoch 9/15\n",
|
||
|
"150/150 [==============================] - 18s 121ms/step - loss: 0.5996 - categorical_accuracy: 0.7613 - val_loss: 0.6354 - val_categorical_accuracy: 0.7500\n",
|
||
|
"Epoch 10/15\n",
|
||
|
"150/150 [==============================] - 18s 121ms/step - loss: 0.5592 - categorical_accuracy: 0.7700 - val_loss: 0.6084 - val_categorical_accuracy: 0.7533\n",
|
||
|
"Epoch 11/15\n",
|
||
|
"150/150 [==============================] - 18s 123ms/step - loss: 0.5527 - categorical_accuracy: 0.7720 - val_loss: 0.6201 - val_categorical_accuracy: 0.7433\n",
|
||
|
"Epoch 12/15\n",
|
||
|
"150/150 [==============================] - 18s 121ms/step - loss: 0.5486 - categorical_accuracy: 0.7787 - val_loss: 0.5873 - val_categorical_accuracy: 0.7633\n",
|
||
|
"Epoch 13/15\n",
|
||
|
"150/150 [==============================] - 18s 122ms/step - loss: 0.5131 - categorical_accuracy: 0.7913 - val_loss: 0.5851 - val_categorical_accuracy: 0.7800\n",
|
||
|
"Epoch 14/15\n",
|
||
|
"150/150 [==============================] - 18s 123ms/step - loss: 0.5064 - categorical_accuracy: 0.7987 - val_loss: 0.5622 - val_categorical_accuracy: 0.7767\n",
|
||
|
"Epoch 15/15\n",
|
||
|
"150/150 [==============================] - 18s 122ms/step - loss: 0.4948 - categorical_accuracy: 0.8007 - val_loss: 0.5729 - val_categorical_accuracy: 0.7900\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Train the model for 15 epochs\n",
|
||
|
"\n",
|
||
|
"history = cifar_model.fit(train_dataset_bw, validation_data=test_dataset_bw, epochs=15)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"#### Plot the learning curves"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4IAAAFNCAYAAABVKNEpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzs3Xd8leX9//HXJ3snkACBJBD2FGIIGxVEBK1bq+JAtNaqtdaqrWjbb62/ttraulpbtRUUtY6KtrZVESsKCAiEvWcgYYQdEiCQcf3+uA8QMECEc3Iy3s/H4zxOzj2u+3MfMVc+97XMOYeIiIiIiIg0HiHBDkBERERERERqlxJBERERERGRRkaJoIiIiIiISCOjRFBERERERKSRUSIoIiIiIiLSyCgRFBERERERaWSUCIpI0JhZppk5MwsLdiwiIiKBYGaPmtnrwY5D5HhKBKVBMrPPzWy3mUUGOxYREZH6zMzyzOyCYMchIv6lRFAaHDPLBM4BHHBZLV9bLVsiIiIiUucpEZSGaDQwC3gFuKXqDjOLNrM/mNkGMysys+lmFu3bN9jMZpjZHjPLN7Mxvu2fm9ntVcoYY2bTq3x2ZvZ9M1sNrPZte9ZXxl4zyzWzc6ocH2pmj5jZWjMr9u3PMLPnzewPx8X7bzO77/gbNLMXzOz3x237l5nd7/v5ITPb5Ct/pZkNq8kXZ2atzGyimW03s/Vmdm+VfY+a2btm9rav3Hlm1qvK/q6+72qPmS01s8uq7Dvh9+5zo5ltNLMdZvbTKuf1NbO5vu+x0Myeqsl9iIhI7TCz75rZGjPbZWYfmFkr33Yzs6fNbJvv9/4iM+vh23exmS3z1SWbzOzBasqN9NUnPapsa2ZmB8ysuZmlmNl/fMfsMrNpZlajv2vN7BIzW+A7d4aZ9ayyL8/MHvbFt9vMxptZ1Knu17evu5lN9u0rNLNHqlw2wswm+O55qZnlVDnvtOpskTPmnNNLrwb1AtYAdwO9gTKgRZV9zwOfA2lAKDAQiARaA8XAKCAcSAayfOd8DtxepYwxwPQqnx0wGWgKRPu23eQrIwx4ANgKRPn2/RhYDHQGDOjlO7YvsBkI8R2XAuyvGn+Va54L5APm+9wEOAC08pWbD7Ty7csE2tfgewsBcoH/AyKAdsA6YIRv/6O+7/Ma33f0ILDe93O473t/xHfu+b7vs/MpvvdM3/f3VyDa910cBLr6zpsJ3Oz7OQ7oH+x/X3rppZdeje0F5AEXVLP9fGAHkO37nf5HYKpv3whfnZLkq+u6Ai19+7YA5/h+bgJkn+C644BfV/n8feBj38+PAy9UqYPOOVwnnuJesoFtQD9ffXSL7/4iq9zrEiADr17/EvhVDe433ndfDwBRvs/9fPseBUqBi33XfByY5dt3WnW2Xnr546UWQWlQzGww0AZ4xzmXC6wFbvDtCwFuA37onNvknKtwzs1wzh0EbgQ+dc696Zwrc87tdM4t+AaXftw5t8s5dwDAOfe6r4xy59wf8CqMzr5jbwd+5pxb6TwLfcfOBoqAw08Crwc+d84VVnO9aXgJ1OGWxmuAmc65zUCF73rdzCzcOZfnnFtbg3voAzRzzj3mnDvknFuHl6BdX+WYXOfcu865MuApvMquv+8VBzzhO/cz4D/AqFN874f90jl3wDm3EFiIlxCCl3h2MLMU51yJc25WDe5DRERqx43AOOfcPN/v9IeBAeYN0SjDS4a64CVoy51zW3znleHVUQnOud3OuXknKP/veA9oD7vBt+1wGS2BNr56e5pzztUg5u8CLzrnvvLVR6/iPYDsX+WYPznn8p1zu4BfV4nhZPd7CbDVOfcH51ypc67YOfdVlTKnO+c+dM5VAK9xtJ473Tpb5IwpEZSG5hbgE+fcDt/nv3O0e2gKXuJS3S/YjBNsr6n8qh/M7AEzW+7rDrMHSPRd/1TXehWvNRHf+2vVHeSr7N7iaOV0A/CGb98a4D68J5DbzOytql1XTqIN0MrXVWaPL+5HgBbV3adzrhIowGuFbAXk+7YdtgGvBfBk3/thW6v8vB8vqQT4DtAJWGFmc8zskhrch4iI1I5WeL/rAXDOlQA7gTTfA8E/4fUIKTSzl8wswXfo1XitYxvM7AszG3CC8j8Dos2sn5m1AbKA9337nsTrifKJma0zs7E1jLkN8MBxdV2G714Oq1qnb6iy74T3y6n/jji+nosys7AzqLNFzpgSQWkwfGPOrgXOM7OtZrYV+BHQyzeWbQde14z21Zyef4LtAPuAmCqfU6s55shTSPPGAz7ki6WJcy4Jr6XPanCt14HLffF2Bf55guMA3gSu8VWO/YCJR4Jx7u/OucOtow747UnKOSwfWO+cS6ryinfOXVzlmIwq9xkCpON1Z90MZBw3PqM1sImTf+8n5Zxb7ZwbBTT33cO7Zhb7TcsREZGA2IxXzwDg+/2cjPe7H+fcc8653kB3vId6P/Ztn+Ocuxzvd/s/gXeqK9z3cPEdvIeeNwD/cc4V+/YVO+cecM61Ay4F7q/h2Lp8vO6mVeu6GOfcm1WOyajyc2vffZ7qfk9Wt5/UadbZImdMiaA0JFfgdbHohvfUMAsvmZoGjPZVKOOAp8ybFCXUzAaYt8TEG8AFZnatmYWZWbKZZfnKXQBcZWYxZtYBr5XqZOKBcmA7EGZm/wckVNn/N+D/mVlH32D6nmaWDOCcKwDm4LUETjzc1bQ6zrn5vmv8DZjknNsDYGadzex8332V4o0drDj118dsYK9v0Hq07/vpYWZ9qhzT28yuMm921PvwutPMAr7CS5h/YmbhZjYEr2J+6xTf+0mZ2U1m1sxXxh7f5prci4iI+Fe4mUVVeYXh9bq51cyyfL/TfwN85ZzLM7M+vpa8cLz6oRSoMLMIM7vRzBJ9wwz2cvLf638HrsPrlnm4W+jhCV86mJlVKaMm9cNfgTt9sZmZxZrZt8wsvsox3zezdDNritcz5u0qsVR7v3jDIVLN7D7zJrqJN7N+pwrmDOpskTOmRFAakluA8c65jc65rYdfeF1TbvRVWg/iTdQyB9iF99QtxDm3Ea+bygO+7Qs42n//aeAQUIjXdfONU8QxCfgIWIXXhaSUY7uZPIX3hPMTvMrrZbyJUg57FTiLE3QLPc6bwAVUqRzxxho8gdcStxXviesjAL7Kd2l1BfnGLVyKl0Cv953/N7xurYf9C69C3g3cDFzlG5txCG+pjot85/0ZL/le4Tuv2u+9Bvc3ElhqZiXAs8D1zrnSGpwnIiL+9SFeknL49ahz7n/Az/F6pGzBaxE7PK48AS/p2o1XF+4EDs92fTOQZ2Z7gTs5OiTia3zj7Pbhdcv8qMqujsCnQAnexGJ/ds59DmBmH9mxM3ZWLW8u3jjBP/liW4M3CVxVf8ero9f5Xr/ynXvC+/W1VA7Hq0e34s0iPvRE91XFCetskUA7POOgiNQRZnYuXhfRzOPG3AWVmT0KdHDOnbDCFhERqc/MLA9vpvBPgx2LSKCpRVCkDvF1ofkh8Le6lASKiIiISMOiRFCkjjCzrnjj4FoCzwQ5HBERERFpwNQ1VEREREREpJFRi6CIiIiIiEgjo0RQRERERESkkQkLdgD+lJKS4jIzM4MdhoiIBFhubu4O51yzYMdRX6h+FBFpPGpaRzaoRDAzM5O5c+cGOwwREQkwM9sQ7BjqE9WPIiKNR03rSHUNFRERERERaWSUCIqIiIiIiDQySgRFREREREQamQY1RlBEREREROqWsrIyCgoKKC0tDXYoDUpUVBTp6emEh4ef1vlKBEVEREREJGAKCgqIj48nMzMTMwt2OA2Cc46dO3dSUFBA27ZtT6sMdQ0VEREREZGAKS0tJTk5WUmgH5kZycnJZ9TKqkRQREREREQCSkmg/53pdxrQRNDMRprZSjNbY2Zjq9mfaGb/NrOFZrbUzG6t6bkiIiIiIiKnsnPnTrKyssjKyiI1NZW0tLQjnw8dOlSjMm699VZWrlx50mOef/553njjDX+EXCsCNkbQzEKB54HhQAEwx8w+cM4tq3LY94FlzrlLzawZsNLM3gAqanCuiIi
|
||
|
"text/plain": [
|
||
|
"<Figure size 1080x360 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Run this cell to plot accuracy vs epoch and loss vs epoch\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(15,5))\n",
|
||
|
"plt.subplot(121)\n",
|
||
|
"try:\n",
|
||
|
" plt.plot(history.history['accuracy'])\n",
|
||
|
" plt.plot(history.history['val_accuracy'])\n",
|
||
|
"except KeyError:\n",
|
||
|
" try:\n",
|
||
|
" plt.plot(history.history['acc'])\n",
|
||
|
" plt.plot(history.history['val_acc'])\n",
|
||
|
" except KeyError:\n",
|
||
|
" plt.plot(history.history['categorical_accuracy'])\n",
|
||
|
" plt.plot(history.history['val_categorical_accuracy'])\n",
|
||
|
"plt.title('Accuracy vs. epochs')\n",
|
||
|
"plt.ylabel('Accuracy')\n",
|
||
|
"plt.xlabel('Epoch')\n",
|
||
|
"plt.legend(['Training', 'Validation'], loc='lower right')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(122)\n",
|
||
|
"plt.plot(history.history['loss'])\n",
|
||
|
"plt.plot(history.history['val_loss'])\n",
|
||
|
"plt.title('Loss vs. epochs')\n",
|
||
|
"plt.ylabel('Loss')\n",
|
||
|
"plt.xlabel('Epoch')\n",
|
||
|
"plt.legend(['Training', 'Validation'], loc='upper right')\n",
|
||
|
"plt.show() "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Create an iterable from the batched test dataset\n",
|
||
|
"\n",
|
||
|
"test_dataset = test_dataset.batch(10)\n",
|
||
|
"iter_test_dataset = iter(test_dataset)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2cAAAGtCAYAAACSpwyeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzsvXmwZVl21rf2OXd+85SZL18OlZWVlTVXV6u7Wj3RLbVaILVAstQIEJgQhOzADkI4wpjAER4wIcBEMDiwHWFs2QhLQgiEEAgQkpDcTavV3erqmofMqpznl/nG++58zznbf7zXItf6VnZmFVXv3az+fhEZkfu8dc+w99rDuXd/a4UYoxBCCCGEEEII2VuSvb4BQgghhBBCCCF8OSOEEEIIIYSQkYAvZ4QQQgghhBAyAvDljBBCCCGEEEJGAL6cEUIIIYQQQsgIwJczQgghhBBCCBkB+HJ2HxBCeCCEEEMIpb2+F3J/E0L4mRDCT4UQPhlCOL3X90PI3aDPEnJnuD4g5P0HOzMh34bEGL8kIif3+j4IuVfos4QQQr4d4C9nhJCRh98Kk/sN+iwhhGwTtuE7xz3CinqXCSH85RDC2RDCVgjh9RDCf7Jz/MdDCF8OIfyvIYTNEMKpEMJnbvvcF0IIfyOE8Hs7f/8XIYTZO1xjKoTwf4cQrocQru5s+Ul36xnJ/UMI4ZkQwvM7/viLIlLbOf7pEMKV2+wuhBD+Ygjh5R3/+8UQQu22v/9nIYQzIYS1EMK/DCEc3DkeQgh/N4Rwc+dzL4cQntj52+dCCC+EEJohhMshhL9y2/nU9W+7h+/Z+f9fCSH8Ugjh50IITRH58feulsgoQZ8l71e4PiD3CyGEPxNC+NXbymdCCP/ktvLlEMIHQggfCyF8fccvvx5C+NhtNl8IIfy1EMKXRaQjIg/uHPupEMLvhhBaIYRfDSHMhRB+fmfc/XoI4YHdfNZRhC9n7z5nReSTIjIlIv+TiPxcCGFx528fEZFzIjIvIv+jiPyyGWD/tIj8WRE5KCKZiPy9O1zjH+78/SEReUZEvldEfuLdfQxyvxNCqIjIr4jIz4rIrIj8UxH5kW/xkR8VkT8kIsdE5CnZWVyGEL5bRP7Gzt8XReSiiPzjnc98r4j8ARF5WESmReSPicjqzt/asu3T0yLyORH5L0IIP/Q2HuEHReSXdj7/82/jc+Q+hT5L3udwfUDuF74oIp8MISQ7PloWkY+LiIQQHhSRcRG5JCL/WrZ9cU5E/o6I/OsQwtxt5/lPReQ/F5EJ2R6HRUT++M7xJRE5LiJfEZF/INtj/huy7f/f1vDl7F0mxvhPY4zXYoxFjPEXReQtEXl25883ReR/iTEOd/52WrYXAN/kZ2OMr8YY2yLy34vIj9pvvEII+0Xk+0Tkv4oxtmOMN0Xk78q2sxNyO98p2wPqN33ul0Tk69/C/u/t+O6aiPyqiHxg5/ifFJH/J8b4fIyxLyL/rYh8dOfbraFsD7qPiEiIMb4RY7wuIhJj/EKM8ZWdvvCyiPyCiHzqbdz/V2KMv7Lz+e7b+By5f6HPkvctXB+Q+4UY4zkR2ZLtMfVTIvLrInI1hPDITvlLsu2fb8UYfzbGmMUYf0FETonIH77tVD8TY3xt5+/DnWP/IMZ4Nsa4KSK/JiJnY4z/LsaYyfYXcs/sykOOMHw5e5cJIfzpEMKLIYSNEMKGiDwh29+EiYhcjTHG28wvyva3YN/ksvlb+bbPfpOjO8ev33aNvy8i+97N5yDvCw6K73N34sZt/+/I9jdj3zzP738uxtiS7V8almKMvy0i/5uI/O8ishxC+D9DCJMiIiGEj4QQ/r8Qwq0QwqaI/DlBf/5WXL67CXmfQZ8l71u4PiD3GV8UkU/L9k6DL4rIF2T7xexTO2U1zu5wUbZ/Efsm3pi4fNv/u055XL7N4cvZu0gI4aiI/F8i8udFZC7GOC0ir4pI2DFZCiGE2z5yRESu3VY+bP42FJEVc5nLItIXkfkY4/TOv8kY4+Pv4qOQ9wfXxfe5t8s12Z70RUQkhDAm21sYroqIxBj/XozxO0TkcdneKvbf7Jj+IxH5lyJyOMY4JSL/h/yHvtAWkcZt50xFZMFcNwr5doM+S96XcH1A7kO++XL2yZ3/f1H0y5kaZ3c4Ijvj7A4cE98BfDl7dxmTbUe8JbItqJTtb8a+yT4R+ckQQjmE8EdF5FER+Te3/f1PhRAeCyE0ROSvisgvxRjz2y+ws/3mN0Tkb4cQJnf2Ax8PIbydrTfk24OvyLb24CdDCKUQwg/Lf9hC83b4RyLyZ3bEv1UR+esi8rUY44UQwod3fm0oy/bitSci3/TZCRFZizH2QgjPisiP3XbON0WkFrYDMJRF5L8Tkeo7ekryfoI+S96vcH1A7je+KCLfJSL1GOMV2d7K+Idk+4uuF2TbPx8OIfzYznj9x0TkMRH5V3t1w+8X+HL2LhJjfF1E/rZsLzCWReRJEfnybSZfE5ETsv1t118Tkc/HGFdv+/vPisjPyPZWnZqI/OQdLvWnRaQiIq+LyLpsC9AX72BLvk2JMQ5E5IdlO0jCumwHPvjld3Ce35JtjcM/k+1fNo7Lf9AwTMr2t8Hrsr2dYVVE/tbO3/5LEfmrIYQtEfkfROSf3HbOzZ2//7Rsf8vWFhEVCY98+0GfJe9XuD4g9xsxxjdFpCXbL2USY2zKdtCaL8cY8x3//AER+a9lexz9SyLyAzFG+4sueZsEvcWZvFeEEH5cRH4ixviJO/z9CyLyczHGn97N+yKEEELI3sH1ASHkdvjLGSGEEEIIIYSMAHw5I4QQQgghhJARgNsaCSGEEEIIIWQE4C9nhBBCCCGEEDIClHbzYr/2q78NP9NlQ5XgXkql1JpItzs05QHYFEWhykmC56k36qqcJvhu2m61zbW6d72WTk2yc+5UXz+keK0k3P3d2Ptl014/OmkkEvNs5RI2dWLuMc/xPHmuIvVKUeRgMxzq9nGzWgT9uWq1DCZz85P6/kp9sPn+z/1BrOz3kM/8+X8MT2NbLXXaPwm6jdLo+Eii66RcxrqtVPTnUscf6qLb8cQD2NZbnY4qr7XARApziyEWYDM+psv75yfApoxNK+eu6H714uke2PSG+vknq05TG9+/uZnhecyDDDOs18IcK3LHJupzFwWOO4lx/XSI7XPuyz+1qz6bJAm3Q5D/KIrCjgbvLX/r83/I8Vl9C3ZO2z6mbbx5NpTNeFjBASqpN1S5NIY5cEPZ3E+KtxzN/RTOFJ+YufjoMUzj9/ATD6vyzDzmQS+XMYtDGvUF0yGYCExFwRsu9DhbFDjORmODJxaxbfhOCfZaguN1bf9Hd9VnRUROPnoYKu/69ZuqnA1xHo2mrsrlCtiUzfwXnPXBWEPPv406+m23qyf7Tr+J56nox5ir4fq52dLPsbyGPlGqaJ90lnkyWcdO8bnv+rD+XILrg6lpXUcnHzsMNuWKnqPzooM2qa7XmlP3xkSqZayPakX34+j1EbNeGQywQ37nj/xz12/5yxkhhBBCCCGEjAB8OSOEEEIIIYSQEYAvZ4QQQgghhBAyAvDljBBCCCGEEEJGgF0NCNJcR3Xg2k0ttLt+4wrYLC/fUOX11TbYpKk+99EHHgCbqSmtu9vc3ASbq1euq3KzieJJG5Cj5ATbqNV18JFKFUWHJROQo3CCPXgBQYZGVJgNUYhoBaY2GIqISKOhxZt5xPP0+1qE2uuiuLXT2dLncUSYqRFKe4FfHnn8kCp/6jMnwWa3saJPEfxGw/uGI5hgL6kjjLai0yTimfLcBG0pYRvlJmDOoIVBbOZNsJUrKyi4jeYeqyUUIJdL2q8q1ZpzHhS9xqCfwxPYBnP9PEdfy02cgmoVRfGSmyAqiXOeRAuHnZghktk280IWBO0fhWDdE0K+NYkbReruNjaWRSlzzpPrMdMLEDTsm0ACLZzDShUzPnurp0JfPxtgECHJ9Hh07uJVMNm6dEGVjz/+KNgccNY4Y7OzqhycwGipiVLihQ+K0QYEwTHUDyQCJ7q7zT0Q7LVGJAVUEtAJUlPnAy9QhLn9LMM
|
||
|
"text/plain": [
|
||
|
"<Figure size 1080x576 with 10 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# Display model predictions for a sample of test images\n",
|
||
|
"\n",
|
||
|
"plt.figure(figsize=(15,8))\n",
|
||
|
"inx = np.random.choice(test_data.shape[0], 18, replace=False)\n",
|
||
|
"images, labels = next(iter_test_dataset)\n",
|
||
|
"probs = cifar_model(tf.reduce_mean(tf.cast(images, tf.float32), axis=-1, keepdims=True) / 255.)\n",
|
||
|
"preds = np.argmax(probs, axis=1)\n",
|
||
|
"for n in range(10):\n",
|
||
|
" ax = plt.subplot(2, 5, n+1)\n",
|
||
|
" plt.imshow(images[n])\n",
|
||
|
" plt.title(cifar_labels[cifar_classes[np.where(labels[n].numpy() == 1.0)[0][0]]])\n",
|
||
|
" plt.text(0, 35, \"Model prediction: {}\".format(cifar_labels[cifar_classes[preds[n]]]))\n",
|
||
|
" plt.axis('off')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Congratulations for completing this programming assignment! In the next week of the course we will learn to develop models for sequential data."
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"coursera": {
|
||
|
"course_slug": "tensor-flow-2-2",
|
||
|
"graded_item_id": "3hWzU",
|
||
|
"launcher_item_id": "AStQh"
|
||
|
},
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.1"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 2
|
||
|
}
|