├── .gitignore ├── Dockerfile ├── HOSTS ├── README.org ├── build-mp-spdz.sh ├── common.sh ├── convert.sh ├── download.sh ├── full.py ├── full01.py ├── prepare.py ├── run-local.sh ├── run-remote.sh ├── setup-remote.sh ├── setup-ssh.sh ├── ssh_config ├── test_protocols.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | RUN apt-get update && apt-get -y install wget tar openssl git make cmake \ 4 | python3 python3-pip clang libsodium-dev autoconf automake \ 5 | libtool yasm texinfo libboost-dev libssl-dev libboost-system-dev \ 6 | libboost-thread-dev libgmp-dev rsync ssh openssh-server procps 7 | 8 | WORKDIR /root 9 | 10 | ADD download.sh . 11 | RUN ./download.sh 12 | 13 | RUN pip3 install numpy 14 | 15 | ADD prepare.py . 16 | RUN ./prepare.py 17 | 18 | RUN git clone -b v0.3.6 https://github.com/data61/MP-SPDZ 19 | 20 | ADD build-mp-spdz.sh . 21 | RUN ./build-mp-spdz.sh 22 | 23 | ADD ssh_config .ssh/config 24 | ADD setup-ssh.sh . 25 | RUN ./setup-ssh.sh 26 | 27 | ADD convert.sh *.py ./ 28 | RUN ./convert.sh 29 | 30 | ADD *.sh *.py HOSTS ./ 31 | 32 | #RUN ./test_protocols.sh 33 | #RUN ./run-local.sh emul D prob 2 3 32 adamapprox 34 | #RUN service ssh start; ./run-remote.sh sh3 A near 1 1 16 rate.1 35 | -------------------------------------------------------------------------------- /HOSTS: -------------------------------------------------------------------------------- 1 | 127.0.0.1 2 | 127.0.0.1 3 | 127.0.0.1 4 | 127.0.0.1 5 | -------------------------------------------------------------------------------- /README.org: -------------------------------------------------------------------------------- 1 | #+TITLE: Deep Learning Training with Multi-Party Computation 2 | 3 | Small set of scripts for training MNIST and CIFAR10 with a number of 4 | networks using [[https://github.com/data61/MP-SPDZ/][MP-SPDZ]]. This version underlies the figures in 5 | . For the version underlying 6 | and 7 | , see the =v1= branch. 8 | 9 | ** Installation (with Docker) 10 | 11 | 12 | Run the following from this directory to build a Docker container: 13 | 14 | : $ docker build . 15 | 16 | =Dockerfile= contains some commented out tests at the end. 17 | 18 | ** Running locally 19 | 20 | After setting everything up, you can use this script to run the 21 | computation: 22 | 23 | : $ ./run-local.sh [] 24 | 25 | The options are as follows: 26 | 27 | - =protocol= is one of =emul= (emulation), =sh2= (semi-honest 28 | two-party computation) =sh3= (semi-honest three-party computation, 29 | =mal3= (malicious three-party computation), =mal4= (malicious 30 | four-party computation), =dm3= (dishonest-majority semi-honest 31 | three-party computation), =sh10= (semi-honest ten-party 32 | computation), =dm10= (dishonest-majority semi-honest ten-party 33 | computation). All protocols assume an honest majority unless stated 34 | otherwise. 35 | - =net= is the network (A-D for MNIST, alex for Falcon AlexNet on 36 | CIFAR10, and new_alex for a more sophisticated AlexNet-like network on 37 | CIFAR10). 38 | - =round= is the kind of rounding, =prob= for probabilistic and =near= for 39 | nearest. 40 | - =n_threads= is the number of threads per party. 41 | - =n_epochs= is the number of epochs. 42 | - =precision= is the precision (in bits) after the decimal point. 43 | 44 | The following options can be given in addition. 45 | 46 | - =rate= the learning rate, e.g., =rate.1= for a learning rate of 0.1. 47 | - =adam=, =adamapprox=, or =amsgrad= to use Adam, Adam with less 48 | precise approximation of inverse square root, or AMSgrad instead of 49 | SGD. 50 | 51 | For example, 52 | 53 | : $ ./run-local.sh emul D prob 2 20 32 adamapprox 54 | 55 | runs 20 epochs of training network D in the emulation with two threads, 56 | probabilistic rounding, 32-bit precision, and Adam with less precise 57 | inverse square root approximation. 58 | 59 | ** Running remotely 60 | 61 | You need to set up hosts that run SSH and have all higher TCP ports 62 | open between each other. We have used =c5.9xlarge= instances in the 63 | same AWS zone and hence 36 threads. The hosts have to run Linux with a 64 | glib not older than Ubuntu 18.04 (2.27), which is the case for Amazon 65 | Linux 2. Honest-majority protocols require three hosts while 66 | dishonest-majority protocols require two. 67 | 68 | With Docker, you can run the following script to set up host names, 69 | user name and SSH RSA key. We do *NOT* recommend running it outside 70 | Docker because it might overwrite an existing RSA key file. 71 | 72 | : $ ./setup-remote.sh 73 | 74 | Without Docker, familiarise yourself with SSH configuration options 75 | and SSH keys. You can use =ssh_config= and the above script to find 76 | out the requirements. =HOSTS= has to contain the hostnames separated 77 | by whitespace. 78 | 79 | After setting up, you can the following using the same options as 80 | above: 81 | 82 | : $ ./run-remote.sh [] 83 | 84 | For example, 85 | 86 | : $ ./run-remote.sh sh3 A near 1 1 16 rate.1 87 | 88 | runs one epoch of training network A with semi-honest three-party 89 | computation, one thread, and nearest rounding, 16-bit precision, 90 | and SGD with rate 0.1. 91 | 92 | ** Cleartext training 93 | 94 | =train.py= allows to run the equivalent cleartext training with Tensorflow 95 | (using floating-point instead of fixed-point representation). 96 | Simply run the following: 97 | 98 | : $ ./train.py [] [] 99 | 100 | == is either a learning rate for SGD, or =adam= or =amsgrad= 101 | followed by a learning rate. == is any combination of =dropout= 102 | to add a Dropout layer (only with network C) and =fashion= for 103 | Fashion MNIST. 104 | 105 | ** Data format and scripts 106 | 107 | =full.py= processes the whole MNIST dataset, and =full01.py= processes 108 | two digits to allow logistic regression and similar. By 109 | default the digits are 0 and 1, but you can request any pair. The 110 | following processes the digits 4 and 9, which are harder to 111 | distinguish: 112 | 113 | : $ ./full01.py 49 114 | 115 | Both output to stdout in the format that =mnist_*.mpc= expect in the 116 | input file for party 0, that is: training set labels, training set 117 | data (example by example), test set labels, test set data (example by 118 | example), all as whitespace-seperated text. Labels are stored as 119 | one-hot vectors for the whole dataset and 0/1 for logistic regression. 120 | 121 | The scripts expect the unzipped MNIST dataset in the current 122 | directory. Run =download.sh= to download the MNIST dataset. 123 | -------------------------------------------------------------------------------- /build-mp-spdz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd MP-SPDZ || exit 1 4 | echo CXX = clang++ >> CONFIG.mine 5 | echo MY_CFLAGS += -DCHOP_MEMORY >> CONFIG.mine 6 | make -j8 setup 7 | mkdir static 8 | make -j8 {static/,}{{{replicated,sy-rep,rep4}-ring,{t,h}emi,atlas}-party,emulate}.x 9 | -------------------------------------------------------------------------------- /common.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if test -z "$6"; then 4 | echo "Usage: $0 []" 5 | exit 1 6 | fi 7 | 8 | protocol=$1 9 | net=$2 10 | round=$3 11 | n_threads=$4 12 | n_epochs=$5 13 | f=$6 14 | shift 6 15 | 16 | cd MP-SPDZ 17 | 18 | test -e logs || mkdir logs 19 | 20 | case $protocol in 21 | sh2) protocol=hemi; PLAYERS=2; run_opt="-b 100 $run_opt" ;; 22 | sh3) protocol=ring; PLAYERS=3 ;; 23 | dm3) protocol=temi; PLAYERS=3; run_opt="-b 100 -lgp 111 -N 3 $run_opt" ;; 24 | mal3) protocol=sy-rep-ring; PLAYERS=3 ;; 25 | mal4) protocol=rep4-ring; PLAYERS=4 ;; 26 | sh10) protocol=atlas; PLAYERS=10 run_opt="-N 10 $run_opt" ;; 27 | dm10) protocol=temi; PLAYERS=10; run_opt="-b 100 -lgp 111 -N 10 $run_opt" ;; 28 | emul) protocol=emulate; compile_args="-K ''" ;; 29 | esac 30 | 31 | export PLAYERS 32 | 33 | args="$*" 34 | 35 | if [[ $net = D ]]; then 36 | args="2dense $args" 37 | fi 38 | 39 | if [[ $protocol =~ ring || $protocol == emulate ]]; then 40 | ring=1 41 | compile_args="-R 64 $compile_args" 42 | if [[ $protocol == rep4-ring ]]; then 43 | args="split4 $args" 44 | elif [[ $protocol != emulate ]]; then 45 | args="split3 $args" 46 | fi 47 | else 48 | ring=0 49 | args="edabit $args" 50 | fi 51 | 52 | k=$[2 * f - 1] 53 | args="f$f k$k $args" 54 | 55 | if [[ $round = near ]]; then 56 | args="nearest $args" 57 | elif [[ $protocol != sy-rep-ring && $round = prob && 58 | ($ring = 1 || $protocol = hemi) ]]; then 59 | args="trunc_pr $args" 60 | fi 61 | 62 | if [[ $net = alex ]]; then 63 | args="falcon_alex $n_epochs 128 $n_threads $args" 64 | run_opt="-IF /tmp/cifar10-Input $run_opt" 65 | elif [[ $net = new_alex ]]; then 66 | args="alex $n_epochs 128 $n_threads $args" 67 | run_opt="-IF /tmp/cifar10-Input $run_opt" 68 | else 69 | args="mnist_full_$net $n_epochs 128 $n_threads $args" 70 | fi 71 | 72 | python3 ./compile.py $compile_args -b 100000 -CD $args | grep -v WARNING 73 | 74 | touch ~/.rnd 75 | Scripts/setup-ssl.sh 10 76 | 77 | N=$PLAYERS 78 | 79 | for i in $(seq 0 $[N-1]); do 80 | echo $i 81 | echo "${hosts[$i]}" 82 | done 83 | 84 | args=${args% } 85 | prog=${args// /-} 86 | 87 | bin=$protocol-party.x 88 | 89 | if [[ $protocol = ring ]]; then 90 | bin=replicated-ring-party.x 91 | elif [[ $protocol = emulate ]]; then 92 | bin=emulate.x 93 | fi 94 | -------------------------------------------------------------------------------- /convert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | out_dir=MP-SPDZ/Player-Data 4 | test -e $out_dir || mkdir $out_dir 5 | 6 | ./full.py > $out_dir/Input-P0-0 7 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget -nc http://yann.lecun.com/exdb/mnist/{train,t10k}-{images-idx3,labels-idx1}-ubyte.gz || exit 1 4 | 5 | for i in *.gz; do 6 | gunzip $i 7 | done 8 | 9 | wget -nc https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz || exit 1 10 | 11 | tar xzvf cifar-10-python.tar.gz 12 | -------------------------------------------------------------------------------- /full.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import struct, sys 4 | 5 | w = lambda x: struct.unpack('>i', x.read(4))[0] 6 | b = lambda x: struct.unpack('B', x.read(1))[0] 7 | 8 | try: 9 | max_n = int(sys.argv[1]) 10 | except: 11 | max_n = None 12 | 13 | try: 14 | scale = int(sys.argv[2]) 15 | except: 16 | scale = True 17 | 18 | for s in 'train', 't10k': 19 | labels = open('%s-labels-idx1-ubyte' % s, 'rb') 20 | images = open('%s-images-idx3-ubyte' % s, 'rb') 21 | 22 | assert w(labels) == 2049 23 | n_labels = w(labels) 24 | 25 | assert w(images) == 2051 26 | n_images = w(images) 27 | assert n_labels == n_images 28 | assert w(images) == 28 29 | assert w(images) == 28 30 | 31 | print ('%d total examples' % n_images, file=sys.stderr) 32 | 33 | data = [] 34 | n = [0] * 10 35 | 36 | for i in range(n_images if max_n is None else min(max_n, n_images)): 37 | label = b(labels) 38 | image = [b(images) / 256 if scale else b(images) 39 | for j in range(28 ** 2)] 40 | data.append(image) 41 | n[label] += 1 42 | l = [0] * 10 43 | l[label] = 1 44 | print(' '.join(str(x) for x in l), end=' ') 45 | print() 46 | 47 | print ('%d used examples %s' % (len(data), n), file=sys.stderr) 48 | 49 | for x in data: 50 | for y in x: 51 | print(y, end=' ') 52 | print() 53 | -------------------------------------------------------------------------------- /full01.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import struct, sys 4 | 5 | try: 6 | digits = [int(x) for x in sys.argv[1]] 7 | except: 8 | digits = 0, 1 9 | 10 | print('digits: %s' % str(digits), file=sys.stderr) 11 | 12 | w = lambda x: struct.unpack('>i', x.read(4))[0] 13 | b = lambda x: struct.unpack('B', x.read(1))[0] 14 | 15 | for s in 'train', 't10k': 16 | labels = open('%s-labels-idx1-ubyte' % s, 'rb') 17 | images = open('%s-images-idx3-ubyte' % s, 'rb') 18 | 19 | assert w(labels) == 2049 20 | n_labels = w(labels) 21 | 22 | assert w(images) == 2051 23 | n_images = w(images) 24 | assert n_labels == n_images 25 | assert w(images) == 28 26 | assert w(images) == 28 27 | 28 | print ('%d total examples' % n_images, file=sys.stderr) 29 | 30 | data = [] 31 | n = 0 32 | 33 | for i in range(n_images): 34 | label = b(labels) 35 | image = [b(images) / 256 for j in range(28 ** 2)] 36 | if label in digits: 37 | data.append(image) 38 | n += label == digits[1] 39 | print(int(label == digits[1]), end=' ') 40 | print() 41 | 42 | print ('%d used examples (%d, %d)' % (len(data), len(data) - n, n), 43 | file=sys.stderr) 44 | 45 | for x in data: 46 | for y in x: 47 | print(y, end=' ') 48 | print() 49 | -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import sys 4 | import numpy 5 | 6 | binary = 'binary' in sys.argv 7 | 8 | bout = open('/tmp/cifar10-Binary-P0-0', 'wb') 9 | out = open('/tmp/cifar10-Input-P0-0', 'w') 10 | 11 | if binary: 12 | def ff(x): 13 | bout.write(x.astype(numpy.single).tobytes()) 14 | else: 15 | def ff(x): 16 | numpy.savetxt(out, x.reshape(x.shape[0], -1), '%.6f') 17 | 18 | def f(x): 19 | x = numpy.reshape(x, (x.shape[0], 3, 32, 32)) 20 | x = numpy.moveaxis(x, 1, -1) 21 | print(x.shape) 22 | x = (x / 255 * 2 - 1) 23 | ff(x) 24 | print (x.max(), x.min(), x.sum()) 25 | 26 | def g(x): 27 | for a in x: 28 | for i in range(10): 29 | out.write(str(int(i == a))) 30 | out.write(' ') 31 | out.write('\n') 32 | 33 | def unpickle(file): 34 | import pickle 35 | with open(file, 'rb') as fo: 36 | dict = pickle.load(fo, encoding='bytes') 37 | return dict 38 | 39 | labels = [] 40 | data = [] 41 | 42 | for i in range(1, 6): 43 | part = unpickle('cifar-10-batches-py/data_batch_%d' % i) 44 | labels.extend(part[b'labels']) 45 | data.append(part[b'data']) 46 | 47 | g(labels) 48 | 49 | for x in data: 50 | f(x) 51 | 52 | data = unpickle('cifar-10-batches-py/test_batch') 53 | g(data[b'labels']) 54 | f(data[b'data']) 55 | -------------------------------------------------------------------------------- /run-local.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | . common.sh 4 | 5 | Scripts/$protocol.sh $prog $run_opt 6 | 7 | # allow debugging with docker 8 | exit 0 9 | -------------------------------------------------------------------------------- /run-remote.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | i=0 4 | for host in $(&1 | grep -v 'Permanently added' | grep -v 'no process found' 22 | 23 | wait 24 | 25 | setup_opts="-h ${hosts[0]} -pn $[RANDOM+1024] -p" 26 | log=$(echo $prog-$(basename $bin) | sed 's/ //g') 27 | 28 | test -e logs || mkdir logs 29 | 30 | { 31 | for j in $(seq 0 $[N-1]); do 32 | { while true; do echo; sleep 1; done; } | 33 | ssh ${hosts[$j]} " 34 | c_rehash Player-Data 35 | echo $prefix 36 | $prefix static/$bin $prog $run_opt $setup_opts $j 37 | " 2>&1 | { 38 | logfile=logs/$log-$j 39 | echo logging to $logfile 40 | if true || test $j -eq 0; then 41 | date >> $logfile 42 | tee -a $logfile 43 | date >> $logfile 44 | else 45 | cat >> /dev/null 46 | fi 47 | } & true 48 | done 49 | } 2>&1 | grep -v 'Permanently added' | grep -v 'no process found' 50 | 51 | # allow debugging with docker 52 | exit 0 53 | -------------------------------------------------------------------------------- /setup-remote.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "enter username (default ec2-user)" 4 | user=$(read) 5 | sed -i s/root/${user:=ec2-user}/ .ssh/config 6 | 7 | echo paste private RSA key followed by Ctrl-d 8 | cat > .ssh/id_rsa 9 | chmod 600 .ssh/id_rsa 10 | 11 | echo paste hostnames followed by Ctrl-d 12 | cat > HOSTS 13 | -------------------------------------------------------------------------------- /setup-ssh.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ssh-keygen -f .ssh/id_dsa -N '' 4 | cat .ssh/id_dsa.pub > .ssh/authorized_keys 5 | service ssh start 6 | ssh 127.0.0.1 true 7 | -------------------------------------------------------------------------------- /ssh_config: -------------------------------------------------------------------------------- 1 | Host * 2 | 3 | # change to remote username 4 | User root 5 | 6 | StrictHostKeyChecking no 7 | -------------------------------------------------------------------------------- /test_protocols.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./full.py 10 > MP-SPDZ/Player-Data/d-P0-0 4 | 5 | opts="16 1dense rate.1 mini" 6 | 7 | run_opt="-IF Player-Data/d -M" ./run-local.sh dm10 A prob 1 3 $opts 8 | 9 | for i in sh2 dm3; do 10 | run_opt="-IF Player-Data/d -M" ./run-local.sh $i A prob 1 10 $opts 11 | done 12 | 13 | for i in emul sh3 mal4 mal3 sh10; do 14 | run_opt="-IF Player-Data/d" ./run-local.sh $i A prob 1 10 $opts 15 | done 16 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | # -*- coding: utf-8 -*- 4 | """tensorflow/datasets 5 | 6 | Automatically generated by Colaboratory. 7 | 8 | Original file is located at 9 | https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/keras_example.ipynb 10 | 11 | # Training a neural network on MNIST with Keras 12 | 13 | This simple example demonstrate how to plug TFDS into a Keras model. 14 | 15 | Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0 16 | 17 | 18 | 21 | 24 | 27 |
19 | View on TensorFlow.org 20 | 22 | Run in Google Colab 23 | 25 | View source on GitHub 26 |
28 | """ 29 | 30 | import tensorflow.compat.v2 as tf 31 | import tensorflow_datasets as tfds 32 | import sys 33 | import re 34 | 35 | tfds.disable_progress_bar() 36 | tf.enable_v2_behavior() 37 | 38 | """## Step 1: Create your input pipeline 39 | 40 | Build efficient input pipeline using advices from: 41 | * [TFDS performance guide](https://www.tensorflow.org/datasets/performances) 42 | * [tf.data performance guide](https://www.tensorflow.org/guide/data_performance#optimize_performance) 43 | 44 | ### Load MNIST 45 | 46 | Load with the following arguments: 47 | 48 | * `shuffle_files`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training. 49 | * `as_supervised`: Returns tuple `(img, label)` instead of dict `{'image': img, 'label': label}` 50 | """ 51 | 52 | dataset = 'fashion_mnist' if 'fashion' in sys.argv else 'mnist' 53 | 54 | (ds_train, ds_test), ds_info = tfds.load( 55 | dataset, 56 | split=['train', 'test'], 57 | shuffle_files=True, 58 | as_supervised=True, 59 | with_info=True, 60 | ) 61 | 62 | """### Build training pipeline 63 | 64 | Apply the following transormations: 65 | 66 | * `ds.map`: TFDS provide the images as tf.uint8, while the model expect tf.float32, so normalize images 67 | * `ds.cache` As the dataset fit in memory, cache before shuffling for better performance.
68 | __Note:__ Random transformations should be applied after caching 69 | * `ds.shuffle`: For true randomness, set the shuffle buffer to the full dataset size.
70 | __Note:__ For bigger datasets which do not fit in memory, a standard value is 1000 if your system allows it. 71 | * `ds.batch`: Batch after shuffling to get unique batches at each epoch. 72 | * `ds.prefetch`: Good practice to end the pipeline by prefetching [for performances](https://www.tensorflow.org/guide/data_performance#prefetching). 73 | """ 74 | 75 | def normalize_img(image, label): 76 | """Normalizes images: `uint8` -> `float32`.""" 77 | return tf.cast(image, tf.float32) / 255., label 78 | 79 | ds_train = ds_train.map( 80 | normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) 81 | ds_train = ds_train.cache() 82 | ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples) 83 | ds_train = ds_train.batch(128) 84 | ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE) 85 | 86 | """### Build evaluation pipeline 87 | 88 | Testing pipeline is similar to the training pipeline, with small differences: 89 | 90 | * No `ds.shuffle()` call 91 | * Caching is done after batching (as batches can be the same between epoch) 92 | """ 93 | 94 | ds_test = ds_test.map( 95 | normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) 96 | ds_test = ds_test.batch(128) 97 | ds_test = ds_test.cache() 98 | ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE) 99 | 100 | """## Step 2: Create and train the model 101 | 102 | Plug the input pipeline into Keras. 103 | """ 104 | 105 | network = 'A' 106 | n_epochs = 10 107 | lr = 0.1 108 | adam = False 109 | amsgrad = False 110 | 111 | if len(sys.argv) > 1: 112 | network = sys.argv[1] 113 | if len(sys.argv) > 2: 114 | n_epochs = int(sys.argv[2]) 115 | if len(sys.argv) > 3: 116 | adam = re.match('adam(.*)', sys.argv[3]) 117 | amsgrad = re.match('amsgrad(.*)', sys.argv[3]) 118 | if adam or amsgrad: 119 | try: 120 | lr = float((adam or amsgrad).group(1)) 121 | except: 122 | lr = .001 123 | if adam: 124 | print('use Adam with lr', lr) 125 | else: 126 | print('use AMSGrad with lr', lr) 127 | else: 128 | lr = float(sys.argv[3]) 129 | 130 | if network == 'A': 131 | layers = [ 132 | tf.keras.layers.Flatten(), 133 | tf.keras.layers.Dense(128,activation='relu'), 134 | tf.keras.layers.Dense(128,activation='relu'), 135 | tf.keras.layers.Dense(10, activation='softmax') 136 | ] 137 | elif network == 'D': 138 | layers = [ 139 | tf.keras.layers.Conv2D(5, 5, 2, 'same', activation='relu'), 140 | tf.keras.layers.Flatten(), 141 | tf.keras.layers.Dense(100, activation='relu'), 142 | tf.keras.layers.Dense(10, activation='softmax') 143 | ] 144 | model = tf.keras.models.Sequential(layers) 145 | elif network == 'B': 146 | layers = [ 147 | tf.keras.layers.Conv2D(16, 5, 1, 'valid', activation='relu'), 148 | tf.keras.layers.MaxPooling2D(2), 149 | tf.keras.layers.Conv2D(16, 5, 1, 'valid', activation='relu'), 150 | tf.keras.layers.MaxPooling2D(2), 151 | tf.keras.layers.Flatten(), 152 | tf.keras.layers.Dense(100, activation='relu'), 153 | tf.keras.layers.Dense(10, activation='softmax') 154 | ] 155 | elif network == 'C': 156 | layers = [ 157 | tf.keras.layers.Conv2D(20, 5, 1, 'valid', activation='relu'), 158 | tf.keras.layers.MaxPooling2D(2), 159 | tf.keras.layers.Conv2D(50, 5, 1, 'valid', activation='relu'), 160 | tf.keras.layers.MaxPooling2D(2), 161 | tf.keras.layers.Flatten(), 162 | tf.keras.layers.Dense(500, activation='relu'), 163 | tf.keras.layers.Dense(10, activation='softmax') 164 | ] 165 | if 'dropout' in sys.argv: 166 | layers.insert(-2, tf.keras.layers.Dropout(0.5)) 167 | else: 168 | raise Exception('unknown network: ' + network) 169 | 170 | print (layers) 171 | model = tf.keras.models.Sequential(layers) 172 | 173 | if adam or amsgrad: 174 | optim = tf.keras.optimizers.Adam(lr, amsgrad=amsgrad) 175 | else: 176 | optim = tf.keras.optimizers.SGD(momentum=0.9, learning_rate=lr,) 177 | 178 | model.compile( 179 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 180 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], 181 | optimizer=optim 182 | ) 183 | 184 | for i in range(1): 185 | history = model.fit( 186 | ds_train, 187 | epochs=n_epochs, 188 | validation_data=ds_test, 189 | ) 190 | 191 | out = open('log-' + '-'.join(sys.argv[1:]), 'w') 192 | for i in range(n_epochs): 193 | out.write(str(i) + ' ') 194 | for x in history.history: 195 | out.write('%s: %.4f ' % (x, history.history[x][i])) 196 | out.write('\n') 197 | --------------------------------------------------------------------------------