import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
# 防止报错:cuDNN launch failure : input shape ([4,64,8,8]),参考自https://blog.csdn.net/qq_35037684/article/details/106734086,cuDNN launch failure
os.environ['TF_FORCE_GPU_ALLOW_GROWTH']='true'
# 数据集
def load_data():
(X_train,_),(_,_)=tf.keras.datasets.mnist.load_data()
X_train=(X_train.astype(np.float32)-127.5)/127.5
X_train=np.expand_dims(X_train,axis=3)
return X_train
# 生成器
def build_generator(noise_shape=(100,)):
input=tf.keras.layers.Input(noise_shape)
x=tf.keras.layers.Dense(128*7*7,activation='relu')(input)
x=tf.keras.layers.Reshape((7,7,128))(x)
x=tf.keras.layers.BatchNormalization(momentum=0.8)(x)
x=tf.keras.layers.UpSampling2D()(x)
x=tf.keras.layers.Conv2D(128,kernel_size=3,padding='same')(x)
x=tf.keras.layers.Activation('relu')(x)
x=tf.keras.layers.BatchNormalization(momentum=0.8)(x)
x=tf.keras.layers.UpSampling2D()(x)
x=tf.keras.layers.Conv2D(64,kernel_size=3,padding='same')(x)
x=tf.keras.layers.Activation('relu')(x)
x=tf.keras.layers.BatchNormalization(momentum=0.8)(x)
x=tf.keras.layers.Conv2D(1,kernel_size=3,padding='same')(x)
out=tf.keras.layers.Activation('tanh')(x)
model=tf.keras.models.Model(input,out)
print('-- Generator --')
model.summary()
return model
# 判别器
def build_discriminator(img_shape):
input=tf.keras.layers.Input(img_shape)
x=tf.keras.layers.Conv2D(32,kernel_size=3,strides=2,padding='same')(input)
x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x=tf.keras.layers.Dropout(0.25)(x)
x=tf.keras.layers.Conv2D(64,kernel_size=3,strides=2,padding='same')(x)
x=tf.keras.layers.ZeroPadding2D(padding=((0,1),(0,1)))(x)
x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x=tf.keras.layers.Dropout(0.25)(x)
x=tf.keras.layers.BatchNormalization(momentum=0.8)(x)
x=tf.keras.layers.Conv2D(128,kernel_size=3,strides=2,padding='same')(x)
x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x=tf.keras.layers.Dropout(0.25)(x)
x=tf.keras.layers.BatchNormalization(momentum=0.8)(x)
x=tf.keras.layers.Conv2D(256,kernel_size=3,strides=1,padding='same')(x)
x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x=tf.keras.layers.Dropout(0.25)(x)
x=tf.keras.layers.Flatten()(x)
out=tf.keras.layers.Dense(1,activation='sigmoid')(x)
model=tf.keras.models.Model(input,out)
print('-- Discriminator --')
model.summary()
return model
# 在训练的过程中生成图片的示例
def save_imgs(generator,epoch,batch):
r,c=5,5
noise=np.random.normal(0,1,(r*c,100))
gen_imgs=generator.predict(noise)
gen_imgs=0.5*gen_imgs+0.5
fig,axs=plt.subplots(r,c)
cnt=0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt,:,:,0],cmap='gray')
axs[i,j].axis('off')
cnt+=1
fig.savefig('test/mnist_%d_%d.jpg'%(epoch,batch))
plt.close()
X_train=load_data()
# 构建三个模型
discriminator=build_discriminator(img_shape=(28,28,1))
generator=build_generator()
gen_optimizer=tf.keras.optimizers.Adam(lr=0.0002,beta_1=0.5)
disc_optimizer=tf.keras.optimizers.Adam(lr=0.0002,beta_1=0.5)
discriminator.compile(loss='binary_crossentropy',optimizer=disc_optimizer,metrics=['accuracy'])
generator.compile(loss='binary_crossentropy',optimizer=gen_optimizer)
z=tf.keras.layers.Input(shape=(100,))
img=generator(z)
discriminator.trainalbe=False
real=discriminator(img)
combined=tf.keras.models.Model(z,real)
combined.compile(loss='binary_crossentropy',optimizer=gen_optimizer)
batch_size=8
epochs=10
num_examples=X_train.shape[0]
num_batches=int(num_examples/float(batch_size))
half_batch=int(batch_size/2)
# 训练,首先训练判别器,假图片和真图片喂给判别器生成真实的标签,之后判别器停止训练,识别生成的假图片,训练生成器使得它生成的图片能被判别器判别为真图片
for epoch in range(epochs+1):
for batch in range(num_batches):
#noise images for the batch
noise=np.random.normal(0,1,(half_batch,100))
fake_images=generator.predict(noise)
fake_labels=np.zeros((half_batch,1))
# real images for batch
idx=np.random.randint(0,X_train.shape[0],half_batch)
real_images=X_train[idx]
real_labels=np.ones((half_batch,1))
# Train the discriminator (real classified as ones and generated as zeros)
d_loss_real=discriminator.train_on_batch(real_images,real_labels)
d_loss_fake=discriminator.train_on_batch(fake_images,fake_labels)
d_loss=0.5*np.add(d_loss_real,d_loss_fake)
noise=np.random.normal(0,1,(batch_size,100))
# Train the generator
g_loss=combined.train_on_batch(noise,np.ones((batch_size,1)))
# Plot the progress
print('Epoch %d Batch %d/%d [D loss:%f,acc.: %.2f%%] [G loss:%f]'%(epoch,batch,num_batches,d_loss[0],100*d_loss[1],g_loss))
if batch%50==0:
save_imgs(generator,epoch,batch)
参考链接:Generative Adversarial Networks: Generate images using Keras GAN [Tutorial]
本文创建于2022.10.29/22.14,修改于2022.10.29/22.15