├── .gitignore ├── README.md ├── dataset ├── __init__.py ├── __init__.pyc ├── mnist.py ├── mnist.pyc ├── svhn.py ├── svhn.pyc ├── usps.py ├── usps.pyc ├── util.py └── util.pyc ├── model.py ├── model.pyc ├── nns ├── __init__.py ├── __init__.pyc ├── large.py ├── large.pyc ├── small.py └── small.pyc ├── osda_train.py ├── preprocessing ├── __init__.py ├── __init__.pyc ├── preprocessing.py └── preprocessing.pyc ├── results ├── ladv.png ├── test_accuracy.png └── um.png └── utils ├── __init__.py ├── __init__.pyc ├── metrics.py ├── metrics.pyc ├── utils.py └── utils.pyc /.gitignore: -------------------------------------------------------------------------------- 1 | /log/* 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open_set_domain_adaptation 2 | Last Update: 3, Mar, 2019. Code is complete Now (UNK accuracy is low, trying to fix it....). 3 | 4 | Unofficial Tensorflow Implementation of 《Open Set Domain Adaptation by Backpropagation》 5 | 6 | On SVHN->MNIST and MNIST->USPS, USPS->MNIST 7 | 8 | ## Usage: 9 | 10 | python osda_train.py 11 | 12 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/dataset/__init__.pyc -------------------------------------------------------------------------------- /dataset/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import sys 4 | import util 5 | from urlparse import urljoin 6 | import gzip 7 | import struct 8 | import operator 9 | import numpy as np 10 | #from preprocessing import preprocessing 11 | def get_one_hot(targets, nb_classes): 12 | return np.eye(nb_classes)[np.array(targets).reshape(-1)] 13 | class MNIST: 14 | base_url = 'http://yann.lecun.com/exdb/mnist/' 15 | 16 | data_files = { 17 | 'train_images': 'train-images-idx3-ubyte.gz', 18 | 'train_labels': 'train-labels-idx1-ubyte.gz', 19 | 'test_images': 't10k-images-idx3-ubyte.gz', 20 | 'test_labels': 't10k-labels-idx1-ubyte.gz', 21 | } 22 | def __init__(self,path='mnist',shuffle=True,output_size=[28,28],output_channel=1,frange=[-1.,1.0],known_class=[0,1,2,3,4],unknown_class=[5,6,7,8,9],unk=True,split='train',select=[]): 23 | self.image_shape=[28,28,1] 24 | self.label_shape=() 25 | self.path=path 26 | self.shuffle=shuffle 27 | self.output_size=output_size 28 | self.output_channel=output_channel 29 | self.frange=frange 30 | self.split=split 31 | self.select=select 32 | #--------------key for open set domain adaptation---------------------------------- 33 | self.known_class=known_class 34 | self.unknown_class=unknown_class 35 | self.unk=unk 36 | #---------------------------------------------------------------------------------- 37 | self.download() 38 | self.pointer=0 39 | self.load_dataset() 40 | def download(self): 41 | data_dir = self.path 42 | if not os.path.exists(data_dir): 43 | os.mkdir(data_dir) 44 | for filename in self.data_files.values(): 45 | path = self.path+'/'+filename 46 | if not os.path.exists(path): 47 | url = urljoin(self.base_url, filename) 48 | util.maybe_download(url, path) 49 | def _read_datafile(self, path, expected_dims): 50 | base_magic_num = 2048 51 | with gzip.GzipFile(path) as f: 52 | magic_num = struct.unpack('>I', f.read(4))[0] 53 | expected_magic_num = base_magic_num + expected_dims 54 | if magic_num != expected_magic_num: 55 | raise ValueError('Incorrect MNIST magic number (expected ' 56 | '{}, got {})' 57 | .format(expected_magic_num, magic_num)) 58 | dims = struct.unpack('>' + 'I' * expected_dims, 59 | f.read(4 * expected_dims)) 60 | buf = f.read(reduce(operator.mul, dims)) 61 | data = np.frombuffer(buf, dtype=np.uint8) 62 | data = data.reshape(*dims) 63 | return data 64 | def shuffle_data(self): 65 | images = self.images[:] 66 | labels = self.labels[:] 67 | self.images = [] 68 | self.labels = [] 69 | 70 | idx = np.random.permutation(len(labels)) 71 | for i in idx: 72 | self.images.append(images[i]) 73 | self.labels.append(labels[i]) 74 | def load_dataset(self): 75 | abspaths = {name: self.path+'/'+path 76 | for name, path in self.data_files.items()} 77 | if self.split=='train': 78 | self.images = self._read_images(abspaths['train_images']) 79 | self.labels = self._read_labels(abspaths['train_labels']) 80 | elif self.split=='test': 81 | self.images = self._read_images(abspaths['test_images']) 82 | self.labels = self._read_labels(abspaths['test_labels']) 83 | if len(self.select)!=0: 84 | self.images=self.images[self.select] 85 | self.labels=self.labels[self.select] 86 | if len(self.known_class)!=0: 87 | images=[] 88 | labels=[] 89 | known_indices=[i for i in xrange(len(self.images)) if self.labels[i] in self.known_class] 90 | known_images=self.images[known_indices] 91 | known_labels=self.labels[known_indices] 92 | if self.unk==True: 93 | labels=[] 94 | for i in xrange(len(self.labels)): 95 | if self.labels[i] in self.unknown_class: 96 | labels.append(self.unknown_class[0]) 97 | else: 98 | labels.append(self.labels[i]) 99 | self.labels=labels 100 | else: 101 | self.images=known_images 102 | self.labels=known_labels 103 | if self.frange==[-1.,1.]: 104 | self.images=self.images*2.0-1.0 105 | 106 | def reset_pointer(self): 107 | self.pointer=0 108 | if self.shuffle: 109 | self.shuffle_data() 110 | 111 | def class_next_batch(self,num_per_class): 112 | batch_size=10*num_per_class 113 | classpaths=[] 114 | ids=[] 115 | for i in xrange(10): 116 | classpaths.append([]) 117 | for j in xrange(len(self.labels)): 118 | label=self.labels[j] 119 | classpaths[label].append(j) 120 | for i in xrange(10): 121 | ids+=np.random.choice(classpaths[i],size=num_per_class,replace=False).tolist() 122 | selfimages=np.array(self.images) 123 | selflabels=np.array(self.labels) 124 | return np.array(selfimages[ids]),get_one_hot(selflabels[ids],10) 125 | 126 | def next_batch(self,batch_size): 127 | if self.pointer+batch_size>=len(self.labels): 128 | self.reset_pointer() 129 | images=self.images[self.pointer:(self.pointer+batch_size)] 130 | labels=self.labels[self.pointer:(self.pointer+batch_size)] 131 | self.pointer+=batch_size 132 | return np.array(images),get_one_hot(labels,6) 133 | def _read_images(self, path): 134 | return (self._read_datafile(path, 3) 135 | .astype(np.float32) 136 | .reshape(-1, 28, 28, 1) 137 | /255.0) 138 | 139 | def _read_labels(self, path): 140 | return self._read_datafile(path, 1) 141 | 142 | def main(): 143 | mnist=MNIST(path='data/mnist') 144 | a,b=mnist.next_batch(10) 145 | print a[0] 146 | #print b 147 | 148 | if __name__=='__main__': 149 | main() 150 | -------------------------------------------------------------------------------- /dataset/mnist.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/dataset/mnist.pyc -------------------------------------------------------------------------------- /dataset/svhn.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import sys 4 | import util 5 | from urlparse import urljoin 6 | import gzip 7 | import struct 8 | import operator 9 | import numpy as np 10 | from scipy.io import loadmat 11 | def get_one_hot(targets, nb_classes): 12 | return np.eye(nb_classes)[np.array(targets).reshape(-1)] 13 | class SVHN: 14 | base_url = 'http://ufldl.stanford.edu/housenumbers/' 15 | 16 | data_files = { 17 | 'train': 'train_32x32.mat', 18 | 'test': 'test_32x32.mat', 19 | #'extra': 'extra_32x32.mat', 20 | } 21 | 22 | def __init__(self,path=None,frange=[-1.,1.],select=[],shuffle=True,output_size=[28,28],output_channel=1,known_class=[0,1,2,3,4],unknown_class=[5,6,7,8,9],unk=True,split='train'): 23 | self.image_shape=[32,32,3] 24 | self.label_shape=() 25 | self.path=path 26 | self.shuffle=shuffle 27 | self.output_size=output_size 28 | self.output_channel=output_channel 29 | self.split=split 30 | self.frange=frange 31 | self.select=select 32 | self.download() 33 | self.pointer=0 34 | #-----------key for open set domain adaptation data preprocessing----------------------------- 35 | self.unk=unk 36 | self.known_class=known_class 37 | self.unknown_class=unknown_class 38 | #---------------------------------------------------------------------------------------------- 39 | self.load_dataset() 40 | self.classpaths=[] 41 | self.class_pointer=10*[0] 42 | self.unk=unk 43 | for i in xrange(10): 44 | self.classpaths.append([]) 45 | for j in xrange(len(self.labels)): 46 | label=self.labels[j] 47 | self.classpaths[label].append(j) 48 | def download(self): 49 | data_dir = self.path 50 | if not os.path.exists(data_dir): 51 | os.mkdir(data_dir) 52 | for filename in self.data_files.values(): 53 | path = self.path+'/'+filename 54 | if not os.path.exists(path): 55 | url = urljoin(self.base_url, filename) 56 | util.maybe_download(url, path) 57 | def shuffle_data(self): 58 | images = self.images[:] 59 | labels = self.labels[:] 60 | self.images = [] 61 | self.labels = [] 62 | 63 | idx = np.random.permutation(len(labels)) 64 | for i in idx: 65 | self.images.append(images[i]) 66 | self.labels.append(labels[i]) 67 | def load_dataset(self): 68 | abspaths = {name: self.path+'/'+path 69 | for name, path in self.data_files.items()} 70 | 71 | if self.split=='train': 72 | train_mat = loadmat(abspaths['train']) 73 | train_images = train_mat['X'].transpose((3, 0, 1, 2)) 74 | train_labels = train_mat['y'].squeeze() 75 | train_labels[train_labels == 10] = 0 76 | train_images = train_images.astype(np.float32)/255. 77 | self.images = train_images 78 | self.labels = train_labels 79 | elif self.split=='test': 80 | test_mat = loadmat(abspaths['test']) 81 | test_images = test_mat['X'].transpose((3, 0, 1, 2)) 82 | test_images = test_images.astype(np.float32)/255. 83 | test_labels = test_mat['y'].squeeze() 84 | test_labels[test_labels == 10] = 0 85 | self.images=test_images 86 | self.labels=test_labels 87 | if len(self.select)!=0: 88 | self.images=self.images[self.select] 89 | self.labels=self.labels[self.select] 90 | if len(self.known_class)!=0: 91 | images=[] 92 | labels=[] 93 | known_indices=[i for i in xrange(len(self.images)) if self.labels[i] in self.known_class] 94 | known_images=self.images[known_indices] 95 | known_labels=self.labels[known_indices] 96 | if self.unk==True: 97 | labels=[] 98 | for i in xrange(len(self.labels)): 99 | if self.labels[i] in self.unknown_class: 100 | self.labels[i]=self.unknown_class[0] 101 | else: 102 | self.images=known_images 103 | self.labels=known_labels 104 | 105 | if self.frange==[-1.,1.]: 106 | self.images=self.images*2.0-1.0 107 | def reset_pointer(self): 108 | self.pointer=0 109 | if self.shuffle: 110 | self.shuffle_data() 111 | def reset_class_pointer(self,i): 112 | self.class_pointer[i]=0 113 | if self.shuffle: 114 | self.classpaths[i]=np.random.permutation(self.classpaths[i]) 115 | 116 | def class_next_batch(self,num_per_class): 117 | batch_size=10*num_per_class 118 | selfimages=np.zeros((0,32,32,3)) 119 | selflabels=[] 120 | for i in xrange(10): 121 | selfimages=np.concatenate((selfimages,self.images[self.classpaths[i][self.class_pointer[i]:self.class_pointer[i]+num_per_class]]),0) 122 | selflabels+=self.labels[self.classpaths[i][self.class_pointer[i]:self.class_pointer[i]+num_per_class]] 123 | self.class_pointer[i]+=num_per_class 124 | if self.class_pointer[i]+num_per_class>=len(self.classpaths[i]): 125 | self.reset_class_pointer(i) 126 | return np.array(selfimages),get_one_hot(selflabels,10) 127 | 128 | def next_batch(self,batch_size): 129 | images=self.images[self.pointer:(self.pointer+batch_size)] 130 | labels=self.labels[self.pointer:(self.pointer+batch_size)] 131 | self.pointer+=batch_size 132 | if self.pointer+batch_size>=len(self.labels): 133 | self.reset_pointer() 134 | return np.array(images),get_one_hot(labels,6) 135 | 136 | def main(): 137 | svhn=SVHN(path='data/svhn',unk=True) 138 | a,b=svhn.next_batch(10) 139 | print a[0] 140 | print b 141 | 142 | if __name__=='__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /dataset/svhn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/dataset/svhn.pyc -------------------------------------------------------------------------------- /dataset/usps.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import sys 4 | import util 5 | from urlparse import urljoin 6 | import gzip 7 | import struct 8 | import operator 9 | import numpy as np 10 | from scipy.io import loadmat 11 | def get_one_hot(targets, nb_classes): 12 | return np.eye(nb_classes)[np.array(targets).reshape(-1)] 13 | class USPS: 14 | base_url = 'http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/' 15 | data_files = { 16 | 'train': 'zip.train.gz', 17 | 'test': 'zip.test.gz' 18 | } 19 | 20 | def __init__(self,path=None,shuffle=True,frange=[-1.,1.],known_class=[0,1,2,3,4],unknown_class=[5,6,7,8,9],unk=True,output_size=[28,28],output_channel=1,split='train',select=[]): 21 | self.image_shape=[16,16,1] 22 | self.label_shape=() 23 | self.path=path 24 | self.shuffle=shuffle 25 | self.output_size=output_size 26 | self.output_channel=output_channel 27 | self.frange=frange 28 | #-----------key for open set domain adaptation data preprocessing----------------------------- 29 | self.unk=unk 30 | self.known_class=known_class 31 | self.unknown_class=unknown_class 32 | #---------------------------------------------------------------------------------------------- 33 | 34 | 35 | self.split=split 36 | self.select=select 37 | self.download() 38 | self.pointer=0 39 | self.load_dataset(self.select) 40 | def download(self): 41 | data_dir = self.path 42 | if not os.path.exists(data_dir): 43 | os.mkdir(data_dir) 44 | for filename in self.data_files.values(): 45 | path = self.path+'/'+filename 46 | if not os.path.exists(path): 47 | url = urljoin(self.base_url, filename) 48 | util.maybe_download(url, path) 49 | def shuffle_data(self): 50 | images = self.images[:] 51 | labels = self.labels[:] 52 | self.images = [] 53 | self.labels = [] 54 | 55 | idx = np.random.permutation(len(labels)) 56 | for i in idx: 57 | self.images.append(images[i]) 58 | self.labels.append(labels[i]) 59 | def _read_datafile(self,path): 60 | """Read the proprietary USPS digits data file.""" 61 | labels, images = [], [] 62 | with gzip.GzipFile(path) as f: 63 | for line in f: 64 | vals = line.strip().split() 65 | labels.append(float(vals[0])) 66 | images.append([float(val) for val in vals[1:]]) 67 | labels = np.array(labels, dtype=np.int32) 68 | labels[labels == 10] = 0 # fix weird 0 labels 69 | images = np.array(images, dtype=np.float32).reshape(-1, 16, 16, 1) 70 | images = (images + 1) / 2 71 | return images, labels 72 | def load_dataset(self,select): 73 | abspaths = {name: self.path+'/'+path 74 | for name, path in self.data_files.items()} 75 | 76 | if self.split=='train': 77 | train_images,train_labels = self._read_datafile(abspaths['train']) 78 | self.images = train_images[select] 79 | self.labels = train_labels[select] 80 | if len(select)==0: 81 | self.images=train_images 82 | self.labels=train_labels 83 | elif self.split=='test': 84 | test_images,test_labels = self._read_datafile(abspaths['test']) 85 | self.images=test_images[select] 86 | self.labels=test_labels[select] 87 | if len(select)==0: 88 | self.images=test_images 89 | self.labels=test_labels 90 | if len(self.known_class)!=0: 91 | images=[] 92 | labels=[] 93 | known_indices=[i for i in xrange(len(self.images)) if self.labels[i] in self.known_class] 94 | known_images=self.images[known_indices] 95 | known_labels=self.labels[known_indices] 96 | if self.unk==True: 97 | labels=[] 98 | for i in xrange(len(self.labels)): 99 | if self.labels[i] in self.unknown_class: 100 | self.labels[i]=self.unknown_class[0] 101 | else: 102 | self.images=known_images 103 | self.labels=known_labels 104 | if self.frange==[-1.,1.]: 105 | self.images=self.images*2.0-1. 106 | def reset_pointer(self): 107 | self.pointer=0 108 | if self.shuffle: 109 | self.shuffle_data() 110 | 111 | def class_next_batch(self,num_per_class): 112 | batch_size=10*num_per_class 113 | classpaths=[] 114 | ids=[] 115 | for i in xrange(10): 116 | classpaths.append([]) 117 | for j in xrange(len(self.labels)): 118 | label=self.labels[j] 119 | classpaths[label].append(j) 120 | for i in xrange(10): 121 | ids+=np.random.choice(classpaths[i],size=num_per_class,replace=False).tolist() 122 | selfimages=np.array(self.images) 123 | selflabels=np.array(self.labels) 124 | return np.array(selfimages[ids]),get_one_hot(selflabels[ids],10) 125 | 126 | def next_batch(self,batch_size): 127 | if self.pointer+batch_size>=len(self.labels): 128 | self.reset_pointer() 129 | images=self.images[self.pointer:(self.pointer+batch_size)] 130 | labels=self.labels[self.pointer:(self.pointer+batch_size)] 131 | self.pointer+=batch_size 132 | if len(self.known_class)!=0: 133 | return np.array(images),get_one_hot(labels,len(self.known_class)+1) 134 | else: 135 | return np.array(images),get_one_hot(labels,10) 136 | 137 | 138 | def main(): 139 | #svhn=USPS(path='data/usps',split='train',select=[2,3,45]) 140 | #print len(svhn.images) 141 | usps=USPS(path='data/usps',frange=[0.,1.],split='test') 142 | a,b=usps.next_batch(10) 143 | print a[0] 144 | #print b 145 | 146 | if __name__=='__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /dataset/usps.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/dataset/usps.pyc -------------------------------------------------------------------------------- /dataset/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path 3 | 4 | import requests 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | def maybe_download(url, dest): 9 | """Download the url to dest if necessary, optionally checking file 10 | integrity. 11 | """ 12 | if not os.path.exists(dest): 13 | logger.info('Downloading %s to %s', url, dest) 14 | download(url, dest) 15 | 16 | 17 | def download(url, dest): 18 | """Download the url to dest, overwriting dest if it already exists.""" 19 | response = requests.get(url, stream=True) 20 | with open(dest, 'wb') as f: 21 | for chunk in response.iter_content(chunk_size=1024): 22 | if chunk: 23 | f.write(chunk) 24 | -------------------------------------------------------------------------------- /dataset/util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/dataset/util.pyc -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as tcl 3 | from nns import small,large 4 | class DALearner: 5 | 6 | def __init__(self,name='small',source='mnist',target='usps',num_classes=6): 7 | if name=='small': 8 | self.default_image_size=28 9 | self.num_channels=1 10 | self.model=small.lenet() 11 | else: 12 | self.default_image_size=32 13 | self.num_channels=3 14 | self.model=large.lenet() 15 | self.num_classes=num_classes 16 | self.source=source 17 | self.target=target 18 | self.mean=None 19 | self.bgr=None 20 | self.range=None 21 | 22 | def loss(self,xs,ys,xt,phase=True,keep_prob=0.5,lamb=1.0): 23 | model=self.model 24 | 25 | src_e=model.forward(xs,enc=True,dec=False,reuse=False,phase=phase,keep_prob=keep_prob,nmc=self.num_classes) 26 | print src_e.get_shape() 27 | src_p=model.forward(src_e,enc=False,dec=True,reuse=False,phase=phase,keep_prob=keep_prob,nmc=self.num_classes) 28 | 29 | trg_e=model.forward(xt,enc=True,dec=False,reuse=True,phase=phase,keep_prob=keep_prob,nmc=self.num_classes) 30 | trg_p=model.forward(trg_e,enc=False,dec=True,reuse=True,phase=phase,keep_prob=keep_prob,nmc=self.num_classes) 31 | 32 | #----source classification loss------ 33 | loss_src=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=src_p,labels=ys)) 34 | tf.summary.scalar('class/loss_src',loss_src) 35 | 36 | #-------------------------construct L_adv(xt) using "open set domain adaptation by backpropagation"--------------------------------------------- 37 | p_kone=tf.gather(tf.nn.softmax(trg_p),indices=[self.num_classes-1],axis=1) 38 | loss_adv=-0.5*tf.reduce_mean(tf.log(p_kone+1e-8))-0.5*tf.reduce_mean((tf.log(1.0-(p_kone)+1e-8))) 39 | tf.summary.scalar('adv/loss_trg_adv',loss_adv) 40 | loss_class=( 41 | loss_src 42 | +loss_adv 43 | ) 44 | loss_gen=( 45 | loss_src 46 | -loss_adv 47 | ) 48 | return loss_class,loss_gen,src_p,trg_p 49 | 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/model.pyc -------------------------------------------------------------------------------- /nns/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/nns/__init__.py -------------------------------------------------------------------------------- /nns/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/nns/__init__.pyc -------------------------------------------------------------------------------- /nns/large.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as tcl 3 | 4 | class lenet: 5 | def __init__(self): 6 | pass 7 | def forward(self,x,enc=True,dec=False,phase=True,keep_prob=0.5,reuse=False,nmc=6): 8 | net=tf.identity(x) 9 | if enc: 10 | with tf.variable_scope('gen',reuse=reuse): 11 | for i in xrange(2): 12 | net=tcl.conv2d(net,64,5,1,'VALID',activation_fn=None) 13 | net=tcl.batch_norm(net,is_training=phase) 14 | net=tf.nn.leaky_relu(net) 15 | for i in xrange(2): 16 | net=tcl.conv2d(net,128,3,2,'VALID',activation_fn=None) 17 | net=tcl.batch_norm(net,is_training=phase) 18 | net=tf.nn.leaky_relu(net) 19 | for i in xrange(2): 20 | net=tcl.fully_connected(net,100,activation_fn=None) 21 | net=tcl.batch_norm(net,is_training=phase) 22 | net=tf.nn.leaky_relu(net) 23 | net=tcl.flatten(net) 24 | 25 | if dec: 26 | with tf.variable_scope('class',reuse=reuse): 27 | net=tcl.fully_connected(net,nmc,activation_fn=None) 28 | 29 | return net 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /nns/large.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/nns/large.pyc -------------------------------------------------------------------------------- /nns/small.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.layers as tcl 3 | 4 | class lenet: 5 | def __init__(self): 6 | pass 7 | def forward(self,x,enc=True,dec=False,phase=True,keep_prob=0.5,reuse=False,nmc=6): 8 | net=tf.identity(x) 9 | reg=tcl.l2_regularizer(5e-4) 10 | conv_init=tcl.xavier_initializer() 11 | if enc: 12 | with tf.variable_scope('gen',reuse=reuse): 13 | net=tcl.conv2d(net,20,5,1,'VALID',activation_fn=None,weights_initializer=conv_init,weights_regularizer=reg) 14 | net=tcl.batch_norm(net,is_training=phase,scale=True) 15 | net=tf.nn.leaky_relu(net,0.01) 16 | net=tcl.max_pool2d(net,2,2,'VALID') 17 | 18 | net=tcl.conv2d(net,50,5,1,'VALID',activation_fn=None,weights_initializer=conv_init,weights_regularizer=reg) 19 | net=tcl.batch_norm(net,is_training=phase,scale=True) 20 | net=tf.nn.leaky_relu(net,0.01) 21 | net=tcl.max_pool2d(net,2,2,'VALID') 22 | 23 | net=tf.nn.dropout(net,keep_prob=keep_prob) 24 | net=tcl.flatten(net) 25 | 26 | net=tcl.fully_connected(net,500,activation_fn=None,weights_regularizer=reg) 27 | net=tcl.batch_norm(net,is_training=phase,scale=True) 28 | net=tf.nn.leaky_relu(net,0.01) 29 | if dec: 30 | with tf.variable_scope('class',reuse=reuse): 31 | net=tcl.fully_connected(net,nmc,activation_fn=None,weights_regularizer=reg) 32 | 33 | return net 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /nns/small.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/nns/small.pyc -------------------------------------------------------------------------------- /osda_train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import tensorflow as tf 4 | import datetime 5 | from model import DALearner 6 | from utils import utils 7 | from utils import metrics 8 | from preprocessing.preprocessing import preprocessing 9 | 10 | import math 11 | 12 | tf.app.flags.DEFINE_float('lr', '1e-3', 'Learning rate for adam optimizer') 13 | tf.app.flags.DEFINE_float('dropout_keep_prob', 0.5, 'Dropout keep probability') 14 | tf.app.flags.DEFINE_integer('num_epochs', 200, 'Number of epochs for training') 15 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Batch size') 16 | tf.app.flags.DEFINE_string('net','small', '[small,large]') 17 | tf.app.flags.DEFINE_string('opt','mom', '[adam,mom]') 18 | tf.app.flags.DEFINE_string('train','mnist', '[mnist,usps,svshn]') 19 | tf.app.flags.DEFINE_string('test','usps', '[mnist,usps,svshn]') 20 | tf.app.flags.DEFINE_string('train_root_dir', '../training', 'Root directory to put the training data') 21 | tf.app.flags.DEFINE_integer('log_step', 10000, 'Logging period in terms of iteration') 22 | 23 | #-------------------------open set domain adaptation---------------------------------------- 24 | NUM_CLASSES = 6 25 | FLAGS = tf.app.flags.FLAGS 26 | 27 | TRAIN_FILE=FLAGS.train 28 | TEST_FILE=FLAGS.test 29 | 30 | print TRAIN_FILE+' ---------------------------------------> '+TEST_FILE 31 | print TRAIN_FILE+' ---------------------------------------> '+TEST_FILE 32 | print TRAIN_FILE+' ---------------------------------------> '+TEST_FILE 33 | 34 | TRAIN=utils.get_data(FLAGS.train,split='train',unk=False,shuffle=True,frange=[0.,1.]) 35 | VALID=utils.get_data(FLAGS.test,split='train',unk=True,shuffle=True,frange=[0.,1.]) 36 | TEST=utils.get_data(FLAGS.test,split='train',unk=True,shuffle=False,frange=[0.,1.]) 37 | 38 | def adaptation_factor(x): 39 | den=1.0+math.exp(-10*x) 40 | lamb=2.0/den-1.0 41 | return min(lamb,1.0) 42 | 43 | def main(_): 44 | # Create training directories 45 | now = datetime.datetime.now() 46 | train_dir_name = now.strftime('alexnet_%Y%m%d_%H%M%S') 47 | train_dir = os.path.join(FLAGS.train_root_dir, train_dir_name) 48 | checkpoint_dir = os.path.join(train_dir, 'checkpoint') 49 | tensorboard_dir = os.path.join(train_dir, 'tensorboard') 50 | tensorboard_train_dir = os.path.join(tensorboard_dir, 'train') 51 | tensorboard_val_dir = os.path.join(tensorboard_dir, 'val') 52 | 53 | if not os.path.isdir(FLAGS.train_root_dir): os.mkdir(FLAGS.train_root_dir) 54 | if not os.path.isdir(train_dir): os.mkdir(train_dir) 55 | if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir) 56 | if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir) 57 | if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir) 58 | if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir) 59 | 60 | dropout_keep_prob = tf.placeholder(tf.float32) 61 | revgrad_lamb = tf.placeholder(tf.float32) 62 | is_training=tf.placeholder(tf.bool) 63 | 64 | # Model 65 | model =DALearner(name=FLAGS.net,num_classes=NUM_CLASSES,source=FLAGS.train,target=FLAGS.test) 66 | # Placeholders 67 | x_s = tf.placeholder(tf.float32, [None]+TRAIN.image_shape,name='x') 68 | x_t = tf.placeholder(tf.float32, [None]+TEST.image_shape,name='xt') 69 | x=preprocessing(x_s,model) 70 | xt=preprocessing(x_t,model) 71 | tf.summary.image('Source Images',x) 72 | tf.summary.image('Target Images',xt) 73 | y = tf.placeholder(tf.float32, [None, NUM_CLASSES],name='y') 74 | yt = tf.placeholder(tf.float32, [None, NUM_CLASSES],name='yt') 75 | loss_class,loss_gen,src_p,trg_p= model.loss(x, y, xt, keep_prob=dropout_keep_prob,phase=is_training,lamb=revgrad_lamb) 76 | 77 | #---- Optimizers--------- 78 | main_vars=tf.trainable_variables() 79 | gen_vars=[var for var in main_vars if 'gen' in var.name] 80 | class_vars=[var for var in main_vars if 'class' in var.name] 81 | print gen_vars 82 | print class_vars 83 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 84 | with tf.control_dependencies(update_ops): 85 | gen_op=utils.get_optimizer(FLAGS.opt,FLAGS.lr,loss_gen,gen_vars) 86 | class_op=utils.get_optimizer(FLAGS.opt,FLAGS.lr,loss_class,class_vars) 87 | optimizer=tf.group(gen_op,class_op) 88 | 89 | #------------ A series of metrics for evaluation: OS,OS*,ALL,UNK---------------------------------- 90 | target_predict=trg_p 91 | with tf.variable_scope('metrics') as scope: 92 | os_acc,os_update_op=metrics.OS(hx=target_predict,y=yt,num_classes=NUM_CLASSES) 93 | osstar_acc,osstar_update_op=metrics.OS_star(hx=target_predict,y=yt,num_classes=NUM_CLASSES) 94 | all_acc,all_update_op=metrics.ALL(hx=target_predict,y=yt) 95 | unk_acc,unk_update_op=metrics.UNK(hx=target_predict,y=yt,num_classes=NUM_CLASSES) 96 | metrics_update_op=tf.group(os_update_op,osstar_update_op,all_update_op,unk_update_op) 97 | metrics_variables=[v for v in tf.local_variables() if v.name.startswith('metrics')] 98 | reset_ops=[v.initializer for v in metrics_variables] 99 | print metrics_variables 100 | 101 | 102 | train_writer=tf.summary.FileWriter('./log/tensorboard') 103 | train_writer.add_graph(tf.get_default_graph()) 104 | merged=tf.summary.merge_all() 105 | 106 | 107 | 108 | 109 | print '============================GLOBAL TRAINABLE VARIABLES ============================' 110 | print tf.trainable_variables(),' ',len(tf.trainable_variables()) 111 | #print '============================GLOBAL VARIABLES ======================================' 112 | #print tf.global_variables() 113 | 114 | with tf.Session() as sess: 115 | sess.run(tf.global_variables_initializer()) 116 | sess.run(tf.local_variables_initializer()) 117 | saver=tf.train.Saver() 118 | train_writer.add_graph(sess.graph) 119 | 120 | print("{} Start training...".format(datetime.datetime.now())) 121 | for step in range(200*600): 122 | # Start training 123 | batch_xs, batch_ys = TRAIN.next_batch(FLAGS.batch_size) 124 | Tbatch_xs, Tbatch_ys = VALID.next_batch(FLAGS.batch_size) 125 | MAX_STEP=10000 126 | constant=adaptation_factor(step*1.0/MAX_STEP) 127 | summary,_=sess.run([merged,optimizer], feed_dict={x_s: batch_xs,x_t: Tbatch_xs,is_training:True,y: batch_ys,revgrad_lamb:constant,dropout_keep_prob:0.5,yt:Tbatch_ys}) 128 | train_writer.add_summary(summary,step) 129 | 130 | if step%600==0: 131 | epoch=step/600 132 | print("{} Start validation".format(datetime.datetime.now())) 133 | #print 'Epoch {0:<10} Step {1:<10} C_loss {2:<10} Advloss {3:<10}'.format(epoch,step,closs,advloss) 134 | test_acc = 0. 135 | test_count = 0 136 | bs=500 137 | print constant 138 | print 'test_counts ',len(TEST.labels) 139 | for _ in xrange((len(TEST.labels))/bs): 140 | batch_tx, batch_ty = TEST.next_batch(bs) 141 | sess.run(metrics_update_op, feed_dict={x_t: batch_tx, yt: batch_ty, is_training:False,dropout_keep_prob: 1.}) 142 | osacc,osstaracc,allacc,unkacc = sess.run([os_acc,osstar_acc,all_acc,unk_acc], feed_dict={x_t: batch_tx, yt: batch_ty, is_training:False,dropout_keep_prob: 1.}) 143 | test_count += bs 144 | res=len(TEST.labels)%bs 145 | if res>0: 146 | batch_tx, batch_ty = TEST.next_batch(res) 147 | sess.run(metrics_update_op, feed_dict={x_t: batch_tx, yt: batch_ty, is_training:False,dropout_keep_prob: 1.}) 148 | osacc,osstaracc,allacc,unkacc = sess.run([os_acc,osstar_acc,all_acc,unk_acc], feed_dict={x_t: batch_tx, yt: batch_ty, is_training:False,dropout_keep_prob: 1.}) 149 | 150 | print "Epoch {4:<5} OS {0:<10} OS* {1:<10} ALL {2:<10} UNK {3:<10}".format(osacc,osstaracc,allacc,unkacc,epoch) 151 | sess.run(reset_ops) 152 | if epoch==300: 153 | return 154 | 155 | 156 | if __name__ == '__main__': 157 | tf.app.run() 158 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/preprocessing/__init__.py -------------------------------------------------------------------------------- /preprocessing/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/preprocessing/__init__.pyc -------------------------------------------------------------------------------- /preprocessing/preprocessing.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | models = {} 8 | 9 | def register_model_fn(name): 10 | def decorator(fn): 11 | models[name] = fn 12 | # set default parameters 13 | fn.range = None 14 | fn.mean = None 15 | fn.bgr = False 16 | return fn 17 | return decorator 18 | 19 | def get_model_fn(name): 20 | return models[name] 21 | 22 | def preprocessing(inputs, model_fn): 23 | inputs = tf.cast(inputs, tf.float32) 24 | channels = int(inputs.get_shape()[-1]) 25 | if channels == 1 and model_fn.num_channels == 3: 26 | print 'really GREY TO RGB?' 27 | logging.info('Converting grayscale images to RGB') 28 | inputs = tf.image.grayscale_to_rgb(inputs) 29 | elif channels == 3 and model_fn.num_channels == 1: 30 | print 'really RGB TO GREY?' 31 | inputs = tf.image.rgb_to_grayscale((inputs)) 32 | if model_fn.range is not None: 33 | print 'range=[255]?' 34 | inputs = model_fn.range * inputs 35 | if model_fn.default_image_size is not None: 36 | size = model_fn.default_image_size 37 | logging.info('Resizing images to [{}, {}]'.format(size, size)) 38 | inputs = tf.image.resize_images(inputs, [size, size]) 39 | print 'after resize ',inputs.get_shape() 40 | if model_fn.mean is not None: 41 | logging.info('Performing mean subtraction.') 42 | inputs = inputs - tf.reshape(tf.constant(model_fn.mean),[-1,3]) 43 | print 'after mean ',inputs.get_shape() 44 | if model_fn.bgr: 45 | logging.info('Performing BGR transposition.') 46 | inputs = inputs[:, :, [2, 1, 0]] 47 | #print 'start nomrliazation (x-mean)/sqrt(var+epsilon)' 48 | #mean,var=tf.nn.moments(inputs,axes=[0]) 49 | #inputs=(inputs-mean)/tf.sqrt(var+1e-8) 50 | return inputs 51 | 52 | RGB2GRAY = np.array([0.2989, 0.5870, 0.1140], dtype=np.float32) 53 | 54 | def rgb2gray(image): 55 | return tf.reduce_sum(tf.multiply(image, tf.constant(RGB2GRAY)), 56 | -1, 57 | keep_dims=True) 58 | 59 | def gray2rgb(image): 60 | return tf.multiply(image, tf.constant(RGB2GRAY)) 61 | -------------------------------------------------------------------------------- /preprocessing/preprocessing.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/preprocessing/preprocessing.pyc -------------------------------------------------------------------------------- /results/ladv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/results/ladv.png -------------------------------------------------------------------------------- /results/test_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/results/test_accuracy.png -------------------------------------------------------------------------------- /results/um.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/results/um.png -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/utils/__init__.py -------------------------------------------------------------------------------- /utils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/utils/__init__.pyc -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import math 4 | 5 | def OS_star(hx,y,num_classes=3): 6 | disc_y=tf.argmax(y,1) 7 | disc_hx=tf.argmax(hx,1) 8 | known_indices=tf.where(tf.not_equal(disc_y,num_classes-1)) 9 | known_hx=tf.gather(disc_hx,known_indices) 10 | known_y=tf.gather(disc_y,known_indices) 11 | known_hx=tf.reshape(known_hx,[-1]) 12 | known_y=tf.reshape(known_y,[-1]) 13 | osstar,osstar_update=tf.metrics.mean_per_class_accuracy(labels=known_y,predictions=known_hx,num_classes=num_classes) 14 | return osstar*(num_classes)*1.0/(num_classes-1.0),osstar_update 15 | 16 | def OS(hx,y,num_classes=6): 17 | os,os_update=tf.metrics.mean_per_class_accuracy(labels=tf.argmax(y,1),predictions=tf.argmax(hx,1),num_classes=num_classes) 18 | return os,os_update 19 | 20 | def ALL(hx,y): 21 | allacc,all_update=tf.metrics.accuracy(labels=tf.argmax(y,1),predictions=tf.argmax(hx,1)) 22 | return allacc,all_update 23 | 24 | def UNK(hx,y,num_classes=6): 25 | 26 | disc_y=tf.argmax(y,1) 27 | disc_hx=tf.argmax(hx,1) 28 | known_indices=tf.where(tf.equal(disc_y,num_classes-1)) 29 | known_hx=tf.gather(disc_hx,known_indices) 30 | known_y=tf.gather(disc_y,known_indices) 31 | known_hx=tf.reshape(known_hx,[-1]) 32 | known_y=tf.reshape(known_y,[-1]) 33 | unk,unk_update=tf.metrics.accuracy(labels=known_y,predictions=known_hx) 34 | return unk,unk_update 35 | -------------------------------------------------------------------------------- /utils/metrics.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/utils/metrics.pyc -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path 3 | 4 | import requests 5 | import tensorflow as tf 6 | logger = logging.getLogger(__name__) 7 | import sys 8 | sys.path.append('../') 9 | 10 | from dataset.mnist import MNIST 11 | from dataset.svhn import SVHN 12 | from dataset.usps import USPS 13 | 14 | 15 | def get_data(name,split,unk,shuffle,frange): 16 | prefix='dataset/data/' 17 | if name=='mnist': 18 | return MNIST(path=prefix+'mnist',split=split,unk=unk,shuffle=shuffle,frange=frange) 19 | if name=='usps': 20 | return USPS(path=prefix+'usps',split=split,unk=unk,shuffle=shuffle,frange=frange) 21 | if name=='svhn': 22 | return SVHN(path=prefix+'svhn',split=split,unk=unk,shuffle=shuffle,frange=frange) 23 | 24 | def get_optimizer(opt,lr,loss,var): 25 | if opt=='adam': 26 | return tf.train.AdamOptimizer(lr,0.9).minimize(loss,var_list=var) 27 | if opt=='mom': 28 | return tf.train.MomentumOptimizer(lr,0.9).minimize(loss,var_list=var) 29 | def maybe_download(url, dest): 30 | """Download the url to dest if necessary, optionally checking file 31 | integrity. 32 | """ 33 | if not os.path.exists(dest): 34 | logger.info('Downloading %s to %s', url, dest) 35 | download(url, dest) 36 | 37 | 38 | def download(url, dest): 39 | """Download the url to dest, overwriting dest if it already exists.""" 40 | response = requests.get(url, stream=True) 41 | with open(dest, 'wb') as f: 42 | for chunk in response.iter_content(chunk_size=1024): 43 | if chunk: 44 | f.write(chunk) 45 | -------------------------------------------------------------------------------- /utils/utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mid-Push/Open_set_domain_adaptation/72508c9008019074ee863504e57d81ff2c7a1db1/utils/utils.pyc --------------------------------------------------------------------------------