Generate Synthetic Images with DCGANs in Keras
Import Libraries¶
In [3]:
import sys
sys.path.insert(1, '/kaggle/input/module-of-plot-utils')
In [4]:
%matplotlib inline
import tensorflow as tf
from tensorflow import keras
import numpy as np
import plot_utils
import matplotlib.pyplot as plt
from tqdm import tqdm
print('Tensorflow version:', tf.__version__)
/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.3 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
Tensorflow version: 2.13.0
Load and Preprocess the Data¶
In [5]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype(np.float32) / 255.0
x_test = x_test.astype(np.float32) / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz 29515/29515 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz 26421880/26421880 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz 5148/5148 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz 4422102/4422102 [==============================] - 0s 0us/step
In [6]:
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(x_train[i], cmap=plt.cm.binary)
plt.show()
Create Batches of Training Data¶
In [7]:
batch_size = 32
# This dataset fills a buffer with buffer_size elements,
#then randomly samples elements from this buffer, replacing the selected elements with new elements.
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(1000)
#Combines consecutive elements of this dataset into batches.
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)
#Creates a Dataset that prefetches elements from this dataset
Build the Generator Network for DCGAN¶
In [8]:
num_features = 100
generator = keras.models.Sequential([
keras.layers.Dense(7 * 7 * 128, input_shape=[num_features]),
keras.layers.Reshape([7, 7, 128]),
keras.layers.BatchNormalization(),
keras.layers.Conv2DTranspose(64, (5,5), (2,2), padding="same", activation="selu"),
keras.layers.BatchNormalization(),
keras.layers.Conv2DTranspose(1, (5,5), (2,2), padding="same", activation="tanh"),
])
In [9]:
noise = tf.random.normal(shape=[1, num_features])
generated_images = generator(noise, training=False)
plot_utils.show(generated_images, 1)
Build the Discriminator Network for DCGAN¶
In [10]:
discriminator = keras.models.Sequential([
keras.layers.Conv2D(64, (5,5), (2,2), padding="same", input_shape=[28, 28, 1]),
keras.layers.LeakyReLU(0.2),
keras.layers.Dropout(0.3),
keras.layers.Conv2D(128, (5,5), (2,2), padding="same"),
keras.layers.LeakyReLU(0.2),
keras.layers.Dropout(0.3),
keras.layers.Flatten(),
keras.layers.Dense(1, activation='sigmoid')
])
In [11]:
decision = discriminator(generated_images)
print(decision)
tf.Tensor([[0.50265145]], shape=(1, 1), dtype=float32)
Compile the Deep Convolutional Generative Adversarial Network (DCGAN)¶
In [12]:
discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
discriminator.trainable = False
gan = keras.models.Sequential([generator, discriminator])
gan.compile(loss="binary_crossentropy", optimizer="rmsprop")
Define Training Procedure¶
In [13]:
from IPython import display
from tqdm import tqdm
seed = tf.random.normal(shape=[batch_size, 100])
In [14]:
from tqdm import tqdm
def train_dcgan(gan, dataset, batch_size, num_features, epochs=5):
generator, discriminator = gan.layers
for epoch in tqdm(range(epochs)):
print("Epoch {}/{}".format(epoch + 1, epochs))
for X_batch in dataset:
noise = tf.random.normal(shape=[batch_size, num_features])
generated_images = generator(noise)
X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
discriminator.trainable = True
discriminator.train_on_batch(X_fake_and_real, y1)
noise = tf.random.normal(shape=[batch_size, num_features])
y2 = tf.constant([[1.]] * batch_size)
discriminator.trainable = False
gan.train_on_batch(noise, y2)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator, epoch + 1, seed)
display.clear_output(wait=True)
generate_and_save_images(generator, epochs, seed)
In [15]:
## Source https://www.tensorflow.org/tutorials/generative/dcgan#create_a_gif
def generate_and_save_images(model, epoch, test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5, 5, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='binary')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
Train DCGAN¶
In [16]:
x_train_dcgan = x_train.reshape(-1, 28, 28, 1) * 2. - 1.
In [17]:
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(x_train_dcgan)
dataset = dataset.shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)
In [18]:
%%time
train_dcgan(gan, dataset, batch_size, num_features, epochs=50)
CPU times: user 47min 39s, sys: 3min, total: 50min 39s Wall time: 47min 14s
Generate Synthetic Images with DCGAN¶
In [19]:
noise = tf.random.normal(shape=[batch_size, num_features])
generated_images = generator(noise)
plot_utils.show(generated_images, 8)
In [20]:
## Source: https://www.tensorflow.org/tutorials/generative/dcgan#create_a_gif
import imageio
import glob
anim_file = 'dcgan.gif'
with imageio.get_writer(anim_file, mode='I') as writer:
filenames = glob.glob('image*.png')
filenames = sorted(filenames)
last = -1
for i,filename in enumerate(filenames):
frame = 2*(i)
if round(frame) > round(last):
last = frame
else:
continue
image = imageio.imread(filename)
writer.append_data(image)
image = imageio.imread(filename)
writer.append_data(image)
import IPython
display.Image(filename=anim_file)
/tmp/ipykernel_26/3041711877.py:17: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. image = imageio.imread(filename) /tmp/ipykernel_26/3041711877.py:19: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly. image = imageio.imread(filename)
Out[20]:
<IPython.core.display.Image object>
In [ ]: