├── src ├── test.ipynb ├── 加噪scale=0.005结果.xls ├── 加噪scale=0.006结果.xls ├── 加噪scale=0.008结果.xls ├── 加噪scale=0.015结果.xls ├── 加噪scale=0.01结果.xls ├── 加噪scale=0.02结果.xls ├── 加噪scale=0.03结果.xls ├── 加噪scale=0.05结果.xls ├── 加噪scale=0.1结果.xls ├── 加噪scale=0.0005结果.xls ├── Dataset.py ├── Server.py ├── Model.py ├── Client.py └── VGG.py ├── requirements.txt ├── .idea ├── .gitignore ├── vcs.xml ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml └── tf-fed-demo.iml ├── data └── 未加噪结果.xls ├── result ├── 未加噪结果.xls ├── 未加噪结果d=0.6.xls ├── 未加噪结果d=0.7.xls └── 未加噪结果改变每个epoch的训练次数.xls ├── README.md ├── .gitignore └── LICENSE /src/test.ipynb: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu=1.14.0 2 | tqdm 3 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /data/未加噪结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/data/未加噪结果.xls -------------------------------------------------------------------------------- /result/未加噪结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/result/未加噪结果.xls -------------------------------------------------------------------------------- /result/未加噪结果d=0.6.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/result/未加噪结果d=0.6.xls -------------------------------------------------------------------------------- /result/未加噪结果d=0.7.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/result/未加噪结果d=0.7.xls -------------------------------------------------------------------------------- /src/加噪scale=0.005结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/src/加噪scale=0.005结果.xls -------------------------------------------------------------------------------- /src/加噪scale=0.006结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/src/加噪scale=0.006结果.xls -------------------------------------------------------------------------------- /src/加噪scale=0.008结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/src/加噪scale=0.008结果.xls -------------------------------------------------------------------------------- /src/加噪scale=0.015结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/src/加噪scale=0.015结果.xls -------------------------------------------------------------------------------- /src/加噪scale=0.01结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/src/加噪scale=0.01结果.xls -------------------------------------------------------------------------------- /src/加噪scale=0.02结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/src/加噪scale=0.02结果.xls -------------------------------------------------------------------------------- /src/加噪scale=0.03结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/src/加噪scale=0.03结果.xls -------------------------------------------------------------------------------- /src/加噪scale=0.05结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/src/加噪scale=0.05结果.xls -------------------------------------------------------------------------------- /src/加噪scale=0.1结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/src/加噪scale=0.1结果.xls -------------------------------------------------------------------------------- /src/加噪scale=0.0005结果.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/src/加噪scale=0.0005结果.xls -------------------------------------------------------------------------------- /result/未加噪结果改变每个epoch的训练次数.xls: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blyspyder/tf-fed-demo/HEAD/result/未加噪结果改变每个epoch的训练次数.xls -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tf-fed-demo 2 | A federated learning demo for AlexNet on CIFAR-10 dataset, basing on Tensorflow. 3 | 4 | ## Dependence 5 | 1. Python 3.7 6 | 2. Tensorflow v1.14.x 7 | 3. tqdm 8 | 9 | ## Usage 10 | ```bash 11 | cd ./src 12 | python Server.py 13 | ``` 14 | 15 | ## Blog 16 | My CSDN Blog: [https://blog.csdn.net/Mr_Zing/article/details/101938334](https://blog.csdn.net/Mr_Zing/article/details/101938334) -------------------------------------------------------------------------------- /.idea/tf-fed-demo.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /src/Dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras.utils import to_categorical 3 | 4 | 5 | class BatchGenerator: 6 | def __init__(self, x, yy): 7 | self.x = x 8 | self.y = yy 9 | self.size = len(x) 10 | self.random_order = list(range(len(x))) 11 | np.random.shuffle(self.random_order) 12 | self.start = 0 13 | return 14 | 15 | def next_batch(self, batch_size): 16 | if self.start + batch_size >= len(self.random_order): 17 | overflow = (self.start + batch_size) - len(self.random_order) 18 | perm0 = self.random_order[self.start:] +\ 19 | self.random_order[:overflow] 20 | self.start = overflow 21 | else: 22 | perm0 = self.random_order[self.start:self.start + batch_size] 23 | self.start += batch_size 24 | 25 | assert len(perm0) == batch_size 26 | 27 | return self.x[perm0], self.y[perm0] 28 | 29 | # support slice 30 | def __getitem__(self, val): 31 | return self.x[val], self.y[val] 32 | 33 | 34 | class Dataset(object): 35 | def __init__(self, load_data_func, one_hot=True, split=0): 36 | (x_train, y_train), (x_test, y_test) = load_data_func() 37 | print("Dataset: train-%d, test-%d" % (len(x_train), len(x_test))) 38 | 39 | if one_hot: 40 | y_train = to_categorical(y_train, 10) 41 | y_test = to_categorical(y_test, 10) 42 | 43 | x_train = x_train.astype('float32') 44 | x_test = x_test.astype('float32') 45 | x_train /= 255 46 | x_test /= 255 47 | 48 | if split == 0: 49 | self.train = BatchGenerator(x_train, y_train) 50 | else: 51 | self.train = self.splited_batch(x_train, y_train, split) 52 | 53 | self.test = BatchGenerator(x_test, y_test) 54 | 55 | def splited_batch(self, x_data, y_data, count): 56 | res = [] 57 | l = len(x_data) 58 | for i in range(0, l, l//count): 59 | res.append( 60 | BatchGenerator(x_data[i:i + l // count], 61 | y_data[i:i + l // count])) 62 | return res 63 | -------------------------------------------------------------------------------- /src/Server.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tqdm import tqdm 3 | from Client import Clients 4 | import xlwt 5 | import numpy as np 6 | 7 | 8 | def gaussian_noise(input,std,client_id,sheet,j): 9 | length = len(input) 10 | sum=0 11 | for i in range(length): 12 | source = input[i].copy() 13 | #noise= tf.random_normal(shape=tf.shape(input[i]),mean=0.0,stddev=std,dtype=tf.float32) 14 | noise = np.random.normal(loc=0.0,scale=std,size=input[i].shape) 15 | input[i] += noise 16 | dist = np.linalg.norm(source-input[i]) 17 | sum+=dist 18 | average = sum/length 19 | sheet.write(j,client_id,average) 20 | return input 21 | 22 | def buildClients(num): 23 | learning_rate = 0.0005#0.0002 24 | num_input = 32 # image shape: 32*32 25 | num_input_channel = 3 # image channel: 3 26 | num_classes = 10 # Cifar-10 total classes (0-9 digits) 27 | 28 | #返回一定数量的clients 29 | return Clients(input_shape=[None, num_input, num_input, num_input_channel], 30 | num_classes=num_classes, 31 | learning_rate=learning_rate, 32 | clients_num=num) 33 | 34 | def run_global_test(client, global_vars, test_num, i, save=False,sheet=None): 35 | #测试输出acc和loss 36 | client.set_global_vars(global_vars) 37 | acc, loss = client.run_test(test_num,save) 38 | sheet.write(i,0,float(acc)) 39 | sheet.write(i,1,float(loss)) 40 | print("[epoch {}, {} inst] Testing ACC: {:.4f}, Loss: {:.4f}".format( 41 | ep + 1, test_num, acc, loss)) 42 | 43 | scales = [0.0005,0.05,0.2] 44 | for scale in scales: 45 | CLIENT_NUMBER = 4 #客户端数量 46 | '''可尝试更高比例的客户端''' 47 | CLIENT_RATIO_PER_ROUND = 0.5 #每轮挑选的clients的比例 48 | epoch = 260 49 | 50 | #### CREATE CLIENT AND LOAD DATASET #### 51 | client = buildClients(CLIENT_NUMBER) 52 | 53 | workbook = xlwt.Workbook() 54 | sheet=workbook.add_sheet('0.0002') 55 | #### BEGIN TRAINING #### 56 | sheet2 = workbook.add_sheet('欧式距离') 57 | 58 | global_vars = client.get_client_vars() 59 | for ep in range(epoch): 60 | #收集client端的参数 61 | client_vars_sum = None 62 | 63 | # 随机挑选client训练 64 | random_clients = client.choose_clients(CLIENT_RATIO_PER_ROUND) 65 | 66 | # tqdm显示进度条 67 | for client_id in tqdm(random_clients, ascii=True): 68 | #将sever端模型加载到tqdm上 69 | client.set_global_vars(global_vars) 70 | 71 | # 训练这个下表的client 72 | client.train_epoch(cid=client_id) 73 | 74 | # 获取当前client的变量值 75 | current_client_vars_norm = client.get_client_vars() 76 | 77 | #获得参数后如高斯白噪声 78 | current_client_vars=gaussian_noise(current_client_vars_norm,scale,client_id,sheet2,ep) 79 | 80 | # 叠加各层参数 81 | if client_vars_sum is None: 82 | client_vars_sum = current_client_vars 83 | else: 84 | for cv, ccv in zip(client_vars_sum, current_client_vars): 85 | cv += ccv 86 | 87 | # obtain the avg vars as global vars 88 | global_vars = [] 89 | for var in client_vars_sum: 90 | global_vars.append(var / len(random_clients)) 91 | 92 | # 测试集进行测试 93 | run_global_test(client, global_vars, test_num=600,i=ep,sheet=sheet)#将结果写入到excel中 94 | workbook.save('加噪scale={}结果.xls'.format(scale)) 95 | #### FINAL TEST #### 96 | #run_global_test(client, global_vars, test_num=10000) 97 | -------------------------------------------------------------------------------- /src/Model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def AlexNet(input_shape, num_classes, learning_rate, graph): 5 | with graph.as_default(): 6 | X = tf.placeholder(tf.float32, input_shape, name='X') 7 | Y = tf.placeholder(tf.float32, [None, num_classes], name='Y') 8 | DROP_RATE = tf.placeholder(tf.float32, name='drop_rate') 9 | 10 | #定义核函数 11 | conv1_filter = tf.Variable(tf.truncated_normal(shape=[3, 3, 3, 64], mean=0, stddev=0.08)) 12 | conv2_filter = tf.Variable(tf.truncated_normal(shape=[3, 3, 64, 128], mean=0, stddev=0.08)) 13 | conv3_filter = tf.Variable(tf.truncated_normal(shape=[5, 5, 128, 256], mean=0, stddev=0.08)) 14 | conv4_filter = tf.Variable(tf.truncated_normal(shape=[5, 5, 256, 512], mean=0, stddev=0.08)) 15 | 16 | conv1 = tf.nn.conv2d(X, conv1_filter, strides=[1,1,1,1], padding='SAME') 17 | conv1 = tf.nn.relu(conv1) 18 | conv1_pool = tf.nn.max_pool(conv1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') 19 | conv1_bn = tf.layers.batch_normalization(conv1_pool) 20 | 21 | conv2 = tf.nn.conv2d(conv1_bn, conv2_filter, strides=[1,1,1,1], padding='SAME') 22 | conv2 = tf.nn.relu(conv2) 23 | conv2_pool = tf.nn.max_pool(conv2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') 24 | conv2_bn = tf.layers.batch_normalization(conv2_pool) 25 | 26 | conv3 = tf.nn.conv2d(conv2_bn, conv3_filter, strides=[1,1,1,1], padding='SAME') 27 | conv3 = tf.nn.relu(conv3) 28 | conv3_pool = tf.nn.max_pool(conv3, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') 29 | conv3_bn = tf.layers.batch_normalization(conv3_pool) 30 | 31 | conv4 = tf.nn.conv2d(conv3_bn, conv4_filter, strides=[1,1,1,1], padding='SAME') 32 | conv4 = tf.nn.relu(conv4) 33 | conv4_pool = tf.nn.max_pool(conv4, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME') 34 | conv4_bn = tf.layers.batch_normalization(conv4_pool) 35 | 36 | flat = tf.contrib.layers.flatten(conv4_bn) 37 | 38 | full1 = tf.contrib.layers.fully_connected(inputs=flat, num_outputs=128, activation_fn=tf.nn.relu) 39 | full1 = tf.nn.dropout(full1, keep_prob=0.7) 40 | full1 = tf.layers.batch_normalization(full1) 41 | 42 | full2 = tf.contrib.layers.fully_connected(inputs=full1, num_outputs=256, activation_fn=tf.nn.relu) 43 | full2 = tf.nn.dropout(full2, keep_prob=0.7) 44 | full2 = tf.layers.batch_normalization(full2) 45 | 46 | full3 = tf.contrib.layers.fully_connected(inputs=full2, num_outputs=512, activation_fn=tf.nn.relu) 47 | full3 = tf.nn.dropout(full3, keep_prob=0.7) 48 | full3 = tf.layers.batch_normalization(full3) 49 | 50 | full4 = tf.contrib.layers.fully_connected(inputs=full3, num_outputs=1024, activation_fn=tf.nn.relu) 51 | full4 = tf.nn.dropout(full4, keep_prob=0.7) 52 | full4 = tf.layers.batch_normalization(full4) 53 | 54 | logits = tf.contrib.layers.fully_connected(inputs=full4, num_outputs=10, activation_fn=None) 55 | 56 | loss_op = tf.reduce_mean( 57 | tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, 58 | labels=Y)) 59 | 60 | optimizer = tf.train.AdamOptimizer( 61 | learning_rate=learning_rate) 62 | train_op = optimizer.minimize(loss_op) 63 | 64 | #评估模型 65 | prediction = tf.nn.softmax(logits) 66 | pred = tf.argmax(prediction, 1) 67 | 68 | #m模型准确率 69 | correct_pred = tf.equal(pred, tf.argmax(Y, 1)) 70 | accuracy = tf.reduce_mean( 71 | tf.cast(correct_pred, tf.float32)) 72 | 73 | return X, Y, DROP_RATE, train_op, loss_op, accuracy 74 | 75 | 76 | -------------------------------------------------------------------------------- /src/Client.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from collections import namedtuple 4 | import math 5 | from Model import AlexNet 6 | from Dataset import Dataset 7 | import random 8 | from VGG import vgg_net 9 | 10 | # FedModel 定义包含属性 x,y,drop_rate,train_op,loss_op,acc_op等属性 11 | FedModel = namedtuple('FedModel', 'X Y DROP_RATE train_op loss_op acc_op') 12 | 13 | #联邦模型客户端类 14 | class Clients: 15 | def __init__(self, input_shape, num_classes, learning_rate, clients_num): 16 | self.graph = tf.Graph() 17 | self.sess = tf.Session(graph=self.graph) 18 | 19 | # 创建alxnet网络 20 | net = AlexNet(input_shape, num_classes, learning_rate, self.graph) 21 | #net = vgg_net(input_shape, num_classes, learning_rate, self.graph) 22 | self.model = FedModel(*net) 23 | 24 | # 初始化 25 | with self.graph.as_default(): 26 | self.sess.run(tf.global_variables_initializer()) 27 | 28 | # 装载数据 29 | # 根据训练客户端数量划分数据集 30 | self.dataset = Dataset(tf.keras.datasets.cifar10.load_data,split=clients_num) 31 | 32 | #self.dataset = Dataset(tf.keras.datasets.mnist.load_data,split=clients_num) 33 | 34 | 35 | #测试模型准确率 36 | def run_test(self, num, save=False): 37 | with self.graph.as_default(): 38 | batch_x, batch_y = self.dataset.test.next_batch(num) 39 | #替代计算图中的x,y等数据 40 | feed_dict = { 41 | self.model.X: batch_x, 42 | self.model.Y: batch_y, 43 | self.model.DROP_RATE: 0 44 | } 45 | return self.sess.run([self.model.acc_op, self.model.loss_op], 46 | feed_dict=feed_dict) 47 | 48 | def train_epoch(self, cid, batch_size=256, dropout_rate=0.7): 49 | dataset = self.dataset.train[cid] 50 | 51 | with self.graph.as_default(): 52 | for _ in range(math.ceil(dataset.size // batch_size)): 53 | #for _ in range(1): 54 | batch_x, batch_y = dataset.next_batch(batch_size) 55 | batch_x = data_augmentation(batch_x,batch_y) #做数据增强处理 56 | 57 | feed_dict = { 58 | self.model.X: batch_x, 59 | self.model.Y: batch_y, 60 | self.model.DROP_RATE: dropout_rate 61 | } 62 | self.sess.run(self.model.train_op, feed_dict=feed_dict) 63 | 64 | #返回计算图中所有可训练的变量值 65 | def get_client_vars(self): 66 | """ Return all of the variables list """ 67 | with self.graph.as_default(): 68 | client_vars = self.sess.run(tf.trainable_variables()) 69 | return client_vars 70 | 71 | def set_global_vars(self, global_vars): 72 | with self.graph.as_default(): 73 | all_vars = tf.trainable_variables()#获取所有可训练变量 74 | for variable, value in zip(all_vars, global_vars): 75 | variable.load(value, self.sess)#加载server端发送的var到模型上 76 | 77 | #随机返回ratio比例的客户端并返回编号 78 | def choose_clients(self, ratio=1.0): 79 | client_num = self.get_clients_num() 80 | choose_num = math.floor(client_num * ratio) 81 | return np.random.permutation(client_num)[:choose_num] 82 | 83 | def get_clients_num(self): 84 | #返回客户端的数量 85 | return len(self.dataset.train) 86 | 87 | #数据增强 88 | def _random_crop(batch, crop_shape, padding=None): 89 | oshape = np.shape(batch[0]) 90 | if padding: 91 | oshape = (oshape[0] + 2*padding, oshape[1] + 2*padding) 92 | new_batch = [] 93 | npad = ((padding, padding), (padding, padding), (0, 0)) 94 | for i in range(len(batch)): 95 | new_batch.append(batch[i]) 96 | if padding: 97 | new_batch[i] = np.lib.pad(batch[i], pad_width=npad, 98 | mode='constant', constant_values=0) 99 | nh = random.randint(0, oshape[0] - crop_shape[0]) 100 | nw = random.randint(0, oshape[1] - crop_shape[1]) 101 | new_batch[i] = new_batch[i][nh:nh + crop_shape[0], 102 | nw:nw + crop_shape[1]] 103 | return new_batch 104 | 105 | def _random_flip_leftright(batch,batch_y): 106 | for i in range(len(batch)): 107 | ''' 108 | filpped_le_re=tf.image.random_flip_left_right(batch_x[i]) #随机左右翻转 109 | print(type(filpped_le_re)) 110 | np.concatenate(batch_x,filpped_le_re) 111 | batch_x.append(filpped_le_re) 112 | batch_y.append(batch_y[i]) 113 | filpped_up_down=tf.image.random_flip_up_down(batch_x[i]) #随机上下翻转 114 | batch_x.append(filpped_up_down) 115 | batch_y.append(batch_y[i]) 116 | 117 | # 随机设置图片的对比度 118 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 119 | batch_x.append(image) 120 | batch_y.append(batch_y[i]) 121 | 122 | # 随机设置图片的色度 123 | image2 = tf.image.random_hue(image, max_delta=0.3) 124 | batch_x.append(image2) 125 | batch_y.append(batch_y[i]) 126 | 127 | adjust=tf.image.random_brightness(filpped_up_down,0.4) 128 | batch_x.append(adjust) 129 | batch_y.append(batch_y[i]) 130 | 131 | ''' 132 | if bool(random.getrandbits(1)): 133 | batch[i] = np.fliplr(batch[i]) 134 | 135 | return batch 136 | 137 | def data_augmentation(batch_x,batch_y): 138 | batch= _random_flip_leftright(batch_x,batch_y) 139 | batch = _random_crop(batch, [32, 32], 4) 140 | return batch 141 | -------------------------------------------------------------------------------- /src/VGG.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import time 4 | import random 5 | import pickle 6 | import math 7 | import datetime 8 | from keras.preprocessing.image import ImageDataGenerator 9 | 10 | #预先定义的变量 11 | class_num = 10 12 | image_size = 32 13 | img_channels = 3 14 | iterations = 200 15 | batch_size = 250 16 | weight_decay = 0.0003 17 | dropout_rate = 0.5 18 | momentum_rate = 0.9 19 | 20 | 21 | #初始化权重,采用正则化随机初始,加入少量的噪声来打破对称性以及避免0梯度 22 | def weight_variable(name, sp): 23 | initial = tf.initializers.he_normal() 24 | return tf.get_variable(name = name, shape = sp, initializer = initial) 25 | 26 | 27 | def bias_variable(shape): 28 | initial = tf.constant(0.1, shape=shape) 29 | return tf.Variable(initial) 30 | 31 | def batch_norm(input): 32 | return tf.contrib.layers.batch_norm(input, decay=0.9, center=True, scale=True, epsilon=1e-3, 33 | updates_collections=None) 34 | def conv(name,x,w,b): 35 | #去掉BN 36 | #return tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME'),b),name=name) 37 | return tf.nn.relu(batch_norm(tf.nn.bias_add(tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME'),b)),name=name) 38 | 39 | def max_pool(name,x,k): 40 | return tf.nn.max_pool(x,ksize=[1,k,k,1],strides=[1,k,k,1],padding='SAME',name=name) 41 | 42 | def fc(name,x,w,b): 43 | return tf.nn.relu(batch_norm(tf.matmul(x,w)+b),name=name) 44 | 45 | weights={ 46 | 'wc1_1' : weight_variable('wc1_1', [3,3,3,64]), 47 | 'wc1_2' : weight_variable('wc1_2', [3,3,64,64]), 48 | 'wc2_1' : weight_variable('wc2_1', [3,3,64,128]), 49 | 'wc2_2' : weight_variable('wc2_2', [3,3,128,128]), 50 | 'wc3_1' : weight_variable('wc3_1', [3,3,128,256]), 51 | 'wc3_2' : weight_variable('wc3_2', [3,3,256,256]), 52 | 'wc3_3' : weight_variable('wc3_3', [3,3,256,256]), 53 | 'wc4_1' : weight_variable('wc4_1', [3,3,256,512]), 54 | 'wc4_2' : weight_variable('wc4_2', [3,3,512,512]), 55 | 'wc4_3' : weight_variable('wc4_3', [3,3,512,512]), 56 | 'wc5_1' : weight_variable('wc5_1', [3,3,512,512]), 57 | 'wc5_2' : weight_variable('wc5_2', [3,3,512,512]), 58 | 'wc5_3' : weight_variable('wc5_3', [3,3,512,512]), 59 | 'fc1' : weight_variable('fc1', [2*2*512,4096]), 60 | 'fc2' : weight_variable('fc2', [4096,4096]), 61 | 'fc3' : weight_variable('fc3', [4096,10]) 62 | } 63 | 64 | biases={ 65 | 'bc1_1' : bias_variable([64]), 66 | 'bc1_2' : bias_variable([64]), 67 | 'bc2_1' : bias_variable([128]), 68 | 'bc2_2' : bias_variable([128]), 69 | 'bc3_1' : bias_variable([256]), 70 | 'bc3_2' : bias_variable([256]), 71 | 'bc3_3' : bias_variable([256]), 72 | 'bc4_1' : bias_variable([512]), 73 | 'bc4_2' : bias_variable([512]), 74 | 'bc4_3' : bias_variable([512]), 75 | 'bc5_1' : bias_variable([512]), 76 | 'bc5_2' : bias_variable([512]), 77 | 'bc5_3' : bias_variable([512]), 78 | 'fb1' : bias_variable([4096]), 79 | 'fb2' : bias_variable([4096]), 80 | 'fb3' : bias_variable([10]), 81 | } 82 | 83 | #VGG-16网络,因为输入尺寸小,去掉最后两个个max pooling层 84 | def vgg_net(input_shape,num_classes,learning_rate,graph): 85 | with graph.as_default(): 86 | x = tf.placeholder(tf.float32,input_shape,name='X') 87 | y_ = tf.placeholder(tf.float32, [None, num_classes],name='Y') 88 | DROP_RATE = tf.placeholder(tf.float32, name='drop_rate') 89 | 90 | conv1_1=conv('conv1_1',x,weights['wc1_1'],biases['bc1_1']) 91 | conv1_2=conv('conv1_2',conv1_1,weights['wc1_2'],biases['bc1_2']) 92 | pool1=max_pool('pool1',conv1_2,k=2) 93 | 94 | conv2_1=conv('conv2_1',pool1,weights['wc2_1'],biases['bc2_1']) 95 | conv2_2=conv('conv2_2',conv2_1,weights['wc2_2'],biases['bc2_2']) 96 | pool2=max_pool('pool2',conv2_2,k=2) 97 | 98 | conv3_1=conv('conv3_1',pool2,weights['wc3_1'],biases['bc3_1']) 99 | conv3_2=conv('conv3_2',conv3_1,weights['wc3_2'],biases['bc3_2']) 100 | conv3_3=conv('conv3_3',conv3_2,weights['wc3_3'],biases['bc3_3']) 101 | pool3=max_pool('pool3',conv3_3,k=2) 102 | 103 | conv4_1=conv('conv4_1',pool3,weights['wc4_1'],biases['bc4_1']) 104 | conv4_2=conv('conv4_2',conv4_1,weights['wc4_2'],biases['bc4_2']) 105 | conv4_3=conv('conv4_3',conv4_2,weights['wc4_3'],biases['bc4_3']) 106 | pool4=max_pool('pool4',conv4_3,k=2) 107 | 108 | conv5_1=conv('conv5_1',pool4,weights['wc5_1'],biases['bc5_1']) 109 | conv5_2=conv('conv5_2',conv5_1,weights['wc5_2'],biases['bc5_2']) 110 | conv5_3=conv('conv5_3',conv5_2,weights['wc5_3'],biases['bc5_3']) 111 | pool5=max_pool('pool5',conv5_3,k=1) 112 | 113 | _shape=pool5.get_shape() 114 | flatten=_shape[1].value*_shape[2].value*_shape[3].value 115 | pool5=tf.reshape(pool5,shape=[-1,flatten]) 116 | fc1=fc('fc1',pool5,weights['fc1'],biases['fb1']) 117 | fc1=tf.nn.dropout(fc1,DROP_RATE) 118 | 119 | fc2=fc('fc2',fc1,weights['fc2'],biases['fb2']) 120 | fc2=tf.nn.dropout(fc2,DROP_RATE) 121 | 122 | output=fc('fc3',fc2,weights['fc3'],biases['fb3']) 123 | 124 | loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=output)) 125 | 126 | optimizer = tf.train.AdamOptimizer( 127 | learning_rate=learning_rate) 128 | train_op = optimizer.minimize(loss_op) 129 | 130 | prediction = tf.nn.softmax(output) 131 | pred = tf.argmax(prediction,1) 132 | 133 | correct_pred = tf.equal(pred, tf.argmax(y_, 1)) 134 | accuracy = tf.reduce_mean( 135 | tf.cast(correct_pred, tf.float32)) 136 | 137 | return x,y_,DROP_RATE,train_op,loss_op,accuracy 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------