├── README.md ├── counter.txt ├── data.rar ├── getdata.py ├── getkeys.py ├── keras ├── collect_sample.py └── model_keras.py ├── pytorch └── readme.md. └── tensorflow ├── load_data.py ├── loaddata_2.py ├── loaddata_3.py └── model_tf.py /README.md: -------------------------------------------------------------------------------- 1 | # chrome_Trex 2 | 3 | This program automates the Google Chrome Trex game. 4 | You can use this URL for collecting data and testing. 5 | http://wayou.github.io/t-rex-runner/ 6 | 7 | A convolutional neural network is being used to predict the keyboard input. 8 | 9 | ## getdata.py:
10 | Collects training images and stores them in data/
11 | ## getkeys.py:
12 | Contains helper function for getting keyboard input
13 | (Source: https://github.com/Sentdex/pygta5/blob/master/Versions/v0.02/getkeys.py)
14 | 15 | Different models are implemented in keras and tensorflow. 16 | -------------------------------------------------------------------------------- /counter.txt: -------------------------------------------------------------------------------- 1 | 401 2 | 3 3 | 8024 4 | -------------------------------------------------------------------------------- /data.rar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SouravSharan/chrome_Trex/e5efc3443886f01c82c5272e8ad504fd8a501784/data.rar -------------------------------------------------------------------------------- /getdata.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import ImageGrab 3 | import cv2 4 | import time 5 | from getkeys import key_check 6 | 7 | file = open("G://Works//Chrome T-rex//counter.txt", 'r') 8 | items = file.readlines() 9 | file.close() 10 | counter = [] 11 | counter = list(map(int, items)) 12 | print(counter) 13 | 14 | path = 'G://Works//Chrome T-rex//data//' 15 | 16 | up = 38 17 | down = 40 18 | 19 | for i in list(range(4))[::-1]: 20 | print(i+1) 21 | time.sleep(1) 22 | last_time = time.time() 23 | while True: 24 | screen = np.array(ImageGrab.grab(bbox=(360,100,700,440))) 25 | screen = cv2.cvtColor(screen, cv2.COLOR_RGB2BGR) 26 | keys = key_check() 27 | 28 | if up in keys: 29 | cv2.imwrite(path + 'up/' + str(counter[0]) + ".jpg",screen) 30 | counter[0]+=1 31 | time.sleep(0.5) 32 | elif down in keys: 33 | cv2.imwrite(path + 'down/' + str(counter[1]) + ".jpg",screen) 34 | counter[1]+=1 35 | time.sleep(0.5) 36 | else: 37 | cv2.imwrite(path + 'null/' + str(counter[2]) + ".jpg",screen) 38 | counter[2]+=1 39 | 40 | if ord('E') in keys: 41 | break 42 | if cv2.waitKey(25) & 0xFF == ord('q'): 43 | cv2.destroyAllWindows() 44 | break 45 | 46 | file = open("G://Works//Chrome T-rex//counter.txt", 'w') 47 | for ch in counter: 48 | print(ch) 49 | file.write(str(ch) + "\n") 50 | file.close() 51 | -------------------------------------------------------------------------------- /getkeys.py: -------------------------------------------------------------------------------- 1 | # Citation: Box Of Hats (https://github.com/Box-Of-Hats ) 2 | 3 | import win32api as wapi 4 | import time 5 | 6 | keyList = [] 7 | 8 | for char in "ABCDEFGHIJKLMNOPQRSTUVWXYZ 123456789,.'£$/\\": 9 | keyList.append(ord(char)) 10 | 11 | keyList.append(13) #"0x0D") #enter 12 | keyList.append(37 ) #left_arrow 13 | keyList.append(38) #up_arrow 14 | keyList.append(39) #right_arrow 15 | keyList.append(40) #down_arrow 16 | 17 | def key_check(): 18 | keys = [] 19 | for key in keyList: 20 | if wapi.GetAsyncKeyState(int(key)): 21 | keys.append(key) 22 | 23 | return keys 24 | -------------------------------------------------------------------------------- /keras/collect_sample.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import math 4 | from keras.models import load_model 5 | import numpy as np 6 | import pyautogui 7 | import time 8 | from PIL import ImageGrab 9 | 10 | model = load_model('whole_model.h5') 11 | while True: 12 | screen = np.array(ImageGrab.grab(bbox=(360,100,700,440))) 13 | screen = cv2.cvtColor(screen, cv2.COLOR_RGB2BGR) 14 | 15 | cv2.imwrite('./screen.jpg', screen) 16 | screen = cv2.imread('./screen.jpg') 17 | screen = np.expand_dims(screen, axis=0) 18 | key = model.predict(screen, batch_size = 1, verbose = 0) 19 | k = key[0] 20 | print(k[1]) 21 | if k[1] == 1: 22 | pyautogui.press('space') 23 | -------------------------------------------------------------------------------- /keras/model_keras.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras.preprocessing.image import ImageDataGenerator 3 | from keras.models import Sequential 4 | from keras.layers import Conv2D, MaxPooling2D 5 | from keras.layers import Activation, Dropout, Flatten, Dense 6 | from keras import backend as K 7 | from keras import utils 8 | #1import h5py 9 | 10 | # dimensions of our images. 11 | img_width, img_height = 340, 340 12 | 13 | train_data_dir = 'G:/Works/Chrome T-rex/data' 14 | #validation_data_dir = 'data/validation' 15 | nb_train_samples = 768 16 | #nb_validation_samples = 800 17 | epochs = 30 18 | batch_size = 48 19 | 20 | if K.image_data_format() == 'channels_first': 21 | input_shape = (3, img_width, img_height) 22 | else: 23 | input_shape = (img_width, img_height, 3) 24 | 25 | model = Sequential() 26 | model.add(Conv2D(32, (3, 3), input_shape=input_shape)) 27 | model.add(Activation('relu')) 28 | model.add(MaxPooling2D(pool_size=(2, 2))) 29 | 30 | model.add(Conv2D(64, (3, 3))) 31 | model.add(Activation('relu')) 32 | model.add(MaxPooling2D(pool_size=(2, 2))) 33 | 34 | model.add(Conv2D(64, (3, 3))) 35 | model.add(Activation('relu')) 36 | model.add(MaxPooling2D(pool_size=(2, 2))) 37 | 38 | model.add(Conv2D(128, (3, 3))) 39 | model.add(Activation('relu')) 40 | model.add(MaxPooling2D(pool_size=(2, 2))) 41 | 42 | model.add(Flatten()) 43 | model.add(Dense(128)) 44 | model.add(Activation('relu')) 45 | model.add(Dropout(0.5)) 46 | model.add(Dense(2)) 47 | model.add(Activation('softmax')) 48 | 49 | model.compile(loss=keras.losses.categorical_crossentropy, 50 | optimizer=keras.optimizers.Adadelta(), 51 | metrics=['accuracy']) 52 | 53 | # this is the augmentation configuration we will use for training 54 | train_datagen = ImageDataGenerator( 55 | rescale=1. / 255, 56 | shear_range=0.2, 57 | zoom_range=0.2, 58 | horizontal_flip=True) 59 | 60 | # this is the augmentation configuration we will use for testing: 61 | # only rescaling 62 | test_datagen = ImageDataGenerator(rescale=1. / 255) 63 | 64 | train_generator = train_datagen.flow_from_directory( 65 | train_data_dir, 66 | target_size=(img_width, img_height), 67 | batch_size=batch_size, 68 | class_mode='categorical') 69 | ''' 70 | validation_generator = test_datagen.flow_from_directory( 71 | validation_data_dir, 72 | target_size=(img_width, img_height), 73 | batch_size=batch_size, 74 | class_mode='categorical') 75 | ''' 76 | model.fit_generator( 77 | train_generator, 78 | steps_per_epoch=nb_train_samples // batch_size, 79 | epochs=epochs) 80 | #validation_data=validation_generator, 81 | #validation_steps=nb_validation_samples // batch_size) 82 | #model.save_weights('./model_weights.h5') 83 | model.save('./whole_model.h5') 84 | -------------------------------------------------------------------------------- /pytorch/readme.md.: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tensorflow/load_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import cv2 5 | import numpy as np 6 | from keras import backend as K 7 | from keras.utils import np_utils 8 | import argparse 9 | from datetime import datetime 10 | import hashlib 11 | import os.path 12 | import random 13 | import re 14 | import sys 15 | import tarfile 16 | from six.moves import urllib 17 | 18 | 19 | import tensorflow as tf 20 | from tensorflow.python.framework import graph_util 21 | from tensorflow.python.framework import tensor_shape 22 | from tensorflow.python.platform import gfile 23 | from tensorflow.python.util import compat 24 | 25 | FLAGS = None 26 | def create_image_lists(): 27 | image_dir='/home/rick/derma/dataset' 28 | testing_percentage=20 29 | 30 | result = {} 31 | counter_for_result_label=0 32 | 33 | sub_dirs = [x[0] for x in gfile.Walk(image_dir)] #create sub_dirs 34 | 35 | 36 | # The root directory comes first, so skip it. 37 | 38 | dir_name=[] 39 | 40 | #ignore first element in sub_dir 41 | is_root_dir = True 42 | for sub_dir in sub_dirs: 43 | if is_root_dir: 44 | is_root_dir = False 45 | continue 46 | 47 | 48 | dir_name = os.path.basename(sub_dir) 49 | 50 | extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] 51 | file_list = [] 52 | #dir_name = os.path.basename(image_dir) 53 | #if dir_name == image_dir: 54 | #continue 55 | tf.logging.info("Looking for images in '" + dir_name + "'") 56 | for extension in extensions: 57 | #for image_dir in sub_dir 58 | file_glob = os.path.join(image_dir, dir_name, '*.' + extension) 59 | file_list.extend(gfile.Glob(file_glob)) #create a list of all files 60 | 61 | #using regex to set label name 62 | label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower()) 63 | 64 | #dividing 65 | training_images = [] 66 | testing_images = [] 67 | for file_name in file_list: 68 | base_name = os.path.basename(file_name) 69 | hash_name = re.sub(r'_nohash_.*$', '', file_name) 70 | 71 | hash_name_hashed = hashlib.sha1(compat.as_bytes(hash_name)).hexdigest() 72 | percentage_hash = ((int(hash_name_hashed, 16) % 73 | (MAX_NUM_IMAGES_PER_CLASS + 1)) * 74 | (100.0 / MAX_NUM_IMAGES_PER_CLASS)) 75 | if percentage_hash < testing_percentage: 76 | #testing_images.append(file_name) 77 | testing_images.append(cv2.imread(file_name)) 78 | #testing_images.append(base_name) 79 | else: 80 | #training_images.append(file_name) 81 | training_images.append(cv2.imread(file_name)) 82 | #training_images.append(base_name) 83 | 84 | 85 | result[counter_for_result_label] = { 86 | 'training_label': [counter_for_result_label]*(len(training_images)), 87 | 'testing_label': [counter_for_result_label]*(len(testing_images)), 88 | 'training': training_images, 89 | 'testing': testing_images, 90 | } 91 | 92 | counter_for_result_label=counter_for_result_label+1 93 | return result 94 | -------------------------------------------------------------------------------- /tensorflow/loaddata_2.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | # Dataset Parameters - CHANGE HERE 5 | MODE = 'folder' # or 'file', if you choose a plain text file (see above). 6 | dataset_path = 'G:/Works/Chrome T-rex/data' # the dataset file or root folder path. 7 | BATCH_SIZE = 48 8 | # Image Parameters 9 | N_CLASSES = 2 # CHANGE HERE, total number of classes 10 | IMG_HEIGHT = 64 # CHANGE HERE, the image height to be resized to 11 | IMG_WIDTH = 64 # CHANGE HERE, the image width to be resized to 12 | CHANNELS = 3 # The 3 color channels, change to 1 if grayscale 13 | 14 | # Reading the dataset 15 | # 2 modes: 'file' or 'folder' 16 | def read_images(mode, batch_size): 17 | global dataset_path 18 | imagepaths, labels = list(), list() 19 | if mode == 'file': 20 | # Read dataset file 21 | data = open(dataset_path, 'r').read().splitlines() 22 | for d in data: 23 | imagepaths.append(d.split(' ')[0]) 24 | labels.append(int(d.split(' ')[1])) 25 | elif mode == 'folder': 26 | # An ID will be affected to each sub-folders by alphabetical order 27 | label = 0 28 | # List the directory 29 | try: # Python 2 30 | classes = sorted(os.walk(dataset_path).next()[1]) 31 | except Exception: # Python 3 32 | classes = sorted(os.walk(dataset_path).__next__()[1]) 33 | # List each sub-directory (the classes) 34 | for c in classes: 35 | c_dir = os.path.join(dataset_path, c) 36 | try: # Python 2 37 | walk = os.walk(c_dir).next() 38 | except Exception: # Python 3 39 | walk = os.walk(c_dir).__next__() 40 | # Add each image to the training set 41 | for sample in walk[2]: 42 | # Only keeps jpeg images 43 | if sample.endswith('.jpg') or sample.endswith('.jpeg'): 44 | imagepaths.append(os.path.join(c_dir, sample)) 45 | labels.append(label) 46 | label += 1 47 | else: 48 | raise Exception("Unknown mode.") 49 | 50 | # Convert to Tensor 51 | imagepaths = tf.convert_to_tensor(imagepaths, dtype=tf.string) 52 | labels = tf.convert_to_tensor(labels, dtype=tf.int32) 53 | # Build a TF Queue, shuffle data 54 | image, label = tf.train.slice_input_producer([imagepaths, labels], 55 | shuffle=True) 56 | 57 | # Read images from disk 58 | image = tf.read_file(image) 59 | image = tf.image.decode_jpeg(image, channels=CHANNELS) 60 | 61 | # Resize images to a common size 62 | image = tf.image.resize_images(image, [IMG_HEIGHT, IMG_WIDTH]) 63 | 64 | # Normalize 65 | image = image * 1.0/127.5 - 1.0 66 | 67 | # Create batches 68 | X, Y = tf.train.batch([image, label], batch_size=batch_size, 69 | capacity=batch_size * 8, 70 | num_threads=4) 71 | 72 | return X, Y 73 | 74 | #print(read_images('folder', 48)) 75 | -------------------------------------------------------------------------------- /tensorflow/loaddata_3.py: -------------------------------------------------------------------------------- 1 | # reference : https://github.com/sjchoi86/tensorflow-101/blob/master/notebooks/basic_gendataset.ipynb 2 | import numpy as np 3 | import os 4 | from scipy.misc import imread, imresize 5 | 6 | cwd = os.getcwd() 7 | print ("Current folder is %s" % (cwd) ) 8 | 9 | # Training set folder 10 | paths = {"G:/Works/Chrome T-rex/data/null", "G:/Works/Chrome T-rex/data/up"} 11 | 12 | # The reshape size 13 | imgsize = [340, 340] 14 | 15 | # Grayscale 16 | use_gray = 1 17 | 18 | # Save name 19 | data_name = "custom_data" 20 | 21 | def rgb2gray(rgb): 22 | if len(rgb.shape) is 3: 23 | return np.dot(rgb[...,:3], [0.299, 0.587, 0.114]) 24 | else: 25 | return rgb 26 | 27 | nclass = len(paths) 28 | valid_exts = [".jpg",".gif",".png",".tga", ".jpeg"] 29 | imgcnt = 0 30 | for i, relpath in zip(range(nclass), paths): 31 | path = relpath 32 | flist = os.listdir(path) 33 | for f in flist: 34 | if os.path.splitext(f)[1].lower() not in valid_exts: 35 | continue 36 | fullpath = os.path.join(path, f) 37 | print(fullpath) 38 | currimg = imread(fullpath) 39 | # Convert to grayscale 40 | if use_gray: 41 | grayimg = rgb2gray(currimg) 42 | else: 43 | grayimg = currimg 44 | # Reshape 45 | graysmall = imresize(grayimg, [imgsize[0], imgsize[1]])/255. 46 | grayvec = np.reshape(graysmall, (1, -1)) 47 | # Save 48 | curr_label = np.eye(nclass, nclass)[i:i+1, :] 49 | if imgcnt is 0: 50 | totalimg = grayvec 51 | totallabel = curr_label 52 | else: 53 | totalimg = np.concatenate((totalimg, grayvec), axis=0) 54 | totallabel = np.concatenate((totallabel, curr_label), axis=0) 55 | imgcnt = imgcnt + 1 56 | 57 | print ("Total %d images loaded." % (imgcnt)) 58 | 59 | savepath = "G:/Works/Chrome T-rex/tensorflow/" + data_name + ".npz" 60 | 61 | np.savez(savepath, trainimg=totalimg, trainlabel=totallabel , imgsize=imgsize, use_gray=use_gray) 62 | -------------------------------------------------------------------------------- /tensorflow/model_tf.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import numpy as np 4 | import time 5 | 6 | ########Load data######### 7 | 8 | loadpath = "G:/Works/Chrome T-rex/tensorflow/custom_data.npz" 9 | l = np.load(loadpath) 10 | 11 | l.files 12 | 13 | #Parse data 14 | trainimg = l["trainimg"] 15 | trainlabel = l["trainlabel"] 16 | ntrain = trainimg.shape[0] 17 | nclass = trainlabel.shape[1] 18 | dim = trainimg.shape[1] 19 | 20 | print ("%d train images loaded" % (ntrain)) 21 | print ("%d dimensional input" % (dim)) 22 | print ("%d classes" % (nclass)) 23 | print ("shape of 'trainimg' is %s" % (trainimg.shape,)) 24 | 25 | ''' 26 | trainimg_tensor = np.ndarray((ntrain, 340, 340, 1)) 27 | for i in range(ntrain): 28 | currimg = trainimg[i, :] 29 | currimg = np.reshape(currimg, [340, 340, 1]) 30 | trainimg_tensor[i, :, :, :] = currimg 31 | 32 | print ("shape of trainimg_tensor is %s" % (trainimg_tensor.shape,)) 33 | ''' 34 | ########################## 35 | 36 | ########## CNN ########## 37 | 38 | # Convolutional Layer 1 39 | filterSize1 = 5 40 | numFilters1 = 16 41 | stride1_x = 1 42 | stride1_y = 1 43 | 44 | # Convolutional Layer 2 45 | filterSize2 = 5 46 | numFilters2 = 16 47 | stride2_x = 2 48 | stride2_y = 2 49 | 50 | # Convolutional Layer 3 51 | filterSize3 = 5 52 | numFilters3 = 32 53 | stride3_x = 2 54 | stride3_y = 2 55 | 56 | # Convolutional Layer 4 57 | filterSize4 = 3 58 | numFilters4 = 64 59 | stride4_x = 2 60 | stride4_y = 2 61 | 62 | #FC 1 63 | fc_size = 128 64 | 65 | #Image Dimentions 66 | img_w = 340 67 | img_h = 340 68 | img_size_flat = img_h*img_w 69 | num_channels = 1 70 | 71 | num_classes = 2 72 | 73 | def new_weights(shape): 74 | return tf.Variable(tf.truncated_normal(shape, stddev=0.05)) 75 | 76 | 77 | 78 | def new_biases(length): 79 | return tf.Variable(tf.constant(0.05, shape=[length])) 80 | 81 | 82 | 83 | def new_conv_layer(input, # The previous layer. 84 | num_input_channels, # Num. channels in prev. layer. 85 | filter_size, # Width and height of each filter. 86 | num_filters, # Number of filters. 87 | stride_x, 88 | stride_y): 89 | 90 | # Shape of the filter-weights for the convolution. 91 | shape = [filter_size, filter_size, num_input_channels, num_filters] 92 | 93 | # Create new weight (filters) 94 | weights = new_weights(shape=shape) 95 | 96 | # Create new biases, one for each filter. 97 | biases = new_biases(length=num_filters) 98 | 99 | layer = tf.nn.conv2d(input=input, 100 | filter=weights, 101 | strides=[1, stride_y, stride_x, 1], 102 | padding='SAME') 103 | 104 | # A bias-value is added to each filter-channel. 105 | layer += biases 106 | 107 | # Rectified Linear Unit (ReLU). 108 | layer = tf.nn.relu(layer) 109 | 110 | return layer, weights 111 | 112 | def flatten_layer(layer): 113 | layer_shape = layer.get_shape() 114 | num_features = layer_shape[1:4].num_elements() 115 | layer_flat = tf.reshape(layer, [-1, num_features]) 116 | return layer_flat, num_features 117 | 118 | def new_fc_layer(input, num_inputs, num_outputs): 119 | weights = new_weights(shape=[num_inputs, num_outputs]) 120 | biases = new_biases(length=num_outputs) 121 | layer = tf.matmul(input, weights) + biases 122 | layer = tf.nn.relu(layer) 123 | return layer 124 | 125 | x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x') 126 | x_image = tf.reshape(x, [-1, img_h, img_w, num_channels]) 127 | y_true = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_true') 128 | y_true_cls = tf.argmax(y_true, dimension=1) 129 | 130 | layer_conv1, weights_conv1 = new_conv_layer(input=x_image, 131 | num_input_channels=num_channels, 132 | filter_size=filterSize1, 133 | num_filters=numFilters1, 134 | stride_x=stride1_x, 135 | stride_y=stride1_y ) 136 | 137 | print(layer_conv1) 138 | 139 | layer_conv2, weights_conv2 = new_conv_layer(input=layer_conv1, 140 | num_input_channels=numFilters1, 141 | filter_size=filterSize2, 142 | num_filters=numFilters2, 143 | stride_x=stride2_x, 144 | stride_y=stride2_y ) 145 | 146 | print(layer_conv2) 147 | 148 | layer_conv3, weights_conv3 = new_conv_layer(input=layer_conv2, 149 | num_input_channels=numFilters2, 150 | filter_size=filterSize3, 151 | num_filters=numFilters3, 152 | stride_x=stride3_x, 153 | stride_y=stride3_y ) 154 | 155 | print(layer_conv3) 156 | 157 | 158 | layer_conv4, weights_conv4 = new_conv_layer(input=layer_conv3, 159 | num_input_channels=numFilters3, 160 | filter_size=filterSize4, 161 | num_filters=numFilters4, 162 | stride_x=stride4_x, 163 | stride_y=stride4_y ) 164 | 165 | print(layer_conv4) 166 | 167 | conv_shape = tf.shape(layer_conv4) 168 | layer_flat, num_features = flatten_layer(layer_conv4) 169 | 170 | print(layer_flat, num_features) 171 | 172 | 173 | layer_fc1 = new_fc_layer(input=layer_flat, 174 | num_inputs=num_features, 175 | num_outputs=fc_size) 176 | 177 | print(layer_fc1) 178 | 179 | layer_fc2 = new_fc_layer(input=layer_fc1, 180 | num_inputs=fc_size, 181 | num_outputs=num_classes) 182 | 183 | print(layer_fc2) 184 | 185 | y_pred = tf.nn.softmax(layer_fc2) 186 | y_pred_cls = tf.argmax(y_pred, dimension=1) 187 | 188 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=layer_fc2, labels=y_true) 189 | cost = tf.reduce_mean(cross_entropy) 190 | optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(cost) 191 | 192 | correct_prediction = tf.equal(y_pred_cls, y_true_cls) 193 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 194 | 195 | session = tf.Session() 196 | session.run(tf.global_variables_initializer()) 197 | 198 | batch_size = 48 199 | total_iterations = 0 200 | 201 | save_step = 1 202 | saver = tf.train.Saver(max_to_keep=3) 203 | 204 | def optimize(num_iterations): 205 | global total_iterations 206 | start_time = time.time() 207 | 208 | for epoch in range(total_iterations, total_iterations + num_iterations): 209 | num_batch = int(ntrain/batch_size)+1 210 | for i in range(num_batch): 211 | randidx = np.random.randint(ntrain, size=batch_size) 212 | #batch_xs = train_vectorized[randidx, :] 213 | batch_xs = trainimg[randidx, :] 214 | batch_ys = trainlabel[randidx, :] 215 | session.run(optimizer, feed_dict={x: batch_xs, y_true: batch_ys}) 216 | print(str(epoch) + ":" + str(i)) 217 | 218 | acc = session.run(accuracy, feed_dict={x: batch_xs, y_true: batch_ys}) 219 | msg = "Optimization Iteration: {0:>6}, Training Accuracy: {1:>6.1%}" 220 | print(msg.format(epoch + 1, acc)) 221 | saver.save(session, './tf-model' + str(epoch)) 222 | 223 | total_iterations += num_iterations 224 | 225 | end_time = time.time() 226 | time_dif = end_time - start_time 227 | print("Time usage: " + str(timedelta(seconds=int(round(time_dif))))) 228 | 229 | optimize(10) 230 | --------------------------------------------------------------------------------