[논문읽기] 09-1. LSGAN MNIST with Keras
📲PROJECT/논문읽기

[논문읽기] 09-1. LSGAN MNIST with Keras

728x90
반응형

1. Load Modules

In [13]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import *
from keras.models import *
from keras.optimizers import *

import matplotlib.pyplot as plt
import sys
import numpy as np
import tensorflow as tf

2. Build Network

In [59]:
class LSGAN():
    
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows,self.img_cols, self.channels)
        self.latent_dim = 100
        
        optimizer = Adam(0.0002,0.5)
        
        # Build and Compile the discriminator
        self.discriminator = self.build_discriminator()
        # Loss = Mean Squared Error
        self.discriminator.compile(loss='mse',
                                  optimizer=optimizer,
                                  metrics=['accuracy'])
        
        # Build Generator
        self.generator = self.build_generator()
        
        # Make Noise
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)
        
        # For the combined model only train the generator
        self.discriminator.trainable = False
        
        # valid takes generated imgs as input and determines validity
        valid = self.discriminator(img)
        
        # The Combined model ( G + D)
        # Trains generator to fool discriminator
        self.combined = Model(z,valid)
        
        # Opitimize with MSE
        self.combined.compile(loss='mse',optimizer=optimizer)
        
    def build_generator(self):
        
        model = Sequential()
        
        model.add(Dense(256,input_dim=self.latent_dim))
        model.add(ReLU())
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dense(512))
        model.add(ReLU())
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dense(1024))
        model.add(ReLU())
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))   #784
        model.add(Reshape(self.img_shape))
        
        model.summary()
        
        noise = Input(shape=(self.latent_dim,))
        img = model(noise)
        
        return Model(noise,img)
    
    def build_discriminator(self):
        
        model = Sequential()
        
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(0.2))
        
        model.add(Dense(256))
        model.add(LeakyReLU(0.2))
        model.add(Dense(1))   # no Softmax
        model.summary()
        
        img = Input(shape=self.img_shape)
        validity = model(img)
        
        return Model(img,validity)
    
    def train(self, epochs, batch_size=128,sample_interval=50):
        # Load MNIST
        
        (X_train,_),(_,_) = mnist.load_data()
        
        # Rescale
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train,axis=3) # 다 1열로 펼치기
        
        # Adversarial Ground-Truth ( real value )
        real = np.ones((batch_size,1))
        fake = np.zeros((batch_size,1))
        
        for epoch in range(epochs):
            
            """Train Discriminator"""
            
            # Select random batch of images
            idx = np.random.randint(0, X_train.shape[0],batch_size)  # 0 ~ X_train.shape[0]까지 batch_size 만큼
            imgs = X_train[idx]
            
            # Sample noise as generator input
            noise = np.random.normal(0,1,(batch_size,self.latent_dim))  # batch_size 행 / latent_dim 열
            
            # Generate a batch of new imgs
            gen_imgs = self.generator.predict(noise)
            
            # Train the Discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs,real)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs,fake)
            d_loss = 0.5 * np.add(d_loss_fake,d_loss_real)
            
            
            """Train Generator"""
            
            g_loss = self.combined.train_on_batch(noise,real)
            
            # Plot the progress
            print("Epoch : {0} / D_Loss : {1}, ACC : {2:.2f} / G_Loss : {3}".format(epoch,d_loss[0],100*d_loss[1],g_loss))
            
            # If at save interval => Save img samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
                
    def sample_images(self,epoch):
        r,c = 5,5
        noise = np.random.normal(0,1,(r*c,self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale
        gen_imgs = gen_imgs * 0.5 + 0.5

        fig, axs = plt.subplots(r,c)
        cnt = 0

        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0],cmap='gray')
                axs[i,j].axis('off')
                cnt+=1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()    

3. Train the MNIST

In [60]:
if __name__ == '__main__':
    lsgan = LSGAN()
    lsgan.train(epochs=30000,batch_size=64, sample_interval=200)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_13 (Flatten)         (None, 784)               0         
_________________________________________________________________
dense_82 (Dense)             (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_31 (LeakyReLU)   (None, 512)               0         
_________________________________________________________________
dense_83 (Dense)             (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_32 (LeakyReLU)   (None, 256)               0         
_________________________________________________________________
dense_84 (Dense)             (None, 1)                 257       
=================================================================
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_85 (Dense)             (None, 256)               25856     
_________________________________________________________________
re_lu_29 (ReLU)              (None, 256)               0         
_________________________________________________________________
batch_normalization_34 (Batc (None, 256)               1024      
_________________________________________________________________
dense_86 (Dense)             (None, 512)               131584    
_________________________________________________________________
re_lu_30 (ReLU)              (None, 512)               0         
_________________________________________________________________
batch_normalization_35 (Batc (None, 512)               2048      
_________________________________________________________________
dense_87 (Dense)             (None, 1024)              525312    
_________________________________________________________________
re_lu_31 (ReLU)              (None, 1024)              0         
_________________________________________________________________
batch_normalization_36 (Batc (None, 1024)              4096      
_________________________________________________________________
dense_88 (Dense)             (None, 784)               803600    
_________________________________________________________________
reshape_12 (Reshape)         (None, 28, 28, 1)         0         
=================================================================
Total params: 1,493,520
Trainable params: 1,489,936
Non-trainable params: 3,584
_________________________________________________________________
Epoch : 0 / D_Loss : 1.38656747341156, ACC : 50.00 / G_Loss : 1.2923169136047363
Epoch : 1 / D_Loss : 1.471015453338623, ACC : 37.50 / G_Loss : 1.1586521863937378
Epoch : 2 / D_Loss : 0.24263668060302734, ACC : 68.75 / G_Loss : 1.1366199254989624
Epoch : 3 / D_Loss : 0.19426219165325165, ACC : 75.78 / G_Loss : 1.0721166133880615
Epoch : 4 / D_Loss : 0.19625677168369293, ACC : 71.88 / G_Loss : 1.0149812698364258
Epoch : 5 / D_Loss : 0.1108073964715004, ACC : 85.94 / G_Loss : 0.9186646938323975
Epoch : 6 / D_Loss : 0.16535161435604095, ACC : 79.69 / G_Loss : 0.9249618053436279
Epoch : 7 / D_Loss : 0.19682630896568298, ACC : 69.53 / G_Loss : 1.0650662183761597
Epoch : 8 / D_Loss : 0.12016656249761581, ACC : 89.06 / G_Loss : 1.1257177591323853
Epoch : 9 / D_Loss : 0.1623956561088562, ACC : 80.47 / G_Loss : 1.0617823600769043
Epoch : 10 / D_Loss : 0.18658095598220825, ACC : 74.22 / G_Loss : 0.9005378484725952
Epoch : 11 / D_Loss : 0.18756097555160522, ACC : 71.09 / G_Loss : 1.1381738185882568
Epoch : 12 / D_Loss : 0.2246052324771881, ACC : 64.06 / G_Loss : 0.9778075218200684
Epoch : 13 / D_Loss : 0.3742966651916504, ACC : 46.09 / G_Loss : 1.132008671760559
Epoch : 14 / D_Loss : 0.504763662815094, ACC : 42.19 / G_Loss : 1.0361526012420654
Epoch : 15 / D_Loss : 0.49995723366737366, ACC : 49.22 / G_Loss : 1.0365160703659058
Epoch : 16 / D_Loss : 0.33904188871383667, ACC : 49.22 / G_Loss : 0.928471028804779
Epoch : 17 / D_Loss : 0.15023961663246155, ACC : 79.69 / G_Loss : 1.0730818510055542
Epoch : 18 / D_Loss : 0.12589681148529053, ACC : 85.94 / G_Loss : 0.9365036487579346
Epoch : 19 / D_Loss : 0.08975698053836823, ACC : 89.06 / G_Loss : 0.9358891248703003
Epoch : 20 / D_Loss : 0.12051121145486832, ACC : 84.38 / G_Loss : 1.0348204374313354
Epoch : 21 / D_Loss : 0.1128842681646347, ACC : 85.94 / G_Loss : 1.1916886568069458
Epoch : 22 / D_Loss : 0.09648407995700836, ACC : 88.28 / G_Loss : 1.0989240407943726
Epoch : 23 / D_Loss : 0.10530316084623337, ACC : 85.16 / G_Loss : 1.0076613426208496
Epoch : 24 / D_Loss : 0.08209355175495148, ACC : 92.97 / G_Loss : 0.9922367930412292
Epoch : 25 / D_Loss : 0.09474079310894012, ACC : 88.28 / G_Loss : 1.0379505157470703
Epoch : 26 / D_Loss : 0.12563005089759827, ACC : 82.81 / G_Loss : 1.0778446197509766
Epoch : 27 / D_Loss : 0.17440742254257202, ACC : 74.22 / G_Loss : 1.0659027099609375
Epoch : 28 / D_Loss : 0.23970657587051392, ACC : 75.00 / G_Loss : 1.1508634090423584
Epoch : 29 / D_Loss : 0.2289704531431198, ACC : 64.06 / G_Loss : 1.2160108089447021
Epoch : 30 / D_Loss : 0.1998881995677948, ACC : 73.44 / G_Loss : 1.0289862155914307
Epoch : 31 / D_Loss : 0.2103518545627594, ACC : 64.06 / G_Loss : 0.9708253145217896
Epoch : 32 / D_Loss : 0.2507723569869995, ACC : 66.41 / G_Loss : 0.9974524974822998
Epoch : 33 / D_Loss : 0.17648330330848694, ACC : 75.00 / G_Loss : 1.2036888599395752
Epoch : 34 / D_Loss : 0.16105754673480988, ACC : 78.12 / G_Loss : 1.0303584337234497
Epoch : 35 / D_Loss : 0.14140692353248596, ACC : 82.81 / G_Loss : 1.200460433959961
Epoch : 36 / D_Loss : 0.1378239393234253, ACC : 84.38 / G_Loss : 1.0274330377578735
Epoch : 37 / D_Loss : 0.11803567409515381, ACC : 84.38 / G_Loss : 1.062586784362793
Epoch : 38 / D_Loss : 0.09319763630628586, ACC : 90.62 / G_Loss : 1.0679075717926025
Epoch : 39 / D_Loss : 0.09631732106208801, ACC : 90.62 / G_Loss : 1.170445442199707
Epoch : 40 / D_Loss : 0.11304035782814026, ACC : 90.62 / G_Loss : 1.0799174308776855
Epoch : 41 / D_Loss : 0.09875467419624329, ACC : 89.84 / G_Loss : 1.0230785608291626
Epoch : 42 / D_Loss : 0.09628114104270935, ACC : 87.50 / G_Loss : 1.0461076498031616
Epoch : 43 / D_Loss : 0.08089073747396469, ACC : 92.19 / G_Loss : 1.0617423057556152
Epoch : 44 / D_Loss : 0.11302421987056732, ACC : 82.81 / G_Loss : 1.0359588861465454
Epoch : 45 / D_Loss : 0.0888390988111496, ACC : 92.97 / G_Loss : 1.0119073390960693
Epoch : 46 / D_Loss : 0.09964063763618469, ACC : 87.50 / G_Loss : 0.9884819984436035
Epoch : 47 / D_Loss : 0.08516378700733185, ACC : 89.84 / G_Loss : 1.1223119497299194
Epoch : 48 / D_Loss : 0.08760619163513184, ACC : 87.50 / G_Loss : 1.0157421827316284
Epoch : 49 / D_Loss : 0.08505186438560486, ACC : 89.84 / G_Loss : 1.1687366962432861
Epoch : 50 / D_Loss : 0.07813139259815216, ACC : 92.19 / G_Loss : 1.0675700902938843
Epoch : 51 / D_Loss : 0.0974535271525383, ACC : 90.62 / G_Loss : 1.0866228342056274
Epoch : 52 / D_Loss : 0.10546877980232239, ACC : 87.50 / G_Loss : 1.075911283493042
Epoch : 53 / D_Loss : 0.07595279812812805, ACC : 91.41 / G_Loss : 1.028899073600769
Epoch : 54 / D_Loss : 0.092352956533432, ACC : 89.84 / G_Loss : 1.049062967300415
Epoch : 55 / D_Loss : 0.0693572610616684, ACC : 92.19 / G_Loss : 1.0601452589035034
Epoch : 56 / D_Loss : 0.06908569484949112, ACC : 90.62 / G_Loss : 1.1607345342636108
Epoch : 57 / D_Loss : 0.09064450114965439, ACC : 89.06 / G_Loss : 1.1646032333374023
Epoch : 58 / D_Loss : 0.07549747824668884, ACC : 91.41 / G_Loss : 1.079940915107727
Epoch : 59 / D_Loss : 0.11252069473266602, ACC : 86.72 / G_Loss : 0.7953663468360901
Epoch : 60 / D_Loss : 0.15162590146064758, ACC : 81.25 / G_Loss : 0.9283652305603027
Epoch : 61 / D_Loss : 0.2712584137916565, ACC : 52.34 / G_Loss : 1.0047624111175537
Epoch : 62 / D_Loss : 0.3261421322822571, ACC : 48.44 / G_Loss : 1.1719380617141724
Epoch : 63 / D_Loss : 0.39710119366645813, ACC : 36.72 / G_Loss : 1.0852932929992676
Epoch : 64 / D_Loss : 0.31775766611099243, ACC : 50.00 / G_Loss : 1.138929009437561
Epoch : 65 / D_Loss : 0.18907782435417175, ACC : 73.44 / G_Loss : 0.9692761898040771
Epoch : 66 / D_Loss : 0.09923519939184189, ACC : 90.62 / G_Loss : 1.0372819900512695
Epoch : 67 / D_Loss : 0.05164875090122223, ACC : 97.66 / G_Loss : 1.1107170581817627
Epoch : 68 / D_Loss : 0.0822988897562027, ACC : 86.72 / G_Loss : 0.9042568802833557
Epoch : 69 / D_Loss : 0.1073434054851532, ACC : 89.06 / G_Loss : 1.0075749158859253
Epoch : 70 / D_Loss : 0.07640635967254639, ACC : 92.97 / G_Loss : 1.0619683265686035
Epoch : 71 / D_Loss : 0.056963831186294556, ACC : 92.97 / G_Loss : 1.0264363288879395
Epoch : 72 / D_Loss : 0.04847197234630585, ACC : 96.88 / G_Loss : 1.0342180728912354
Epoch : 73 / D_Loss : 0.06820496171712875, ACC : 93.75 / G_Loss : 1.2274476289749146
Epoch : 74 / D_Loss : 0.06770201772451401, ACC : 92.97 / G_Loss : 1.0010666847229004
Epoch : 75 / D_Loss : 0.07353506237268448, ACC : 89.06 / G_Loss : 0.9656128883361816
Epoch : 76 / D_Loss : 0.05640412122011185, ACC : 93.75 / G_Loss : 0.9802725911140442
Epoch : 77 / D_Loss : 0.05597372353076935, ACC : 96.09 / G_Loss : 1.050729513168335
Epoch : 78 / D_Loss : 0.05323430523276329, ACC : 95.31 / G_Loss : 1.092334508895874
Epoch : 79 / D_Loss : 0.05248890444636345, ACC : 95.31 / G_Loss : 1.113814353942871
Epoch : 80 / D_Loss : 0.05707560479640961, ACC : 94.53 / G_Loss : 1.0364426374435425
Epoch : 81 / D_Loss : 0.06313268095254898, ACC : 94.53 / G_Loss : 1.0810896158218384
Epoch : 82 / D_Loss : 0.06065795570611954, ACC : 93.75 / G_Loss : 1.017265796661377
Epoch : 83 / D_Loss : 0.06461706757545471, ACC : 94.53 / G_Loss : 1.0428777933120728
Epoch : 84 / D_Loss : 0.074702188372612, ACC : 89.84 / G_Loss : 0.8876435160636902
Epoch : 85 / D_Loss : 0.05921899527311325, ACC : 95.31 / G_Loss : 1.1065373420715332
Epoch : 86 / D_Loss : 0.048591092228889465, ACC : 96.09 / G_Loss : 1.0538147687911987
Epoch : 87 / D_Loss : 0.0629400759935379, ACC : 94.53 / G_Loss : 0.9884475469589233
Epoch : 88 / D_Loss : 0.06099383160471916, ACC : 93.75 / G_Loss : 0.9551231861114502
Epoch : 89 / D_Loss : 0.0605444461107254, ACC : 92.97 / G_Loss : 0.9171189069747925
Epoch : 90 / D_Loss : 0.0437452495098114, ACC : 97.66 / G_Loss : 1.1051135063171387
Epoch : 91 / D_Loss : 0.0673409104347229, ACC : 93.75 / G_Loss : 0.9254881143569946
Epoch : 92 / D_Loss : 0.06648625433444977, ACC : 91.41 / G_Loss : 0.9525142908096313
Epoch : 93 / D_Loss : 0.07335928827524185, ACC : 92.19 / G_Loss : 1.0110105276107788
Epoch : 94 / D_Loss : 0.06168702244758606, ACC : 94.53 / G_Loss : 1.1031090021133423
Epoch : 95 / D_Loss : 0.05616850033402443, ACC : 96.88 / G_Loss : 1.0349249839782715
Epoch : 96 / D_Loss : 0.05545426160097122, ACC : 92.97 / G_Loss : 1.0943102836608887
Epoch : 97 / D_Loss : 0.04556863382458687, ACC : 96.09 / G_Loss : 0.9863028526306152
Epoch : 98 / D_Loss : 0.0656975507736206, ACC : 92.97 / G_Loss : 1.1613688468933105
Epoch : 99 / D_Loss : 0.06176591292023659, ACC : 94.53 / G_Loss : 1.0320043563842773
Epoch : 100 / D_Loss : 0.05207807943224907, ACC : 96.09 / G_Loss : 0.9393661022186279
:
:
:
Epoch : 29994 / D_Loss : 0.231647327542305, ACC : 65.62 / G_Loss : 0.30790817737579346
Epoch : 29995 / D_Loss : 0.23614759743213654, ACC : 63.28 / G_Loss : 0.319231241941452
Epoch : 29996 / D_Loss : 0.2510605454444885, ACC : 52.34 / G_Loss : 0.3026719391345978
Epoch : 29997 / D_Loss : 0.24431605637073517, ACC : 57.81 / G_Loss : 0.32631996273994446
Epoch : 29998 / D_Loss : 0.24833302199840546, ACC : 51.56 / G_Loss : 0.32033103704452515
Epoch : 29999 / D_Loss : 0.2420063018798828, ACC : 53.91 / G_Loss : 0.3198980987071991

4. Make GIF

In [153]:
# 파일이름의 숫자 5자리로 만들어서 정렬하기
import re
path = os.getcwd()
filename = glob.glob(path + '/images/mnist_*.png')
for i in filename:
    a = i.split('/')[-1]
    leng = re.findall('\d+',a)[0]
    leng = "0"*(5-len(leng))+leng
    os.rename(i,leng+".png",path+'/images')
In [160]:
filename_new = glob.glob(path + '/images/*.png')
filename_new = sorted(filename_new)
filename_new[:5]
# 순서대로 정렬된 것을 볼 수 있다. 
Out[160]:
['/Users/charming/Python/0_Paper_Review/09. LSGAN/images/00000.png',
 '/Users/charming/Python/0_Paper_Review/09. LSGAN/images/00200.png',
 '/Users/charming/Python/0_Paper_Review/09. LSGAN/images/00400.png',
 '/Users/charming/Python/0_Paper_Review/09. LSGAN/images/00600.png',
 '/Users/charming/Python/0_Paper_Review/09. LSGAN/images/00800.png']
In [162]:
generated_image_array = [imageio.imread(generated_image) for generated_image in filename_new]
imageio.mimsave('LSGAN_MNIST.gif', generated_image_array, fps=15)





< LSGAN으로 MNIST를 생성하는 과정 >

728x90
반응형