├── .gitignore ├── DeepJDOT.py ├── LICENSE ├── README.md ├── blobs_dataset_example ├── blobs.png ├── latent_space_after_adaptation.png └── latent_space_without_adaptation.png └── testDeepJDOT.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /DeepJDOT.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Saturday 28.05.2021 4 | 5 | Paper and original tensorflow implementation: damodara 6 | 7 | 8 | @author: marc seibel 9 | Deepjdot - class file 10 | 11 | This is a translation from the original tensorflow implementation into pytorch. 12 | """ 13 | import torch 14 | import numpy as np 15 | import ot 16 | import tqdm 17 | 18 | class Deepjdot(object): 19 | def __init__(self, model, batch_size, n_class, optim=None, allign_loss=1.0, tar_cl_loss=1.0, 20 | sloss=0.0,tloss=1.0,int_lr=0.01, ot_method='emd', 21 | jdot_alpha=0.01, lr_decay=True, verbose=1): 22 | 23 | self.net = model # target model 24 | self.batch_size = batch_size 25 | self.n_class = n_class 26 | if optim is not None: 27 | raise ValueError("A custom optimizer is not implemented yet") 28 | self.optimizer = optim 29 | # initialize the gamma (coupling in OT) with zeros 30 | self.gamma = torch.zeros(size=(self.batch_size, self.batch_size)) 31 | # whether to minimize with classification loss 32 | 33 | self.train_cl = torch.tensor(tar_cl_loss) # translated from K.variable 34 | # whether to minimize with the allignment loss 35 | self.train_algn = torch.tensor(allign_loss) 36 | self.sloss = torch.tensor(sloss) # weight for source classification 37 | self.tloss = torch.tensor(tloss) # weight for target classification 38 | 39 | self.verbose = verbose 40 | self.int_lr = int_lr # initial learning rate 41 | self.lr_decay= lr_decay 42 | # 43 | self.ot_method = ot_method 44 | self.jdot_alpha=jdot_alpha # weight for the alpha term 45 | 46 | 47 | # target classification cross ent loss and source cross entropy 48 | def classifier_cat_loss(source_ypred, ypred_t, ys): 49 | ''' 50 | classifier loss based on categorical cross entropy in the target domain 51 | y_true: 52 | y_pred: pytorch tensor which has gradients 53 | 54 | 0:batch_size - is source samples 55 | batch_size:end - is target samples 56 | self.gamma - is the optimal transport plan 57 | ''' 58 | # pytorch has the mean-inbuilt, 59 | source_loss = torch.nn.functional.cross_entropy(source_ypred, 60 | torch.argmax(ys,dim=1)) 61 | 62 | 63 | # categorical cross entropy loss 64 | ypred_t = torch.log(ypred_t) 65 | # loss calculation based on double sum (sum_ij (ys^i, ypred_t^j)) 66 | loss = -torch.matmul(ys, torch.transpose(ypred_t,1,0)) 67 | # returns source loss + target loss 68 | 69 | # todo: check function of tloss train_cl, and sloss 70 | return self.train_cl*(self.tloss*torch.sum(self.gamma * loss) + self.sloss*source_loss) 71 | self.classifier_cat_loss = classifier_cat_loss 72 | 73 | # L2 distance 74 | def L2_dist(x,y): 75 | ''' 76 | compute the squared L2 distance between two matrics 77 | ''' 78 | distx = torch.reshape(torch.sum(torch.square(x),1), (-1,1)) 79 | disty = torch.reshape(torch.sum(torch.square(y),1), (1,-1)) 80 | dist = distx + disty 81 | dist -= 2.0*torch.matmul(x, torch.transpose(y,0,1)) 82 | return dist 83 | 84 | # feature allignment loss 85 | def align_loss(g_source, g_target): 86 | ''' 87 | source and target alignment loss in the intermediate layers of the target model 88 | allignment is performed in the target model (both source and target features are from target model) 89 | y-pred - is the value of intermediate layers in the target model 90 | 1:batch_size - is source samples 91 | batch_size:end - is target samples 92 | ''' 93 | # source domain features 94 | #gs = y_pred[:batch_size,:] # this should not work???? 95 | # target domain features 96 | #gt = y_pred[batch_size:,:] 97 | gdist = L2_dist(g_source,g_target) 98 | return self.jdot_alpha * torch.sum(self.gamma * (gdist)) 99 | self.align_loss= align_loss 100 | 101 | def feature_extraction(model, data, out_layer_num=-2): 102 | ''' 103 | Chop simple sequential model from layer 0 to out_layer_num. 104 | 105 | # https://discuss.pytorch.org/t/is-it-possible-to-slice-a-model-at-an-arbitrary-layer/53766 106 | comment: This method has no internal usage. 107 | 108 | 109 | 110 | extract the features from the pre-trained model 111 | inp_layer_num - input layer 112 | out_layer_num -- from which layer to extract the features 113 | ''' 114 | 115 | intermediate_layer_model = model[:out_layer_num] 116 | intermediate_output = intermediate_layer_model(data) 117 | return intermediate_output 118 | self.feature_extraction = feature_extraction 119 | 120 | 121 | 122 | def fit(self, source_traindata, ys_label, target_traindata, target_label = None, 123 | n_iter=5000, cal_bal=True, sample_size=None): 124 | ''' 125 | source_traindata - source domain training data 126 | ys_label - source data true labels 127 | target_traindata - target domain training data 128 | cal_bal - True: source domain samples are equally represented from 129 | all the classes in the mini-batch (that is, n samples from each class) 130 | - False: source domain samples are randomly sampled 131 | target_label - is not None : compute the target accuracy over the iterations 132 | ''' 133 | 134 | ns = source_traindata.shape[0] 135 | nt = target_traindata.shape[0] 136 | method = self.ot_method # for optimal transport 137 | alpha = self.jdot_alpha 138 | t_acc = [] 139 | g_metric ='deep' # to allign in intermediate layers, when g_metric='original', the 140 | # alignment loss is performed wrt original input features (StochJDOT) 141 | 142 | # function to sample n samples from each class 143 | def mini_batch_class_balanced(label, sample_size=20, shuffle=False): 144 | ''' sample the mini-batch with class balanced 145 | ''' 146 | label = np.argmax(label, axis=1) 147 | if shuffle: 148 | rindex = np.random.permutation(len(label)) 149 | label = label[rindex] 150 | 151 | n_class = len(np.unique(label)) 152 | index = [] 153 | for i in range(n_class): 154 | s_index = np.nonzero(label == i) 155 | s_ind = np.random.permutation(s_index[0]) 156 | index = np.append(index, s_ind[0:sample_size]) 157 | # print(index) 158 | index = np.array(index, dtype=int) 159 | return index 160 | 161 | # target model compliation and optimizer 162 | optimizer = torch.optim.SGD(self.net.parameters(), self.int_lr) 163 | 164 | 165 | cat_losses = [] 166 | align_losses = [] 167 | with tqdm.tqdm(range(n_iter), unit='batch') as tepoch: 168 | for i in tepoch: 169 | 170 | if self.lr_decay and i > 0 and i%5000 ==0: 171 | for g in optimizer.param_groups: 172 | g['lr'] = g['lr']*0.1 173 | 174 | # source domain mini-batch indexes 175 | if cal_bal: 176 | s_ind = mini_batch_class_balanced(ys_label, sample_size=sample_size) 177 | self.sbatch_size = len(s_ind) 178 | else: 179 | s_ind = np.random.choice(ns, self.batch_size) 180 | self.sbatch_size = self.batch_size 181 | # target domain mini-batch indexes 182 | t_ind = np.random.choice(nt, self.batch_size) 183 | 184 | # source and target domain mini-batch samples 185 | xs_batch = torch.tensor(source_traindata[s_ind]).type(torch.float) 186 | ys = torch.tensor(ys_label[s_ind]) 187 | xt_batch = torch.tensor(target_traindata[t_ind]).type(torch.float) 188 | def to_categorical(y, num_classes): 189 | """ 1-hot encodes a tensor """ 190 | return torch.eye(num_classes, dtype=torch.int8)[y] 191 | ys_cat = to_categorical(ys,3) 192 | s = xs_batch.shape 193 | 194 | batch = torch.vstack((xs_batch, xt_batch)) 195 | #batch.to(device) 196 | 197 | 198 | self.net.eval() # sets BatchNorm and Dropout in Test mode 199 | # concat of source and target samples and prediction 200 | with torch.no_grad(): 201 | modelpred = self.net(batch) 202 | 203 | # modelpred[0] - is softmax prob, and modelpred[1] - is intermediate layer 204 | gs_batch = modelpred[1][:self.batch_size, :] 205 | gt_batch = modelpred[1][self.batch_size:, :] 206 | # softmax prediction of target samples 207 | fs_pred = modelpred[0][:self.batch_size,:] 208 | ft_pred = modelpred[0][self.batch_size:,:] 209 | 210 | if g_metric=='orginal': 211 | # compution distance metric in the image space 212 | if len(s) == 3: # when the input is image, convert into 2D matrix 213 | C0 = torch.cdist(xs_batch.reshape(-1, s[1] * s[2]), 214 | xt_batch.reshape(-1, s[1] * s[2]), 215 | p=2.0)**2 216 | 217 | elif len(s) == 4: 218 | C0 = torch.cdist(xs_batch.reshape(-1, s[1] * s[2] * s[3]), 219 | xt_batch.reshape(-1, s[1] * s[2] * s[3]), 220 | p=2.0)**2 221 | else: 222 | # distance computation between source and target in deep layer 223 | C0 = torch.cdist(gs_batch, gt_batch, p=2.0)**2 224 | 225 | ys_cat = ys_cat.type(torch.float) 226 | C1 = torch.cdist(ys_cat, ft_pred, p=2)**2 227 | 228 | # JDOT ground metric 229 | C= alpha*C0+C1 230 | 231 | # JDOT optimal coupling (gamma) 232 | if method == 'emd': 233 | gamma=ot.emd(ot.unif(gs_batch.shape[0]), 234 | ot.unif(gt_batch.shape[0]),C) 235 | 236 | # update the computed gamma 237 | self.gamma = torch.tensor(gamma) 238 | 239 | 240 | self.net.train() # Batchnorm and Dropout for train mode 241 | optimizer.zero_grad() 242 | # concat of source and target samples and prediction 243 | modelpred = self.net(batch) 244 | gs_batch = modelpred[1][:self.batch_size, :] 245 | gt_batch = modelpred[1][self.batch_size:, :] 246 | # softmax prediction of target samples 247 | fs_pred = modelpred[0][:self.batch_size,:] 248 | ft_pred = modelpred[0][self.batch_size:,:] 249 | 250 | 251 | cat_loss = self.classifier_cat_loss(fs_pred, ft_pred, ys_cat) 252 | align_loss = self.align_loss(gs_batch, gt_batch) 253 | 254 | loss = cat_loss + align_loss 255 | #loss = criterion(outputs, batch.y) 256 | loss.backward() 257 | optimizer.step() 258 | 259 | cat_losses += [cat_loss.item()] 260 | align_losses += [align_loss.item()] 261 | if self.verbose: 262 | if i%10==0: 263 | cl = np.mean(cat_losses[-10:]) 264 | al = np.mean(align_losses[-10:]) 265 | #print('tl_loss ={:f}'.format(cl)) 266 | #print('fe_loss ={:f}'.format(al)) 267 | #print('tot_loss={:f}'.format(cl+al)) 268 | if target_label is not None: 269 | tpred = self.net(target_traindata)[0] 270 | t_acc.append(torch.mean(target_label==torch.argmax(tpred,1))) 271 | print('Target acc\n', t_acc[-1]) 272 | tepoch.set_postfix(loss=al+cl) 273 | 274 | return [cat_losses,align_losses], t_acc 275 | 276 | def predict(self, data): 277 | data = torch.tensor(data.astype(np.float32)) 278 | self.net.eval() 279 | with torch.no_grad(): 280 | ypred = self.net(data) 281 | return ypred 282 | 283 | def evaluate(self, data, label): 284 | """ 285 | label as digits (0,1,... , num_classes) 286 | """ 287 | data = torch.tensor(data).type(torch.float) 288 | label = torch.tensor(label) 289 | self.net.eval() 290 | with torch.no_grad(): 291 | ypred = self.net(data) 292 | score = torch.mean((label==torch.argmax(ypred[0],1)).type(torch.float)) 293 | return score 294 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 msseibel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepJDOT-pytorch 2 | Translation of DeepJDOT from tensorflow into pytorch. 3 | Visit https://github.com/bbdamodaran/deepJDOT for the original implementation and paper. 4 | -------------------------------------------------------------------------------- /blobs_dataset_example/blobs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msseibel/DeepJDOT-pytorch/c349c16f88ba4a34b40a174163eba4f13281df38/blobs_dataset_example/blobs.png -------------------------------------------------------------------------------- /blobs_dataset_example/latent_space_after_adaptation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msseibel/DeepJDOT-pytorch/c349c16f88ba4a34b40a174163eba4f13281df38/blobs_dataset_example/latent_space_after_adaptation.png -------------------------------------------------------------------------------- /blobs_dataset_example/latent_space_without_adaptation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msseibel/DeepJDOT-pytorch/c349c16f88ba4a34b40a174163eba4f13281df38/blobs_dataset_example/latent_space_without_adaptation.png -------------------------------------------------------------------------------- /testDeepJDOT.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Jan 31 19:26:41 2018 5 | 6 | Paper and original tensorflow implementation: damodara 7 | 8 | Pytorch implementation 9 | @author: msseibel 10 | 11 | DeepJDOT: with emd for the sample data 12 | """ 13 | 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from sklearn.datasets import make_blobs 17 | import torch 18 | import torch.nn.functional as F 19 | import tqdm 20 | 21 | #seed=1985 22 | #np.random.seed(seed) 23 | 24 | #%% 25 | source_traindata, source_trainlabel = make_blobs(1200, centers=[[0, -1], [0, 0], [0, 1]], cluster_std=0.2) 26 | target_traindata, target_trainlabel = make_blobs(1200, centers=[[1, 0], [1, 1], [1, 2]], cluster_std=0.2) 27 | plt.figure() 28 | plt.scatter(source_traindata[:,0], source_traindata[:,1], c=source_trainlabel, marker='o', alpha=0.4) 29 | plt.scatter(target_traindata[:,0], target_traindata[:,1], c=target_trainlabel, marker='x', alpha=0.4) 30 | plt.legend(['source train data', 'target train data']) 31 | plt.title("2D blobs visualization (shape=domain, color=class)") 32 | 33 | #%% optimizer 34 | n_class = len(np.unique(source_trainlabel)) 35 | n_dim = np.shape(source_traindata) 36 | 37 | #%% feature extraction and classifier function definition 38 | class BlobModel(torch.nn.Module): 39 | def __init__(self): 40 | super(BlobModel, self).__init__() 41 | self.fc1 = torch.nn.Linear(2,500) 42 | self.fc2 = torch.nn.Linear(500,100) 43 | self.fc3 = torch.nn.Linear(100,100) 44 | self.fc_out = torch.nn.Linear(100,3) 45 | 46 | def forward(self, batch): 47 | x1 = F.relu(self.fc1(batch), True) 48 | code = F.relu(self.fc2(x1), True) 49 | x2 = F.relu(self.fc3(code), True) 50 | clf = F.softmax(self.fc_out(x2),-1) 51 | 52 | return clf, code 53 | 54 | 55 | 56 | batch_size = 64 57 | n_iter = 1200*10 58 | source_model = BlobModel() 59 | criterion = torch.nn.CrossEntropyLoss() 60 | optim = torch.optim.SGD(source_model.parameters(), lr=0.001) 61 | 62 | for i in tqdm.tqdm(range(n_iter), unit=" batches"): 63 | 64 | ind = np.random.choice(len(source_traindata),size=batch_size,replace=False) 65 | xbatch = torch.tensor(source_traindata [ind].astype(np.float32)) 66 | ybatch = torch.tensor(source_trainlabel[ind]) 67 | 68 | optim.zero_grad() 69 | outputs, latent_code = source_model(xbatch) 70 | 71 | loss = criterion(outputs, ybatch) 72 | #loss = criterion(outputs, batch.y) 73 | loss.backward() 74 | optim.step() 75 | 76 | #tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy) 77 | 78 | #source_model.fit(source_traindata, 79 | # source_trainlabel_cat, 80 | # batch_size=128, 81 | # epochs=100, 82 | # validation_data=(target_traindata, target_trainlabel_cat)) 83 | #%% Evaluate Model trained on source data 84 | source_model.eval() 85 | subset = 200 86 | with torch.no_grad(): 87 | preds_train, smodel_source_feat = source_model( 88 | torch.tensor(source_traindata.astype(np.float32))) 89 | 90 | smodel_source_feat = smodel_source_feat[:200] 91 | preds_train = torch.argmax(preds_train,dim=1) 92 | 93 | preds_targettrain, smodel_target_feat = source_model( 94 | torch.tensor(target_traindata.astype(np.float32))) 95 | smodel_target_feat = smodel_target_feat[:subset] 96 | preds_targettrain = torch.argmax(preds_targettrain,dim=1) 97 | 98 | source_acc = torch.mean((preds_train== torch.tensor(source_trainlabel)).type(torch.float)) 99 | target_acc = torch.mean((preds_targettrain == torch.tensor(target_trainlabel)).type(torch.float)) 100 | print("") 101 | print("source acc using source model", source_acc) 102 | print("target acc using source model", target_acc) 103 | 104 | #%% deepjdot model and training 105 | import DeepJDOT 106 | 107 | batch_size=128 108 | sloss = 2.0; tloss=1.0; int_lr=0.002; jdot_alpha=5.0 109 | # DeepJDOT model initalization 110 | al_model = DeepJDOT.Deepjdot(source_model, batch_size, n_class, optim=None,allign_loss=1.0, 111 | sloss=sloss,tloss=tloss,int_lr=int_lr,jdot_alpha=jdot_alpha, 112 | lr_decay=True,verbose=1) 113 | # DeepJDOT model fit 114 | losses, tacc = al_model.fit(source_traindata, source_trainlabel, target_traindata, 115 | n_iter=1500,cal_bal=False) 116 | 117 | 118 | #%% accuracy assesment 119 | tarmodel_sacc = al_model.evaluate(source_traindata, 120 | source_trainlabel) 121 | acc = al_model.evaluate(target_traindata, target_trainlabel) 122 | print("source loss & acc using source+target model", tarmodel_sacc) 123 | print("target loss & acc using source+target model", acc) 124 | 125 | 126 | #%% intermediate layers of source and target domain for TSNE plot of target (DeepJDOT) model 127 | source_model.eval() 128 | subset = 200 129 | with torch.no_grad(): 130 | al_sourcedata = al_model.predict(source_traindata[:subset,])[1] 131 | al_targetdata = al_model.predict(target_traindata[:subset,])[1] 132 | 133 | #%% function for TSNE plot (source and target are combined) 134 | def tsne_plot(xs, xt, xs_label, xt_label, subset=True, title=None, pname=None): 135 | 136 | num_test=100 137 | if subset: 138 | combined_imgs = np.concatenate([xs[0:num_test], xt[0:num_test]]) 139 | combined_labels = np.concatenate([xs_label[0:num_test],xt_label[0:num_test]]) 140 | combined_labels = combined_labels.astype('int').T 141 | print(combined_labels.shape) 142 | print(combined_imgs.shape) 143 | 144 | from sklearn.manifold import TSNE 145 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=3000) 146 | source_only_tsne = tsne.fit_transform(combined_imgs) 147 | plt.figure(figsize=(10, 10)) 148 | plt.scatter(source_only_tsne[:num_test,0], source_only_tsne[:num_test,1], 149 | c=combined_labels[:num_test], s=75, marker='o', alpha=0.5, label='source train data') 150 | plt.scatter(source_only_tsne[num_test:,0], source_only_tsne[num_test:,1], 151 | c=combined_labels[num_test:],s=50,marker='x',alpha=0.5,label='target train data') 152 | plt.legend(loc='best') 153 | plt.title(title) 154 | 155 | #%% TSNE plots of source model and target model 156 | title = 'tsne plot of source and target data with source model\n2D blobs visualization (shape=domain, color=class)' 157 | tsne_plot(smodel_source_feat, smodel_target_feat, source_trainlabel, 158 | target_trainlabel, title=title) 159 | 160 | title = 'tsne plot of source and target data with source+target model\n2D blobs visualization (shape=domain, color=class)' 161 | tsne_plot(al_sourcedata, al_targetdata, source_trainlabel, 162 | target_trainlabel, title=title) 163 | --------------------------------------------------------------------------------