├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── dataset └── README.md ├── evaluate.py ├── models ├── __init__.py ├── residual_block.py └── resnet.py ├── original_dataset └── README.md ├── prepare_data.py ├── saved_model └── README.md ├── split_dataset.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | /dataset/* 4 | !/dataset/README.md 5 | /original_dataset/* 6 | !/original_dataset/README.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 calmisential 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow2.0_ResNet 2 | A ResNet(**ResNet18, ResNet34, ResNet50, ResNet101, ResNet152**) implementation using TensorFlow-2.0 3 | 4 | See https://github.com/calmisential/Basic_CNNs_TensorFlow2.0 for more CNNs. 5 | 6 | ## Train 7 | 1. Requirements: 8 | + Python >= 3.6 9 | + Tensorflow == 2.0.0 10 | 2. To train the ResNet on your own dataset, you can put the dataset under the folder **original dataset**, and the directory should look like this: 11 | ``` 12 | |——original dataset 13 | |——class_name_0 14 | |——class_name_1 15 | |——class_name_2 16 | |——class_name_3 17 | ``` 18 | 3. Run the script **split_dataset.py** to split the raw dataset into train set, valid set and test set. 19 | 4. Change the corresponding parameters in **config.py**. 20 | 5. Run **train.py** to start training. 21 | ## Evaluate 22 | Run **evaluate.py** to evaluate the model's performance on the test dataset. 23 | 24 | ## The networks I have implemented with tensorflow2.0: 25 | + [ResNet18, ResNet34, ResNet50, ResNet101, ResNet152](https://github.com/calmisential/TensorFlow2.0_ResNet) 26 | + [InceptionV3](https://github.com/calmisential/TensorFlow2.0_InceptionV3) 27 | 28 | 29 | ## References 30 | 1. The original paper: https://arxiv.org/abs/1512.03385 31 | 2. The TensorFlow official tutorials: https://tensorflow.google.cn/beta/tutorials/quickstart/advanced -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # some training parameters 2 | EPOCHS = 50 3 | BATCH_SIZE = 8 4 | NUM_CLASSES = 5 5 | image_height = 224 6 | image_width = 224 7 | channels = 3 8 | save_model_dir = "saved_model/model" 9 | dataset_dir = "dataset/" 10 | train_dir = dataset_dir + "train" 11 | valid_dir = dataset_dir + "valid" 12 | test_dir = dataset_dir + "test" 13 | 14 | # choose a network 15 | # model = "resnet18" 16 | # model = "resnet34" 17 | model = "resnet50" 18 | # model = "resnet101" 19 | # model = "resnet152" 20 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | The dataset includes train set, valid set and test set. -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import config 3 | from prepare_data import generate_datasets 4 | from train import get_model 5 | 6 | if __name__ == '__main__': 7 | 8 | # GPU settings 9 | gpus = tf.config.experimental.list_physical_devices('GPU') 10 | if gpus: 11 | for gpu in gpus: 12 | tf.config.experimental.set_memory_growth(gpu, True) 13 | 14 | # get the original_dataset 15 | train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count = generate_datasets() 16 | # print(train_dataset) 17 | # load the model 18 | model = get_model() 19 | model.load_weights(filepath=config.save_model_dir) 20 | 21 | # Get the accuracy on the test set 22 | loss_object = tf.keras.metrics.SparseCategoricalCrossentropy() 23 | test_loss = tf.keras.metrics.Mean() 24 | test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() 25 | 26 | @tf.function 27 | def test_step(images, labels): 28 | predictions = model(images, training=False) 29 | t_loss = loss_object(labels, predictions) 30 | 31 | test_loss(t_loss) 32 | test_accuracy(labels, predictions) 33 | 34 | for test_images, test_labels in test_dataset: 35 | test_step(test_images, test_labels) 36 | print("loss: {:.5f}, test accuracy: {:.5f}".format(test_loss.result(), 37 | test_accuracy.result())) 38 | 39 | print("The accuracy on test set is: {:.3f}%".format(test_accuracy.result()*100)) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/calmiLovesAI/TensorFlow2.0_ResNet/6136d57c7580d24c9ee3e55ca0b820822565ce17/models/__init__.py -------------------------------------------------------------------------------- /models/residual_block.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class BasicBlock(tf.keras.layers.Layer): 5 | 6 | def __init__(self, filter_num, stride=1): 7 | super(BasicBlock, self).__init__() 8 | self.conv1 = tf.keras.layers.Conv2D(filters=filter_num, 9 | kernel_size=(3, 3), 10 | strides=stride, 11 | padding="same") 12 | self.bn1 = tf.keras.layers.BatchNormalization() 13 | self.conv2 = tf.keras.layers.Conv2D(filters=filter_num, 14 | kernel_size=(3, 3), 15 | strides=1, 16 | padding="same") 17 | self.bn2 = tf.keras.layers.BatchNormalization() 18 | if stride != 1: 19 | self.downsample = tf.keras.Sequential() 20 | self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num, 21 | kernel_size=(1, 1), 22 | strides=stride)) 23 | self.downsample.add(tf.keras.layers.BatchNormalization()) 24 | else: 25 | self.downsample = lambda x: x 26 | 27 | def call(self, inputs, training=None, **kwargs): 28 | residual = self.downsample(inputs) 29 | 30 | x = self.conv1(inputs) 31 | x = self.bn1(x, training=training) 32 | x = tf.nn.relu(x) 33 | x = self.conv2(x) 34 | x = self.bn2(x, training=training) 35 | 36 | output = tf.nn.relu(tf.keras.layers.add([residual, x])) 37 | 38 | return output 39 | 40 | 41 | class BottleNeck(tf.keras.layers.Layer): 42 | def __init__(self, filter_num, stride=1): 43 | super(BottleNeck, self).__init__() 44 | self.conv1 = tf.keras.layers.Conv2D(filters=filter_num, 45 | kernel_size=(1, 1), 46 | strides=1, 47 | padding='same') 48 | self.bn1 = tf.keras.layers.BatchNormalization() 49 | self.conv2 = tf.keras.layers.Conv2D(filters=filter_num, 50 | kernel_size=(3, 3), 51 | strides=stride, 52 | padding='same') 53 | self.bn2 = tf.keras.layers.BatchNormalization() 54 | self.conv3 = tf.keras.layers.Conv2D(filters=filter_num * 4, 55 | kernel_size=(1, 1), 56 | strides=1, 57 | padding='same') 58 | self.bn3 = tf.keras.layers.BatchNormalization() 59 | 60 | self.downsample = tf.keras.Sequential() 61 | self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num * 4, 62 | kernel_size=(1, 1), 63 | strides=stride)) 64 | self.downsample.add(tf.keras.layers.BatchNormalization()) 65 | 66 | def call(self, inputs, training=None, **kwargs): 67 | residual = self.downsample(inputs) 68 | 69 | x = self.conv1(inputs) 70 | x = self.bn1(x, training=training) 71 | x = tf.nn.relu(x) 72 | x = self.conv2(x) 73 | x = self.bn2(x, training=training) 74 | x = tf.nn.relu(x) 75 | x = self.conv3(x) 76 | x = self.bn3(x, training=training) 77 | 78 | output = tf.nn.relu(tf.keras.layers.add([residual, x])) 79 | 80 | return output 81 | 82 | 83 | def make_basic_block_layer(filter_num, blocks, stride=1): 84 | res_block = tf.keras.Sequential() 85 | res_block.add(BasicBlock(filter_num, stride=stride)) 86 | 87 | for _ in range(1, blocks): 88 | res_block.add(BasicBlock(filter_num, stride=1)) 89 | 90 | return res_block 91 | 92 | 93 | def make_bottleneck_layer(filter_num, blocks, stride=1): 94 | res_block = tf.keras.Sequential() 95 | res_block.add(BottleNeck(filter_num, stride=stride)) 96 | 97 | for _ in range(1, blocks): 98 | res_block.add(BottleNeck(filter_num, stride=1)) 99 | 100 | return res_block 101 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from config import NUM_CLASSES 3 | from models.residual_block import make_basic_block_layer, make_bottleneck_layer 4 | 5 | 6 | class ResNetTypeI(tf.keras.Model): 7 | def __init__(self, layer_params): 8 | super(ResNetTypeI, self).__init__() 9 | 10 | self.conv1 = tf.keras.layers.Conv2D(filters=64, 11 | kernel_size=(7, 7), 12 | strides=2, 13 | padding="same") 14 | self.bn1 = tf.keras.layers.BatchNormalization() 15 | self.pool1 = tf.keras.layers.MaxPool2D(pool_size=(3, 3), 16 | strides=2, 17 | padding="same") 18 | 19 | self.layer1 = make_basic_block_layer(filter_num=64, 20 | blocks=layer_params[0]) 21 | self.layer2 = make_basic_block_layer(filter_num=128, 22 | blocks=layer_params[1], 23 | stride=2) 24 | self.layer3 = make_basic_block_layer(filter_num=256, 25 | blocks=layer_params[2], 26 | stride=2) 27 | self.layer4 = make_basic_block_layer(filter_num=512, 28 | blocks=layer_params[3], 29 | stride=2) 30 | 31 | self.avgpool = tf.keras.layers.GlobalAveragePooling2D() 32 | self.fc = tf.keras.layers.Dense(units=NUM_CLASSES, activation=tf.keras.activations.softmax) 33 | 34 | def call(self, inputs, training=None, mask=None): 35 | x = self.conv1(inputs) 36 | x = self.bn1(x, training=training) 37 | x = tf.nn.relu(x) 38 | x = self.pool1(x) 39 | x = self.layer1(x, training=training) 40 | x = self.layer2(x, training=training) 41 | x = self.layer3(x, training=training) 42 | x = self.layer4(x, training=training) 43 | x = self.avgpool(x) 44 | output = self.fc(x) 45 | 46 | return output 47 | 48 | 49 | class ResNetTypeII(tf.keras.Model): 50 | def __init__(self, layer_params): 51 | super(ResNetTypeII, self).__init__() 52 | self.conv1 = tf.keras.layers.Conv2D(filters=64, 53 | kernel_size=(7, 7), 54 | strides=2, 55 | padding="same") 56 | self.bn1 = tf.keras.layers.BatchNormalization() 57 | self.pool1 = tf.keras.layers.MaxPool2D(pool_size=(3, 3), 58 | strides=2, 59 | padding="same") 60 | 61 | self.layer1 = make_bottleneck_layer(filter_num=64, 62 | blocks=layer_params[0]) 63 | self.layer2 = make_bottleneck_layer(filter_num=128, 64 | blocks=layer_params[1], 65 | stride=2) 66 | self.layer3 = make_bottleneck_layer(filter_num=256, 67 | blocks=layer_params[2], 68 | stride=2) 69 | self.layer4 = make_bottleneck_layer(filter_num=512, 70 | blocks=layer_params[3], 71 | stride=2) 72 | 73 | self.avgpool = tf.keras.layers.GlobalAveragePooling2D() 74 | self.fc = tf.keras.layers.Dense(units=NUM_CLASSES, activation=tf.keras.activations.softmax) 75 | 76 | def call(self, inputs, training=None, mask=None): 77 | x = self.conv1(inputs) 78 | x = self.bn1(x, training=training) 79 | x = tf.nn.relu(x) 80 | x = self.pool1(x) 81 | x = self.layer1(x, training=training) 82 | x = self.layer2(x, training=training) 83 | x = self.layer3(x, training=training) 84 | x = self.layer4(x, training=training) 85 | x = self.avgpool(x) 86 | output = self.fc(x) 87 | 88 | return output 89 | 90 | 91 | def resnet_18(): 92 | return ResNetTypeI(layer_params=[2, 2, 2, 2]) 93 | 94 | 95 | def resnet_34(): 96 | return ResNetTypeI(layer_params=[3, 4, 6, 3]) 97 | 98 | 99 | def resnet_50(): 100 | return ResNetTypeII(layer_params=[3, 4, 6, 3]) 101 | 102 | 103 | def resnet_101(): 104 | return ResNetTypeII(layer_params=[3, 4, 23, 3]) 105 | 106 | 107 | def resnet_152(): 108 | return ResNetTypeII(layer_params=[3, 8, 36, 3]) 109 | -------------------------------------------------------------------------------- /original_dataset/README.md: -------------------------------------------------------------------------------- 1 | Please put your pictures for classification here.
2 | The folder name is the type of the pictures which belong to the folder. -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import config 3 | import pathlib 4 | from config import image_height, image_width, channels 5 | 6 | 7 | def load_and_preprocess_image(img_path): 8 | # read pictures 9 | img_raw = tf.io.read_file(img_path) 10 | # decode pictures 11 | img_tensor = tf.image.decode_jpeg(img_raw, channels=channels) 12 | # resize 13 | img_tensor = tf.image.resize(img_tensor, [image_height, image_width]) 14 | img_tensor = tf.cast(img_tensor, tf.float32) 15 | # normalization 16 | img = img_tensor / 255.0 17 | return img 18 | 19 | def get_images_and_labels(data_root_dir): 20 | # get all images' paths (format: string) 21 | data_root = pathlib.Path(data_root_dir) 22 | all_image_path = [str(path) for path in list(data_root.glob('*/*'))] 23 | # get labels' names 24 | label_names = sorted(item.name for item in data_root.glob('*/')) 25 | # dict: {label : index} 26 | label_to_index = dict((label, index) for index, label in enumerate(label_names)) 27 | # get all images' labels 28 | all_image_label = [label_to_index[pathlib.Path(single_image_path).parent.name] for single_image_path in all_image_path] 29 | 30 | return all_image_path, all_image_label 31 | 32 | 33 | def get_dataset(dataset_root_dir): 34 | all_image_path, all_image_label = get_images_and_labels(data_root_dir=dataset_root_dir) 35 | # print("image_path: {}".format(all_image_path[:])) 36 | # print("image_label: {}".format(all_image_label[:])) 37 | # load the dataset and preprocess images 38 | image_dataset = tf.data.Dataset.from_tensor_slices(all_image_path).map(load_and_preprocess_image) 39 | label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label) 40 | dataset = tf.data.Dataset.zip((image_dataset, label_dataset)) 41 | image_count = len(all_image_path) 42 | 43 | return dataset, image_count 44 | 45 | 46 | def generate_datasets(): 47 | train_dataset, train_count = get_dataset(dataset_root_dir=config.train_dir) 48 | valid_dataset, valid_count = get_dataset(dataset_root_dir=config.valid_dir) 49 | test_dataset, test_count = get_dataset(dataset_root_dir=config.test_dir) 50 | 51 | 52 | # read the original_dataset in the form of batch 53 | train_dataset = train_dataset.shuffle(buffer_size=train_count).batch(batch_size=config.BATCH_SIZE) 54 | valid_dataset = valid_dataset.batch(batch_size=config.BATCH_SIZE) 55 | test_dataset = test_dataset.batch(batch_size=config.BATCH_SIZE) 56 | 57 | return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count 58 | -------------------------------------------------------------------------------- /saved_model/README.md: -------------------------------------------------------------------------------- 1 | The trained model will be saved here. -------------------------------------------------------------------------------- /split_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | 5 | 6 | class SplitDataset(): 7 | def __init__(self, dataset_dir, saved_dataset_dir, train_ratio=0.6, test_ratio=0.2, show_progress=False): 8 | self.dataset_dir = dataset_dir 9 | self.saved_dataset_dir = saved_dataset_dir 10 | self.saved_train_dir = saved_dataset_dir + "/train/" 11 | self.saved_valid_dir = saved_dataset_dir + "/valid/" 12 | self.saved_test_dir = saved_dataset_dir + "/test/" 13 | 14 | 15 | self.train_ratio = train_ratio 16 | self.test_radio = test_ratio 17 | self.valid_ratio = 1 - train_ratio - test_ratio 18 | 19 | self.train_file_path = [] 20 | self.valid_file_path = [] 21 | self.test_file_path = [] 22 | 23 | self.index_label_dict = {} 24 | 25 | self.show_progress = show_progress 26 | 27 | if not os.path.exists(self.saved_train_dir): 28 | os.mkdir(self.saved_train_dir) 29 | if not os.path.exists(self.saved_test_dir): 30 | os.mkdir(self.saved_test_dir) 31 | if not os.path.exists(self.saved_valid_dir): 32 | os.mkdir(self.saved_valid_dir) 33 | 34 | 35 | def __get_label_names(self): 36 | label_names = [] 37 | for item in os.listdir(self.dataset_dir): 38 | item_path = os.path.join(self.dataset_dir, item) 39 | if os.path.isdir(item_path): 40 | label_names.append(item) 41 | return label_names 42 | 43 | def __get_all_file_path(self): 44 | all_file_path = [] 45 | index = 0 46 | for file_type in self.__get_label_names(): 47 | self.index_label_dict[index] = file_type 48 | index += 1 49 | type_file_path = os.path.join(self.dataset_dir, file_type) 50 | file_path = [] 51 | for file in os.listdir(type_file_path): 52 | single_file_path = os.path.join(type_file_path, file) 53 | file_path.append(single_file_path) 54 | all_file_path.append(file_path) 55 | return all_file_path 56 | 57 | def __copy_files(self, type_path, type_saved_dir): 58 | for item in type_path: 59 | src_path_list = item[1] 60 | dst_path = type_saved_dir + "%s/" % (item[0]) 61 | if not os.path.exists(dst_path): 62 | os.mkdir(dst_path) 63 | for src_path in src_path_list: 64 | shutil.copy(src_path, dst_path) 65 | if self.show_progress: 66 | print("Copying file "+src_path+" to "+dst_path) 67 | 68 | def __split_dataset(self): 69 | all_file_paths = self.__get_all_file_path() 70 | for index in range(len(all_file_paths)): 71 | file_path_list = all_file_paths[index] 72 | file_path_list_length = len(file_path_list) 73 | random.shuffle(file_path_list) 74 | 75 | train_num = int(file_path_list_length * self.train_ratio) 76 | test_num = int(file_path_list_length * self.test_radio) 77 | 78 | self.train_file_path.append([self.index_label_dict[index], file_path_list[: train_num]]) 79 | self.test_file_path.append([self.index_label_dict[index], file_path_list[train_num:train_num + test_num]]) 80 | self.valid_file_path.append([self.index_label_dict[index], file_path_list[train_num + test_num:]]) 81 | 82 | def start_splitting(self): 83 | self.__split_dataset() 84 | self.__copy_files(type_path=self.train_file_path, type_saved_dir=self.saved_train_dir) 85 | self.__copy_files(type_path=self.valid_file_path, type_saved_dir=self.saved_valid_dir) 86 | self.__copy_files(type_path=self.test_file_path, type_saved_dir=self.saved_test_dir) 87 | 88 | 89 | if __name__ == '__main__': 90 | split_dataset = SplitDataset(dataset_dir="original_dataset", 91 | saved_dataset_dir="dataset", 92 | show_progress=True) 93 | split_dataset.start_splitting() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import tensorflow as tf 3 | from models.resnet import resnet_18, resnet_34, resnet_50, resnet_101, resnet_152 4 | import config 5 | from prepare_data import generate_datasets 6 | import math 7 | 8 | 9 | def get_model(): 10 | model = resnet_50() 11 | if config.model == "resnet18": 12 | model = resnet_18() 13 | if config.model == "resnet34": 14 | model = resnet_34() 15 | if config.model == "resnet101": 16 | model = resnet_101() 17 | if config.model == "resnet152": 18 | model = resnet_152() 19 | model.build(input_shape=(None, config.image_height, config.image_width, config.channels)) 20 | model.summary() 21 | return model 22 | 23 | 24 | if __name__ == '__main__': 25 | # GPU settings 26 | gpus = tf.config.experimental.list_physical_devices('GPU') 27 | if gpus: 28 | for gpu in gpus: 29 | tf.config.experimental.set_memory_growth(gpu, True) 30 | 31 | 32 | # get the original_dataset 33 | train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count = generate_datasets() 34 | 35 | 36 | # create model 37 | model = get_model() 38 | 39 | # define loss and optimizer 40 | loss_object = tf.keras.losses.SparseCategoricalCrossentropy() 41 | optimizer = tf.keras.optimizers.Adadelta() 42 | 43 | train_loss = tf.keras.metrics.Mean(name='train_loss') 44 | train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') 45 | 46 | valid_loss = tf.keras.metrics.Mean(name='valid_loss') 47 | valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy') 48 | 49 | @tf.function 50 | def train_step(images, labels): 51 | with tf.GradientTape() as tape: 52 | predictions = model(images, training=True) 53 | loss = loss_object(y_true=labels, y_pred=predictions) 54 | gradients = tape.gradient(loss, model.trainable_variables) 55 | optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables)) 56 | 57 | train_loss(loss) 58 | train_accuracy(labels, predictions) 59 | 60 | @tf.function 61 | def valid_step(images, labels): 62 | predictions = model(images, training=False) 63 | v_loss = loss_object(labels, predictions) 64 | 65 | valid_loss(v_loss) 66 | valid_accuracy(labels, predictions) 67 | 68 | # start training 69 | for epoch in range(config.EPOCHS): 70 | train_loss.reset_states() 71 | train_accuracy.reset_states() 72 | valid_loss.reset_states() 73 | valid_accuracy.reset_states() 74 | step = 0 75 | for images, labels in train_dataset: 76 | step += 1 77 | train_step(images, labels) 78 | print("Epoch: {}/{}, step: {}/{}, loss: {:.5f}, accuracy: {:.5f}".format(epoch + 1, 79 | config.EPOCHS, 80 | step, 81 | math.ceil(train_count / config.BATCH_SIZE), 82 | train_loss.result(), 83 | train_accuracy.result())) 84 | 85 | for valid_images, valid_labels in valid_dataset: 86 | valid_step(valid_images, valid_labels) 87 | 88 | print("Epoch: {}/{}, train loss: {:.5f}, train accuracy: {:.5f}, " 89 | "valid loss: {:.5f}, valid accuracy: {:.5f}".format(epoch + 1, 90 | config.EPOCHS, 91 | train_loss.result(), 92 | train_accuracy.result(), 93 | valid_loss.result(), 94 | valid_accuracy.result())) 95 | 96 | model.save_weights(filepath=config.save_model_dir, save_format='tf') 97 | --------------------------------------------------------------------------------