回到首页

用tensorflow实现生成类似MNIST图片的GAN

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

# 防止报错:cuDNN launch failure : input shape ([4,64,8,8]),参考自https://blog.csdn.net/qq_35037684/article/details/106734086,cuDNN launch failure
os.environ['TF_FORCE_GPU_ALLOW_GROWTH']='true'

# 数据集
def load_data():
	(X_train,_),(_,_)=tf.keras.datasets.mnist.load_data()
	X_train=(X_train.astype(np.float32)-127.5)/127.5
	X_train=np.expand_dims(X_train,axis=3)
	return X_train

# 生成器
def build_generator(noise_shape=(100,)):
	input=tf.keras.layers.Input(noise_shape)
	x=tf.keras.layers.Dense(128*7*7,activation='relu')(input)
	x=tf.keras.layers.Reshape((7,7,128))(x)
	x=tf.keras.layers.BatchNormalization(momentum=0.8)(x)
	x=tf.keras.layers.UpSampling2D()(x)
	x=tf.keras.layers.Conv2D(128,kernel_size=3,padding='same')(x)
	x=tf.keras.layers.Activation('relu')(x)
	x=tf.keras.layers.BatchNormalization(momentum=0.8)(x)
	x=tf.keras.layers.UpSampling2D()(x)
	x=tf.keras.layers.Conv2D(64,kernel_size=3,padding='same')(x)
	x=tf.keras.layers.Activation('relu')(x)
	x=tf.keras.layers.BatchNormalization(momentum=0.8)(x)
	x=tf.keras.layers.Conv2D(1,kernel_size=3,padding='same')(x)
	out=tf.keras.layers.Activation('tanh')(x)
	model=tf.keras.models.Model(input,out)
	print('-- Generator --')
	model.summary()
	return model

# 判别器
def build_discriminator(img_shape):
	input=tf.keras.layers.Input(img_shape)
	x=tf.keras.layers.Conv2D(32,kernel_size=3,strides=2,padding='same')(input)
	x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
	x=tf.keras.layers.Dropout(0.25)(x)
	x=tf.keras.layers.Conv2D(64,kernel_size=3,strides=2,padding='same')(x)
	x=tf.keras.layers.ZeroPadding2D(padding=((0,1),(0,1)))(x)
	x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
	x=tf.keras.layers.Dropout(0.25)(x)
	x=tf.keras.layers.BatchNormalization(momentum=0.8)(x)
	x=tf.keras.layers.Conv2D(128,kernel_size=3,strides=2,padding='same')(x)
	x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
	x=tf.keras.layers.Dropout(0.25)(x)
	x=tf.keras.layers.BatchNormalization(momentum=0.8)(x)
	x=tf.keras.layers.Conv2D(256,kernel_size=3,strides=1,padding='same')(x)
	x=tf.keras.layers.LeakyReLU(alpha=0.2)(x)
	x=tf.keras.layers.Dropout(0.25)(x)
	x=tf.keras.layers.Flatten()(x)
	out=tf.keras.layers.Dense(1,activation='sigmoid')(x)
	model=tf.keras.models.Model(input,out)
	print('-- Discriminator --')
	model.summary()
	return model

# 在训练的过程中生成图片的示例
def save_imgs(generator,epoch,batch):
	r,c=5,5
	noise=np.random.normal(0,1,(r*c,100))
	gen_imgs=generator.predict(noise)
	gen_imgs=0.5*gen_imgs+0.5
	fig,axs=plt.subplots(r,c)
	cnt=0
	for i in range(r):
		for j in range(c):
			axs[i,j].imshow(gen_imgs[cnt,:,:,0],cmap='gray')
			axs[i,j].axis('off')
			cnt+=1
	fig.savefig('test/mnist_%d_%d.jpg'%(epoch,batch))
	plt.close()

X_train=load_data()

# 构建三个模型
discriminator=build_discriminator(img_shape=(28,28,1))
generator=build_generator()
gen_optimizer=tf.keras.optimizers.Adam(lr=0.0002,beta_1=0.5)
disc_optimizer=tf.keras.optimizers.Adam(lr=0.0002,beta_1=0.5)
discriminator.compile(loss='binary_crossentropy',optimizer=disc_optimizer,metrics=['accuracy'])
generator.compile(loss='binary_crossentropy',optimizer=gen_optimizer)
z=tf.keras.layers.Input(shape=(100,))
img=generator(z)
discriminator.trainalbe=False
real=discriminator(img)
combined=tf.keras.models.Model(z,real)
combined.compile(loss='binary_crossentropy',optimizer=gen_optimizer)

batch_size=8
epochs=10

num_examples=X_train.shape[0]
num_batches=int(num_examples/float(batch_size))
half_batch=int(batch_size/2)

# 训练,首先训练判别器,假图片和真图片喂给判别器生成真实的标签,之后判别器停止训练,识别生成的假图片,训练生成器使得它生成的图片能被判别器判别为真图片
for epoch in range(epochs+1):
	for batch in range(num_batches):
		#noise images for the batch
		noise=np.random.normal(0,1,(half_batch,100))
		fake_images=generator.predict(noise)
		fake_labels=np.zeros((half_batch,1))
		# real images for batch
		idx=np.random.randint(0,X_train.shape[0],half_batch)
		real_images=X_train[idx]
		real_labels=np.ones((half_batch,1))
		# Train the discriminator (real classified as ones and generated as zeros)
		d_loss_real=discriminator.train_on_batch(real_images,real_labels)
		d_loss_fake=discriminator.train_on_batch(fake_images,fake_labels)
		d_loss=0.5*np.add(d_loss_real,d_loss_fake)
		noise=np.random.normal(0,1,(batch_size,100))
		# Train the generator
		g_loss=combined.train_on_batch(noise,np.ones((batch_size,1)))
		# Plot the progress
		print('Epoch %d Batch %d/%d [D loss:%f,acc.: %.2f%%] [G loss:%f]'%(epoch,batch,num_batches,d_loss[0],100*d_loss[1],g_loss))
		if batch%50==0:
			save_imgs(generator,epoch,batch)

参考链接:Generative Adversarial Networks: Generate images using Keras GAN [Tutorial]

本文创建于2022.10.29/22.14,修改于2022.10.29/22.15