├── README.md
├── counter.txt
├── data.rar
├── getdata.py
├── getkeys.py
├── keras
├── collect_sample.py
└── model_keras.py
├── pytorch
└── readme.md.
└── tensorflow
├── load_data.py
├── loaddata_2.py
├── loaddata_3.py
└── model_tf.py
/README.md:
--------------------------------------------------------------------------------
1 | # chrome_Trex
2 |
3 | This program automates the Google Chrome Trex game.
4 | You can use this URL for collecting data and testing.
5 | http://wayou.github.io/t-rex-runner/
6 |
7 | A convolutional neural network is being used to predict the keyboard input.
8 |
9 | ## getdata.py:
10 | Collects training images and stores them in data/
11 | ## getkeys.py:
12 | Contains helper function for getting keyboard input
13 | (Source: https://github.com/Sentdex/pygta5/blob/master/Versions/v0.02/getkeys.py)
14 |
15 | Different models are implemented in keras and tensorflow.
16 |
--------------------------------------------------------------------------------
/counter.txt:
--------------------------------------------------------------------------------
1 | 401
2 | 3
3 | 8024
4 |
--------------------------------------------------------------------------------
/data.rar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SouravSharan/chrome_Trex/e5efc3443886f01c82c5272e8ad504fd8a501784/data.rar
--------------------------------------------------------------------------------
/getdata.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import ImageGrab
3 | import cv2
4 | import time
5 | from getkeys import key_check
6 |
7 | file = open("G://Works//Chrome T-rex//counter.txt", 'r')
8 | items = file.readlines()
9 | file.close()
10 | counter = []
11 | counter = list(map(int, items))
12 | print(counter)
13 |
14 | path = 'G://Works//Chrome T-rex//data//'
15 |
16 | up = 38
17 | down = 40
18 |
19 | for i in list(range(4))[::-1]:
20 | print(i+1)
21 | time.sleep(1)
22 | last_time = time.time()
23 | while True:
24 | screen = np.array(ImageGrab.grab(bbox=(360,100,700,440)))
25 | screen = cv2.cvtColor(screen, cv2.COLOR_RGB2BGR)
26 | keys = key_check()
27 |
28 | if up in keys:
29 | cv2.imwrite(path + 'up/' + str(counter[0]) + ".jpg",screen)
30 | counter[0]+=1
31 | time.sleep(0.5)
32 | elif down in keys:
33 | cv2.imwrite(path + 'down/' + str(counter[1]) + ".jpg",screen)
34 | counter[1]+=1
35 | time.sleep(0.5)
36 | else:
37 | cv2.imwrite(path + 'null/' + str(counter[2]) + ".jpg",screen)
38 | counter[2]+=1
39 |
40 | if ord('E') in keys:
41 | break
42 | if cv2.waitKey(25) & 0xFF == ord('q'):
43 | cv2.destroyAllWindows()
44 | break
45 |
46 | file = open("G://Works//Chrome T-rex//counter.txt", 'w')
47 | for ch in counter:
48 | print(ch)
49 | file.write(str(ch) + "\n")
50 | file.close()
51 |
--------------------------------------------------------------------------------
/getkeys.py:
--------------------------------------------------------------------------------
1 | # Citation: Box Of Hats (https://github.com/Box-Of-Hats )
2 |
3 | import win32api as wapi
4 | import time
5 |
6 | keyList = []
7 |
8 | for char in "ABCDEFGHIJKLMNOPQRSTUVWXYZ 123456789,.'£$/\\":
9 | keyList.append(ord(char))
10 |
11 | keyList.append(13) #"0x0D") #enter
12 | keyList.append(37 ) #left_arrow
13 | keyList.append(38) #up_arrow
14 | keyList.append(39) #right_arrow
15 | keyList.append(40) #down_arrow
16 |
17 | def key_check():
18 | keys = []
19 | for key in keyList:
20 | if wapi.GetAsyncKeyState(int(key)):
21 | keys.append(key)
22 |
23 | return keys
24 |
--------------------------------------------------------------------------------
/keras/collect_sample.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import math
4 | from keras.models import load_model
5 | import numpy as np
6 | import pyautogui
7 | import time
8 | from PIL import ImageGrab
9 |
10 | model = load_model('whole_model.h5')
11 | while True:
12 | screen = np.array(ImageGrab.grab(bbox=(360,100,700,440)))
13 | screen = cv2.cvtColor(screen, cv2.COLOR_RGB2BGR)
14 |
15 | cv2.imwrite('./screen.jpg', screen)
16 | screen = cv2.imread('./screen.jpg')
17 | screen = np.expand_dims(screen, axis=0)
18 | key = model.predict(screen, batch_size = 1, verbose = 0)
19 | k = key[0]
20 | print(k[1])
21 | if k[1] == 1:
22 | pyautogui.press('space')
23 |
--------------------------------------------------------------------------------
/keras/model_keras.py:
--------------------------------------------------------------------------------
1 | import keras
2 | from keras.preprocessing.image import ImageDataGenerator
3 | from keras.models import Sequential
4 | from keras.layers import Conv2D, MaxPooling2D
5 | from keras.layers import Activation, Dropout, Flatten, Dense
6 | from keras import backend as K
7 | from keras import utils
8 | #1import h5py
9 |
10 | # dimensions of our images.
11 | img_width, img_height = 340, 340
12 |
13 | train_data_dir = 'G:/Works/Chrome T-rex/data'
14 | #validation_data_dir = 'data/validation'
15 | nb_train_samples = 768
16 | #nb_validation_samples = 800
17 | epochs = 30
18 | batch_size = 48
19 |
20 | if K.image_data_format() == 'channels_first':
21 | input_shape = (3, img_width, img_height)
22 | else:
23 | input_shape = (img_width, img_height, 3)
24 |
25 | model = Sequential()
26 | model.add(Conv2D(32, (3, 3), input_shape=input_shape))
27 | model.add(Activation('relu'))
28 | model.add(MaxPooling2D(pool_size=(2, 2)))
29 |
30 | model.add(Conv2D(64, (3, 3)))
31 | model.add(Activation('relu'))
32 | model.add(MaxPooling2D(pool_size=(2, 2)))
33 |
34 | model.add(Conv2D(64, (3, 3)))
35 | model.add(Activation('relu'))
36 | model.add(MaxPooling2D(pool_size=(2, 2)))
37 |
38 | model.add(Conv2D(128, (3, 3)))
39 | model.add(Activation('relu'))
40 | model.add(MaxPooling2D(pool_size=(2, 2)))
41 |
42 | model.add(Flatten())
43 | model.add(Dense(128))
44 | model.add(Activation('relu'))
45 | model.add(Dropout(0.5))
46 | model.add(Dense(2))
47 | model.add(Activation('softmax'))
48 |
49 | model.compile(loss=keras.losses.categorical_crossentropy,
50 | optimizer=keras.optimizers.Adadelta(),
51 | metrics=['accuracy'])
52 |
53 | # this is the augmentation configuration we will use for training
54 | train_datagen = ImageDataGenerator(
55 | rescale=1. / 255,
56 | shear_range=0.2,
57 | zoom_range=0.2,
58 | horizontal_flip=True)
59 |
60 | # this is the augmentation configuration we will use for testing:
61 | # only rescaling
62 | test_datagen = ImageDataGenerator(rescale=1. / 255)
63 |
64 | train_generator = train_datagen.flow_from_directory(
65 | train_data_dir,
66 | target_size=(img_width, img_height),
67 | batch_size=batch_size,
68 | class_mode='categorical')
69 | '''
70 | validation_generator = test_datagen.flow_from_directory(
71 | validation_data_dir,
72 | target_size=(img_width, img_height),
73 | batch_size=batch_size,
74 | class_mode='categorical')
75 | '''
76 | model.fit_generator(
77 | train_generator,
78 | steps_per_epoch=nb_train_samples // batch_size,
79 | epochs=epochs)
80 | #validation_data=validation_generator,
81 | #validation_steps=nb_validation_samples // batch_size)
82 | #model.save_weights('./model_weights.h5')
83 | model.save('./whole_model.h5')
84 |
--------------------------------------------------------------------------------
/pytorch/readme.md.:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tensorflow/load_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | import cv2
5 | import numpy as np
6 | from keras import backend as K
7 | from keras.utils import np_utils
8 | import argparse
9 | from datetime import datetime
10 | import hashlib
11 | import os.path
12 | import random
13 | import re
14 | import sys
15 | import tarfile
16 | from six.moves import urllib
17 |
18 |
19 | import tensorflow as tf
20 | from tensorflow.python.framework import graph_util
21 | from tensorflow.python.framework import tensor_shape
22 | from tensorflow.python.platform import gfile
23 | from tensorflow.python.util import compat
24 |
25 | FLAGS = None
26 | def create_image_lists():
27 | image_dir='/home/rick/derma/dataset'
28 | testing_percentage=20
29 |
30 | result = {}
31 | counter_for_result_label=0
32 |
33 | sub_dirs = [x[0] for x in gfile.Walk(image_dir)] #create sub_dirs
34 |
35 |
36 | # The root directory comes first, so skip it.
37 |
38 | dir_name=[]
39 |
40 | #ignore first element in sub_dir
41 | is_root_dir = True
42 | for sub_dir in sub_dirs:
43 | if is_root_dir:
44 | is_root_dir = False
45 | continue
46 |
47 |
48 | dir_name = os.path.basename(sub_dir)
49 |
50 | extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
51 | file_list = []
52 | #dir_name = os.path.basename(image_dir)
53 | #if dir_name == image_dir:
54 | #continue
55 | tf.logging.info("Looking for images in '" + dir_name + "'")
56 | for extension in extensions:
57 | #for image_dir in sub_dir
58 | file_glob = os.path.join(image_dir, dir_name, '*.' + extension)
59 | file_list.extend(gfile.Glob(file_glob)) #create a list of all files
60 |
61 | #using regex to set label name
62 | label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower())
63 |
64 | #dividing
65 | training_images = []
66 | testing_images = []
67 | for file_name in file_list:
68 | base_name = os.path.basename(file_name)
69 | hash_name = re.sub(r'_nohash_.*$', '', file_name)
70 |
71 | hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest()
72 | percentage_hash = ((int(hash_name_hashed, 16) %
73 | (MAX_NUM_IMAGES_PER_CLASS + 1)) *
74 | (100.0 / MAX_NUM_IMAGES_PER_CLASS))
75 | if percentage_hash < testing_percentage:
76 | #testing_images.append(file_name)
77 | testing_images.append(cv2.imread(file_name))
78 | #testing_images.append(base_name)
79 | else:
80 | #training_images.append(file_name)
81 | training_images.append(cv2.imread(file_name))
82 | #training_images.append(base_name)
83 |
84 |
85 | result[counter_for_result_label] = {
86 | 'training_label': [counter_for_result_label]*(len(training_images)),
87 | 'testing_label': [counter_for_result_label]*(len(testing_images)),
88 | 'training': training_images,
89 | 'testing': testing_images,
90 | }
91 |
92 | counter_for_result_label=counter_for_result_label+1
93 | return result
94 |
--------------------------------------------------------------------------------
/tensorflow/loaddata_2.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 |
4 | # Dataset Parameters - CHANGE HERE
5 | MODE = 'folder' # or 'file', if you choose a plain text file (see above).
6 | dataset_path = 'G:/Works/Chrome T-rex/data' # the dataset file or root folder path.
7 | BATCH_SIZE = 48
8 | # Image Parameters
9 | N_CLASSES = 2 # CHANGE HERE, total number of classes
10 | IMG_HEIGHT = 64 # CHANGE HERE, the image height to be resized to
11 | IMG_WIDTH = 64 # CHANGE HERE, the image width to be resized to
12 | CHANNELS = 3 # The 3 color channels, change to 1 if grayscale
13 |
14 | # Reading the dataset
15 | # 2 modes: 'file' or 'folder'
16 | def read_images(mode, batch_size):
17 | global dataset_path
18 | imagepaths, labels = list(), list()
19 | if mode == 'file':
20 | # Read dataset file
21 | data = open(dataset_path, 'r').read().splitlines()
22 | for d in data:
23 | imagepaths.append(d.split(' ')[0])
24 | labels.append(int(d.split(' ')[1]))
25 | elif mode == 'folder':
26 | # An ID will be affected to each sub-folders by alphabetical order
27 | label = 0
28 | # List the directory
29 | try: # Python 2
30 | classes = sorted(os.walk(dataset_path).next()[1])
31 | except Exception: # Python 3
32 | classes = sorted(os.walk(dataset_path).__next__()[1])
33 | # List each sub-directory (the classes)
34 | for c in classes:
35 | c_dir = os.path.join(dataset_path, c)
36 | try: # Python 2
37 | walk = os.walk(c_dir).next()
38 | except Exception: # Python 3
39 | walk = os.walk(c_dir).__next__()
40 | # Add each image to the training set
41 | for sample in walk[2]:
42 | # Only keeps jpeg images
43 | if sample.endswith('.jpg') or sample.endswith('.jpeg'):
44 | imagepaths.append(os.path.join(c_dir, sample))
45 | labels.append(label)
46 | label += 1
47 | else:
48 | raise Exception("Unknown mode.")
49 |
50 | # Convert to Tensor
51 | imagepaths = tf.convert_to_tensor(imagepaths, dtype=tf.string)
52 | labels = tf.convert_to_tensor(labels, dtype=tf.int32)
53 | # Build a TF Queue, shuffle data
54 | image, label = tf.train.slice_input_producer([imagepaths, labels],
55 | shuffle=True)
56 |
57 | # Read images from disk
58 | image = tf.read_file(image)
59 | image = tf.image.decode_jpeg(image, channels=CHANNELS)
60 |
61 | # Resize images to a common size
62 | image = tf.image.resize_images(image, [IMG_HEIGHT, IMG_WIDTH])
63 |
64 | # Normalize
65 | image = image * 1.0/127.5 - 1.0
66 |
67 | # Create batches
68 | X, Y = tf.train.batch([image, label], batch_size=batch_size,
69 | capacity=batch_size * 8,
70 | num_threads=4)
71 |
72 | return X, Y
73 |
74 | #print(read_images('folder', 48))
75 |
--------------------------------------------------------------------------------
/tensorflow/loaddata_3.py:
--------------------------------------------------------------------------------
1 | # reference : https://github.com/sjchoi86/tensorflow-101/blob/master/notebooks/basic_gendataset.ipynb
2 | import numpy as np
3 | import os
4 | from scipy.misc import imread, imresize
5 |
6 | cwd = os.getcwd()
7 | print ("Current folder is %s" % (cwd) )
8 |
9 | # Training set folder
10 | paths = {"G:/Works/Chrome T-rex/data/null", "G:/Works/Chrome T-rex/data/up"}
11 |
12 | # The reshape size
13 | imgsize = [340, 340]
14 |
15 | # Grayscale
16 | use_gray = 1
17 |
18 | # Save name
19 | data_name = "custom_data"
20 |
21 | def rgb2gray(rgb):
22 | if len(rgb.shape) is 3:
23 | return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])
24 | else:
25 | return rgb
26 |
27 | nclass = len(paths)
28 | valid_exts = [".jpg",".gif",".png",".tga", ".jpeg"]
29 | imgcnt = 0
30 | for i, relpath in zip(range(nclass), paths):
31 | path = relpath
32 | flist = os.listdir(path)
33 | for f in flist:
34 | if os.path.splitext(f)[1].lower() not in valid_exts:
35 | continue
36 | fullpath = os.path.join(path, f)
37 | print(fullpath)
38 | currimg = imread(fullpath)
39 | # Convert to grayscale
40 | if use_gray:
41 | grayimg = rgb2gray(currimg)
42 | else:
43 | grayimg = currimg
44 | # Reshape
45 | graysmall = imresize(grayimg, [imgsize[0], imgsize[1]])/255.
46 | grayvec = np.reshape(graysmall, (1, -1))
47 | # Save
48 | curr_label = np.eye(nclass, nclass)[i:i+1, :]
49 | if imgcnt is 0:
50 | totalimg = grayvec
51 | totallabel = curr_label
52 | else:
53 | totalimg = np.concatenate((totalimg, grayvec), axis=0)
54 | totallabel = np.concatenate((totallabel, curr_label), axis=0)
55 | imgcnt = imgcnt + 1
56 |
57 | print ("Total %d images loaded." % (imgcnt))
58 |
59 | savepath = "G:/Works/Chrome T-rex/tensorflow/" + data_name + ".npz"
60 |
61 | np.savez(savepath, trainimg=totalimg, trainlabel=totallabel , imgsize=imgsize, use_gray=use_gray)
62 |
--------------------------------------------------------------------------------
/tensorflow/model_tf.py:
--------------------------------------------------------------------------------
1 |
2 | import tensorflow as tf
3 | import numpy as np
4 | import time
5 |
6 | ########Load data#########
7 |
8 | loadpath = "G:/Works/Chrome T-rex/tensorflow/custom_data.npz"
9 | l = np.load(loadpath)
10 |
11 | l.files
12 |
13 | #Parse data
14 | trainimg = l["trainimg"]
15 | trainlabel = l["trainlabel"]
16 | ntrain = trainimg.shape[0]
17 | nclass = trainlabel.shape[1]
18 | dim = trainimg.shape[1]
19 |
20 | print ("%d train images loaded" % (ntrain))
21 | print ("%d dimensional input" % (dim))
22 | print ("%d classes" % (nclass))
23 | print ("shape of 'trainimg' is %s" % (trainimg.shape,))
24 |
25 | '''
26 | trainimg_tensor = np.ndarray((ntrain, 340, 340, 1))
27 | for i in range(ntrain):
28 | currimg = trainimg[i, :]
29 | currimg = np.reshape(currimg, [340, 340, 1])
30 | trainimg_tensor[i, :, :, :] = currimg
31 |
32 | print ("shape of trainimg_tensor is %s" % (trainimg_tensor.shape,))
33 | '''
34 | ##########################
35 |
36 | ########## CNN ##########
37 |
38 | # Convolutional Layer 1
39 | filterSize1 = 5
40 | numFilters1 = 16
41 | stride1_x = 1
42 | stride1_y = 1
43 |
44 | # Convolutional Layer 2
45 | filterSize2 = 5
46 | numFilters2 = 16
47 | stride2_x = 2
48 | stride2_y = 2
49 |
50 | # Convolutional Layer 3
51 | filterSize3 = 5
52 | numFilters3 = 32
53 | stride3_x = 2
54 | stride3_y = 2
55 |
56 | # Convolutional Layer 4
57 | filterSize4 = 3
58 | numFilters4 = 64
59 | stride4_x = 2
60 | stride4_y = 2
61 |
62 | #FC 1
63 | fc_size = 128
64 |
65 | #Image Dimentions
66 | img_w = 340
67 | img_h = 340
68 | img_size_flat = img_h*img_w
69 | num_channels = 1
70 |
71 | num_classes = 2
72 |
73 | def new_weights(shape):
74 | return tf.Variable(tf.truncated_normal(shape, stddev=0.05))
75 |
76 |
77 |
78 | def new_biases(length):
79 | return tf.Variable(tf.constant(0.05, shape=[length]))
80 |
81 |
82 |
83 | def new_conv_layer(input, # The previous layer.
84 | num_input_channels, # Num. channels in prev. layer.
85 | filter_size, # Width and height of each filter.
86 | num_filters, # Number of filters.
87 | stride_x,
88 | stride_y):
89 |
90 | # Shape of the filter-weights for the convolution.
91 | shape = [filter_size, filter_size, num_input_channels, num_filters]
92 |
93 | # Create new weight (filters)
94 | weights = new_weights(shape=shape)
95 |
96 | # Create new biases, one for each filter.
97 | biases = new_biases(length=num_filters)
98 |
99 | layer = tf.nn.conv2d(input=input,
100 | filter=weights,
101 | strides=[1, stride_y, stride_x, 1],
102 | padding='SAME')
103 |
104 | # A bias-value is added to each filter-channel.
105 | layer += biases
106 |
107 | # Rectified Linear Unit (ReLU).
108 | layer = tf.nn.relu(layer)
109 |
110 | return layer, weights
111 |
112 | def flatten_layer(layer):
113 | layer_shape = layer.get_shape()
114 | num_features = layer_shape[1:4].num_elements()
115 | layer_flat = tf.reshape(layer, [-1, num_features])
116 | return layer_flat, num_features
117 |
118 | def new_fc_layer(input, num_inputs, num_outputs):
119 | weights = new_weights(shape=[num_inputs, num_outputs])
120 | biases = new_biases(length=num_outputs)
121 | layer = tf.matmul(input, weights) + biases
122 | layer = tf.nn.relu(layer)
123 | return layer
124 |
125 | x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x')
126 | x_image = tf.reshape(x, [-1, img_h, img_w, num_channels])
127 | y_true = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_true')
128 | y_true_cls = tf.argmax(y_true, dimension=1)
129 |
130 | layer_conv1, weights_conv1 = new_conv_layer(input=x_image,
131 | num_input_channels=num_channels,
132 | filter_size=filterSize1,
133 | num_filters=numFilters1,
134 | stride_x=stride1_x,
135 | stride_y=stride1_y )
136 |
137 | print(layer_conv1)
138 |
139 | layer_conv2, weights_conv2 = new_conv_layer(input=layer_conv1,
140 | num_input_channels=numFilters1,
141 | filter_size=filterSize2,
142 | num_filters=numFilters2,
143 | stride_x=stride2_x,
144 | stride_y=stride2_y )
145 |
146 | print(layer_conv2)
147 |
148 | layer_conv3, weights_conv3 = new_conv_layer(input=layer_conv2,
149 | num_input_channels=numFilters2,
150 | filter_size=filterSize3,
151 | num_filters=numFilters3,
152 | stride_x=stride3_x,
153 | stride_y=stride3_y )
154 |
155 | print(layer_conv3)
156 |
157 |
158 | layer_conv4, weights_conv4 = new_conv_layer(input=layer_conv3,
159 | num_input_channels=numFilters3,
160 | filter_size=filterSize4,
161 | num_filters=numFilters4,
162 | stride_x=stride4_x,
163 | stride_y=stride4_y )
164 |
165 | print(layer_conv4)
166 |
167 | conv_shape = tf.shape(layer_conv4)
168 | layer_flat, num_features = flatten_layer(layer_conv4)
169 |
170 | print(layer_flat, num_features)
171 |
172 |
173 | layer_fc1 = new_fc_layer(input=layer_flat,
174 | num_inputs=num_features,
175 | num_outputs=fc_size)
176 |
177 | print(layer_fc1)
178 |
179 | layer_fc2 = new_fc_layer(input=layer_fc1,
180 | num_inputs=fc_size,
181 | num_outputs=num_classes)
182 |
183 | print(layer_fc2)
184 |
185 | y_pred = tf.nn.softmax(layer_fc2)
186 | y_pred_cls = tf.argmax(y_pred, dimension=1)
187 |
188 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=layer_fc2, labels=y_true)
189 | cost = tf.reduce_mean(cross_entropy)
190 | optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost)
191 |
192 | correct_prediction = tf.equal(y_pred_cls, y_true_cls)
193 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
194 |
195 | session = tf.Session()
196 | session.run(tf.global_variables_initializer())
197 |
198 | batch_size = 48
199 | total_iterations = 0
200 |
201 | save_step = 1
202 | saver = tf.train.Saver(max_to_keep=3)
203 |
204 | def optimize(num_iterations):
205 | global total_iterations
206 | start_time = time.time()
207 |
208 | for epoch in range(total_iterations, total_iterations + num_iterations):
209 | num_batch = int(ntrain/batch_size)+1
210 | for i in range(num_batch):
211 | randidx = np.random.randint(ntrain, size=batch_size)
212 | #batch_xs = train_vectorized[randidx, :]
213 | batch_xs = trainimg[randidx, :]
214 | batch_ys = trainlabel[randidx, :]
215 | session.run(optimizer, feed_dict={x: batch_xs, y_true: batch_ys})
216 | print(str(epoch) + ":" + str(i))
217 |
218 | acc = session.run(accuracy, feed_dict={x: batch_xs, y_true: batch_ys})
219 | msg = "Optimization Iteration: {0:>6}, Training Accuracy: {1:>6.1%}"
220 | print(msg.format(epoch + 1, acc))
221 | saver.save(session, './tf-model' + str(epoch))
222 |
223 | total_iterations += num_iterations
224 |
225 | end_time = time.time()
226 | time_dif = end_time - start_time
227 | print("Time usage: " + str(timedelta(seconds=int(round(time_dif)))))
228 |
229 | optimize(10)
230 |
--------------------------------------------------------------------------------