728x90
반응형
In [163]:
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR) # warning 출력 방지
from keras.applications.resnet50 import ResNet50, decode_predictions
resnet = ResNet50()
다음은 ResNet50의 구조이다.
In [96]:
resnet.summary()
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_2 (InputLayer) (None, 224, 224, 3) 0 __________________________________________________________________________________________________ conv1_pad (ZeroPadding2D) (None, 230, 230, 3) 0 input_2[0][0] __________________________________________________________________________________________________ conv1 (Conv2D) (None, 112, 112, 64) 9472 conv1_pad[0][0] __________________________________________________________________________________________________ bn_conv1 (BatchNormalization) (None, 112, 112, 64) 256 conv1[0][0] __________________________________________________________________________________________________ activation_50 (Activation) (None, 112, 112, 64) 0 bn_conv1[0][0] __________________________________________________________________________________________________ pool1_pad (ZeroPadding2D) (None, 114, 114, 64) 0 activation_50[0][0] __________________________________________________________________________________________________ max_pooling2d_2 (MaxPooling2D) (None, 56, 56, 64) 0 pool1_pad[0][0] __________________________________________________________________________________________________ res2a_branch2a (Conv2D) (None, 56, 56, 64) 4160 max_pooling2d_2[0][0] __________________________________________________________________________________________________ bn2a_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2a_branch2a[0][0] __________________________________________________________________________________________________ activation_51 (Activation) (None, 56, 56, 64) 0 bn2a_branch2a[0][0] __________________________________________________________________________________________________ res2a_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_51[0][0] __________________________________________________________________________________________________ bn2a_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2a_branch2b[0][0] __________________________________________________________________________________________________ activation_52 (Activation) (None, 56, 56, 64) 0 bn2a_branch2b[0][0] __________________________________________________________________________________________________ res2a_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_52[0][0] __________________________________________________________________________________________________ res2a_branch1 (Conv2D) (None, 56, 56, 256) 16640 max_pooling2d_2[0][0] __________________________________________________________________________________________________ bn2a_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2a_branch2c[0][0] __________________________________________________________________________________________________ bn2a_branch1 (BatchNormalizatio (None, 56, 56, 256) 1024 res2a_branch1[0][0] __________________________________________________________________________________________________ add_17 (Add) (None, 56, 56, 256) 0 bn2a_branch2c[0][0] bn2a_branch1[0][0] __________________________________________________________________________________________________ activation_53 (Activation) (None, 56, 56, 256) 0 add_17[0][0] __________________________________________________________________________________________________ res2b_branch2a (Conv2D) (None, 56, 56, 64) 16448 activation_53[0][0] __________________________________________________________________________________________________ bn2b_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2b_branch2a[0][0] __________________________________________________________________________________________________ activation_54 (Activation) (None, 56, 56, 64) 0 bn2b_branch2a[0][0] __________________________________________________________________________________________________ res2b_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_54[0][0] __________________________________________________________________________________________________ bn2b_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2b_branch2b[0][0] __________________________________________________________________________________________________ activation_55 (Activation) (None, 56, 56, 64) 0 bn2b_branch2b[0][0] __________________________________________________________________________________________________ res2b_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_55[0][0] __________________________________________________________________________________________________ bn2b_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2b_branch2c[0][0] __________________________________________________________________________________________________ add_18 (Add) (None, 56, 56, 256) 0 bn2b_branch2c[0][0] activation_53[0][0] __________________________________________________________________________________________________ activation_56 (Activation) (None, 56, 56, 256) 0 add_18[0][0] __________________________________________________________________________________________________ res2c_branch2a (Conv2D) (None, 56, 56, 64) 16448 activation_56[0][0] __________________________________________________________________________________________________ bn2c_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2a[0][0] __________________________________________________________________________________________________ activation_57 (Activation) (None, 56, 56, 64) 0 bn2c_branch2a[0][0] __________________________________________________________________________________________________ res2c_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_57[0][0] __________________________________________________________________________________________________ bn2c_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2b[0][0] __________________________________________________________________________________________________ activation_58 (Activation) (None, 56, 56, 64) 0 bn2c_branch2b[0][0] __________________________________________________________________________________________________ res2c_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_58[0][0] __________________________________________________________________________________________________ bn2c_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2c_branch2c[0][0] __________________________________________________________________________________________________ add_19 (Add) (None, 56, 56, 256) 0 bn2c_branch2c[0][0] activation_56[0][0] __________________________________________________________________________________________________ activation_59 (Activation) (None, 56, 56, 256) 0 add_19[0][0] __________________________________________________________________________________________________ res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_59[0][0] __________________________________________________________________________________________________ bn3a_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2a[0][0] __________________________________________________________________________________________________ activation_60 (Activation) (None, 28, 28, 128) 0 bn3a_branch2a[0][0] __________________________________________________________________________________________________ res3a_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_60[0][0] __________________________________________________________________________________________________ bn3a_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2b[0][0] __________________________________________________________________________________________________ activation_61 (Activation) (None, 28, 28, 128) 0 bn3a_branch2b[0][0] __________________________________________________________________________________________________ res3a_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_61[0][0] __________________________________________________________________________________________________ res3a_branch1 (Conv2D) (None, 28, 28, 512) 131584 activation_59[0][0] __________________________________________________________________________________________________ bn3a_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3a_branch2c[0][0] __________________________________________________________________________________________________ bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512) 2048 res3a_branch1[0][0] __________________________________________________________________________________________________ add_20 (Add) (None, 28, 28, 512) 0 bn3a_branch2c[0][0] bn3a_branch1[0][0] __________________________________________________________________________________________________ activation_62 (Activation) (None, 28, 28, 512) 0 add_20[0][0] __________________________________________________________________________________________________ res3b_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_62[0][0] __________________________________________________________________________________________________ bn3b_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2a[0][0] __________________________________________________________________________________________________ activation_63 (Activation) (None, 28, 28, 128) 0 bn3b_branch2a[0][0] __________________________________________________________________________________________________ res3b_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_63[0][0] __________________________________________________________________________________________________ bn3b_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2b[0][0] __________________________________________________________________________________________________ activation_64 (Activation) (None, 28, 28, 128) 0 bn3b_branch2b[0][0] __________________________________________________________________________________________________ res3b_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_64[0][0] __________________________________________________________________________________________________ bn3b_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3b_branch2c[0][0] __________________________________________________________________________________________________ add_21 (Add) (None, 28, 28, 512) 0 bn3b_branch2c[0][0] activation_62[0][0] __________________________________________________________________________________________________ activation_65 (Activation) (None, 28, 28, 512) 0 add_21[0][0] __________________________________________________________________________________________________ res3c_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_65[0][0] __________________________________________________________________________________________________ bn3c_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2a[0][0] __________________________________________________________________________________________________ activation_66 (Activation) (None, 28, 28, 128) 0 bn3c_branch2a[0][0] __________________________________________________________________________________________________ res3c_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_66[0][0] __________________________________________________________________________________________________ bn3c_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2b[0][0] __________________________________________________________________________________________________ activation_67 (Activation) (None, 28, 28, 128) 0 bn3c_branch2b[0][0] __________________________________________________________________________________________________ res3c_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_67[0][0] __________________________________________________________________________________________________ bn3c_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3c_branch2c[0][0] __________________________________________________________________________________________________ add_22 (Add) (None, 28, 28, 512) 0 bn3c_branch2c[0][0] activation_65[0][0] __________________________________________________________________________________________________ activation_68 (Activation) (None, 28, 28, 512) 0 add_22[0][0] __________________________________________________________________________________________________ res3d_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_68[0][0] __________________________________________________________________________________________________ bn3d_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2a[0][0] __________________________________________________________________________________________________ activation_69 (Activation) (None, 28, 28, 128) 0 bn3d_branch2a[0][0] __________________________________________________________________________________________________ res3d_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_69[0][0] __________________________________________________________________________________________________ bn3d_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2b[0][0] __________________________________________________________________________________________________ activation_70 (Activation) (None, 28, 28, 128) 0 bn3d_branch2b[0][0] __________________________________________________________________________________________________ res3d_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_70[0][0] __________________________________________________________________________________________________ bn3d_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3d_branch2c[0][0] __________________________________________________________________________________________________ add_23 (Add) (None, 28, 28, 512) 0 bn3d_branch2c[0][0] activation_68[0][0] __________________________________________________________________________________________________ activation_71 (Activation) (None, 28, 28, 512) 0 add_23[0][0] __________________________________________________________________________________________________ res4a_branch2a (Conv2D) (None, 14, 14, 256) 131328 activation_71[0][0] __________________________________________________________________________________________________ bn4a_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2a[0][0] __________________________________________________________________________________________________ activation_72 (Activation) (None, 14, 14, 256) 0 bn4a_branch2a[0][0] __________________________________________________________________________________________________ res4a_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_72[0][0] __________________________________________________________________________________________________ bn4a_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2b[0][0] __________________________________________________________________________________________________ activation_73 (Activation) (None, 14, 14, 256) 0 bn4a_branch2b[0][0] __________________________________________________________________________________________________ res4a_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_73[0][0] __________________________________________________________________________________________________ res4a_branch1 (Conv2D) (None, 14, 14, 1024) 525312 activation_71[0][0] __________________________________________________________________________________________________ bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4a_branch2c[0][0] __________________________________________________________________________________________________ bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096 res4a_branch1[0][0] __________________________________________________________________________________________________ add_24 (Add) (None, 14, 14, 1024) 0 bn4a_branch2c[0][0] bn4a_branch1[0][0] __________________________________________________________________________________________________ activation_74 (Activation) (None, 14, 14, 1024) 0 add_24[0][0] __________________________________________________________________________________________________ res4b_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_74[0][0] __________________________________________________________________________________________________ bn4b_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2a[0][0] __________________________________________________________________________________________________ activation_75 (Activation) (None, 14, 14, 256) 0 bn4b_branch2a[0][0] __________________________________________________________________________________________________ res4b_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_75[0][0] __________________________________________________________________________________________________ bn4b_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2b[0][0] __________________________________________________________________________________________________ activation_76 (Activation) (None, 14, 14, 256) 0 bn4b_branch2b[0][0] __________________________________________________________________________________________________ res4b_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_76[0][0] __________________________________________________________________________________________________ bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4b_branch2c[0][0] __________________________________________________________________________________________________ add_25 (Add) (None, 14, 14, 1024) 0 bn4b_branch2c[0][0] activation_74[0][0] __________________________________________________________________________________________________ activation_77 (Activation) (None, 14, 14, 1024) 0 add_25[0][0] __________________________________________________________________________________________________ res4c_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_77[0][0] __________________________________________________________________________________________________ bn4c_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2a[0][0] __________________________________________________________________________________________________ activation_78 (Activation) (None, 14, 14, 256) 0 bn4c_branch2a[0][0] __________________________________________________________________________________________________ res4c_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_78[0][0] __________________________________________________________________________________________________ bn4c_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2b[0][0] __________________________________________________________________________________________________ activation_79 (Activation) (None, 14, 14, 256) 0 bn4c_branch2b[0][0] __________________________________________________________________________________________________ res4c_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_79[0][0] __________________________________________________________________________________________________ bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4c_branch2c[0][0] __________________________________________________________________________________________________ add_26 (Add) (None, 14, 14, 1024) 0 bn4c_branch2c[0][0] activation_77[0][0] __________________________________________________________________________________________________ activation_80 (Activation) (None, 14, 14, 1024) 0 add_26[0][0] __________________________________________________________________________________________________ res4d_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_80[0][0] __________________________________________________________________________________________________ bn4d_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2a[0][0] __________________________________________________________________________________________________ activation_81 (Activation) (None, 14, 14, 256) 0 bn4d_branch2a[0][0] __________________________________________________________________________________________________ res4d_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_81[0][0] __________________________________________________________________________________________________ bn4d_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2b[0][0] __________________________________________________________________________________________________ activation_82 (Activation) (None, 14, 14, 256) 0 bn4d_branch2b[0][0] __________________________________________________________________________________________________ res4d_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_82[0][0] __________________________________________________________________________________________________ bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4d_branch2c[0][0] __________________________________________________________________________________________________ add_27 (Add) (None, 14, 14, 1024) 0 bn4d_branch2c[0][0] activation_80[0][0] __________________________________________________________________________________________________ activation_83 (Activation) (None, 14, 14, 1024) 0 add_27[0][0] __________________________________________________________________________________________________ res4e_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_83[0][0] __________________________________________________________________________________________________ bn4e_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2a[0][0] __________________________________________________________________________________________________ activation_84 (Activation) (None, 14, 14, 256) 0 bn4e_branch2a[0][0] __________________________________________________________________________________________________ res4e_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_84[0][0] __________________________________________________________________________________________________ bn4e_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2b[0][0] __________________________________________________________________________________________________ activation_85 (Activation) (None, 14, 14, 256) 0 bn4e_branch2b[0][0] __________________________________________________________________________________________________ res4e_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_85[0][0] __________________________________________________________________________________________________ bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4e_branch2c[0][0] __________________________________________________________________________________________________ add_28 (Add) (None, 14, 14, 1024) 0 bn4e_branch2c[0][0] activation_83[0][0] __________________________________________________________________________________________________ activation_86 (Activation) (None, 14, 14, 1024) 0 add_28[0][0] __________________________________________________________________________________________________ res4f_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_86[0][0] __________________________________________________________________________________________________ bn4f_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2a[0][0] __________________________________________________________________________________________________ activation_87 (Activation) (None, 14, 14, 256) 0 bn4f_branch2a[0][0] __________________________________________________________________________________________________ res4f_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_87[0][0] __________________________________________________________________________________________________ bn4f_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2b[0][0] __________________________________________________________________________________________________ activation_88 (Activation) (None, 14, 14, 256) 0 bn4f_branch2b[0][0] __________________________________________________________________________________________________ res4f_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_88[0][0] __________________________________________________________________________________________________ bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4f_branch2c[0][0] __________________________________________________________________________________________________ add_29 (Add) (None, 14, 14, 1024) 0 bn4f_branch2c[0][0] activation_86[0][0] __________________________________________________________________________________________________ activation_89 (Activation) (None, 14, 14, 1024) 0 add_29[0][0] __________________________________________________________________________________________________ res5a_branch2a (Conv2D) (None, 7, 7, 512) 524800 activation_89[0][0] __________________________________________________________________________________________________ bn5a_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2a[0][0] __________________________________________________________________________________________________ activation_90 (Activation) (None, 7, 7, 512) 0 bn5a_branch2a[0][0] __________________________________________________________________________________________________ res5a_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_90[0][0] __________________________________________________________________________________________________ bn5a_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2b[0][0] __________________________________________________________________________________________________ activation_91 (Activation) (None, 7, 7, 512) 0 bn5a_branch2b[0][0] __________________________________________________________________________________________________ res5a_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_91[0][0] __________________________________________________________________________________________________ res5a_branch1 (Conv2D) (None, 7, 7, 2048) 2099200 activation_89[0][0] __________________________________________________________________________________________________ bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5a_branch2c[0][0] __________________________________________________________________________________________________ bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048) 8192 res5a_branch1[0][0] __________________________________________________________________________________________________ add_30 (Add) (None, 7, 7, 2048) 0 bn5a_branch2c[0][0] bn5a_branch1[0][0] __________________________________________________________________________________________________ activation_92 (Activation) (None, 7, 7, 2048) 0 add_30[0][0] __________________________________________________________________________________________________ res5b_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_92[0][0] __________________________________________________________________________________________________ bn5b_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2a[0][0] __________________________________________________________________________________________________ activation_93 (Activation) (None, 7, 7, 512) 0 bn5b_branch2a[0][0] __________________________________________________________________________________________________ res5b_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_93[0][0] __________________________________________________________________________________________________ bn5b_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2b[0][0] __________________________________________________________________________________________________ activation_94 (Activation) (None, 7, 7, 512) 0 bn5b_branch2b[0][0] __________________________________________________________________________________________________ res5b_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_94[0][0] __________________________________________________________________________________________________ bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5b_branch2c[0][0] __________________________________________________________________________________________________ add_31 (Add) (None, 7, 7, 2048) 0 bn5b_branch2c[0][0] activation_92[0][0] __________________________________________________________________________________________________ activation_95 (Activation) (None, 7, 7, 2048) 0 add_31[0][0] __________________________________________________________________________________________________ res5c_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_95[0][0] __________________________________________________________________________________________________ bn5c_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2a[0][0] __________________________________________________________________________________________________ activation_96 (Activation) (None, 7, 7, 512) 0 bn5c_branch2a[0][0] __________________________________________________________________________________________________ res5c_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_96[0][0] __________________________________________________________________________________________________ bn5c_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2b[0][0] __________________________________________________________________________________________________ activation_97 (Activation) (None, 7, 7, 512) 0 bn5c_branch2b[0][0] __________________________________________________________________________________________________ res5c_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_97[0][0] __________________________________________________________________________________________________ bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5c_branch2c[0][0] __________________________________________________________________________________________________ add_32 (Add) (None, 7, 7, 2048) 0 bn5c_branch2c[0][0] activation_95[0][0] __________________________________________________________________________________________________ activation_98 (Activation) (None, 7, 7, 2048) 0 add_32[0][0] __________________________________________________________________________________________________ avg_pool (GlobalAveragePooling2 (None, 2048) 0 activation_98[0][0] __________________________________________________________________________________________________ fc1000 (Dense) (None, 1000) 2049000 avg_pool[0][0] ================================================================================================== Total params: 25,636,712 Trainable params: 25,583,592 Non-trainable params: 53,120 __________________________________________________________________________________________________
In [252]:
import tqdm
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gz Extracting MNIST_data/train-labels-idx1-ubyte.gz Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
In [295]:
X_train = mnist.train.images
y_train = mnist.train.labels
X_test = mnist.test.images
y_test = mnist.test.labels
In [287]:
# image preprocessing
X_train = train.values.astype('float32')/255
X_test = train.values.astype('float32')/255
In [296]:
# reshape
X_train = X_train.reshape(-1,28,28,1)
X_test = X_test.reshape(-1,28,28,1)
In [297]:
X = tf.placeholder(tf.float32, [None,28,28,1])
y = tf.placeholder(tf.float32,[None,10])
In [298]:
def resnet(input):
res = tf.nn.conv2d(input,tf.Variable(tf.random_normal([3,3,32,32]))
, strides=[1,1,1,1], padding='SAME')
res = tf.nn.relu(res)
res = tf.nn.conv2d(input,tf.Variable(tf.random_normal([3,3,32,32]))
, strides=[1,1,1,1], padding='SAME')
res = input + res
res = tf.nn.relu(res)
return res
In [299]:
layer = tf.nn.conv2d(X,tf.Variable(tf.random_normal([3,3,1,32]))
,strides=[1,1,1,1],padding='SAME')
for i in range(5):
layer = resnet(layer)
In [300]:
final = tf.nn.conv2d(layer,tf.Variable(tf.random_normal([3,3,32,1]))
, strides=[1,1,1,1], padding='SAME')
# 쭉 펴준다. (28*28 = 784)
flatten = tf.reshape(final,[-1,final.get_shape()[1:4].num_elements()])
W = tf.Variable(tf.random_normal([final.get_shape()[1:4].num_elements(),10]))
result = tf.matmul(flatten,W)
In [301]:
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=result,labels=y))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(result,1),tf.argmax(y,1)),dtype = tf.float32))
optimizer = tf.train.AdamOptimizer(0.01).minimize(cost)
In [302]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
In [ ]:
batch_size=50
cost_list=[]
acc_list=[]
for i in range(100):
for j in tqdm.tqdm_notebook(range(mnist.train.num_examples//batch_size)):
_,c = sess.run([optimizer,cost],feed_dict={X:X_train[j*batch_size:(j+1)*batch_size,:,:,:],y:y_train[j*batch_size:(j+1)*batch_size,:]})
cost_list.append(c)
print("Epoch:{0}, Cost:{1}".format(i+1, np.mean(cost_list)))
print("Accuracy :",sess.run(accuracy, feed_dict={X:X_test,y:y_test}))
Epoch:1, Cost:1958.700927734375 Accuracy : 0.9062
Epoch:2, Cost:1507.1817626953125 Accuracy : 0.8986
Epoch:3, Cost:1207.7093505859375 Accuracy : 0.9129
Epoch:4, Cost:987.8175048828125 Accuracy : 0.9225
Epoch:5, Cost:828.54638671875 Accuracy : 0.9228
Epoch:6, Cost:709.8893432617188 Accuracy : 0.9145
Epoch:7, Cost:618.9390869140625 Accuracy : 0.9111
Epoch:8, Cost:546.6304321289062 Accuracy : 0.8969
Epoch:9, Cost:488.52044677734375 Accuracy : 0.9159
Epoch:10, Cost:440.9261474609375 Accuracy : 0.9161
에포크 100까지 가야하지만 CPU로는 한 Epoch당 4분 이상이 걸려서 10-Epoch에서 끊었다. 그럼에도 ACcuracy가 91%까지 오르는 것을 볼 수 있었다. 간단한 ResNet모델인데도 이렇게 성능이 나오는 것을 보니 다른 dataset에도 network를 더 깊게해서 적용을 해보고 싶다.
728x90
반응형