├── LICENSE ├── README.md ├── active ├── LASsampler.py ├── MHPsampler.py ├── __init__.py ├── budget.py ├── sampler.py └── utils.py ├── config ├── __init__.py └── defaults.py ├── configs ├── office31.yaml ├── officehome.yaml ├── officehome_RSUT.yaml └── visda.yaml ├── data └── image_list │ ├── domainnet │ ├── clipart_test.txt │ ├── clipart_train.txt │ ├── infograph_test.txt │ ├── infograph_train.txt │ ├── painting_test.txt │ ├── painting_train.txt │ ├── quickdraw_test.txt │ ├── quickdraw_train.txt │ ├── real_test.txt │ ├── real_train.txt │ ├── sketch_test.txt │ └── sketch_train.txt │ ├── office31 │ ├── amazon.txt │ ├── dslr.txt │ └── webcam.txt │ ├── officehome │ ├── Art.txt │ ├── Clipart.txt │ ├── Product.txt │ └── RealWorld.txt │ ├── officehome_RSUT │ ├── Clipart_RS.txt │ ├── Clipart_UT.txt │ ├── Product_RS.txt │ ├── Product_UT.txt │ ├── RealWorld_RS.txt │ └── RealWorld_UT.txt │ └── visda │ ├── train.txt │ └── validation.txt ├── dataset ├── ASDADataset.py ├── image_list.py ├── randaugment.py ├── randaugmentMC.py └── transform.py ├── fig └── framework.png ├── main.py ├── model ├── __init__.py ├── adaptor.py ├── grl.py ├── models.py └── network.py ├── run.sh ├── solver ├── CDACsolver.py ├── MCCsolver.py ├── MMEsolver.py ├── PAAsolver.py ├── __init__.py ├── solver.py └── utils.py └── utils ├── logger.py ├── lr_schedule.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Local Context-Aware Active Domain Adaptation 2 | 3 | Pytorch implementation of LADA. 4 | > [Local Context-Aware Active Domain Adaptation](https://arxiv.org/abs/2208.12856) 5 | > Tao Sun, Cheng Lu, and Haibin Ling 6 | > *ICCV 2023* 7 | > 8 | ## Abstract 9 | Active Domain Adaptation (ADA) queries the labels of a small number of selected target samples to help adapting a model from a source domain to a target domain. The local context of queried data is important, especially when the domain gap is large. However, this has not been fully explored by existing ADA works. 10 | 11 | In this paper, we propose a Local context-aware ADA framework, named LADA, to address this issue. To select informative target samples, we devise a novel criterion based on the local inconsistency of model predictions. Since the labeling budget is usually small, fine-tuning model on only queried data can be inefficient. We progressively augment labeled target data with the confident neighbors in a class-balanced manner. 12 | 13 | Experiments validate that the proposed criterion chooses more informative target samples than existing active selection strategies. Furthermore, our full method surpasses recent ADA arts on various benchmarks. 14 |

15 |
16 |

17 | 18 | 19 | ## Usage 20 | ### Prerequisites 21 | We experimented with python==3.8, pytorch==1.8.0, cudatoolkit==11.1. 22 | 23 | To start, download the [office31](https://faculty.cc.gatech.edu/~judy/domainadapt/), [Office-Home](https://www.hemanthdv.org/officeHomeDataset.html), [VisDA](https://ai.bu.edu/visda-2017/) datasets and set up the path in ./data folder. 24 | 25 | ### Supported methods 26 | | Active Criteria | Paper | Implementation | 27 | |-----------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:----------------------------------------:| 28 | | Random | - | [random](active/sampler.py) | 29 | | Entropy | - | [entropy](active/sampler.py) | 30 | | Margin | - | [margin](active/sampler.py) | 31 | | LeastConfidence | - | [leastConfidence](active/sampler.py) | 32 | | CoreSet | [ICLR 2018](https://openreview.net/pdf?id=H1aIuk-RW) | [coreset](active/sampler.py) | 33 | | AADA | [WACV 2020](https://openaccess.thecvf.com/content_WACV_2020/papers/Su_Active_Adversarial_Domain_Adaptation_WACV_2020_paper.pdf) | [AADA](active/sampler.py) | 34 | | BADGE | [ICLR 2020](https://openreview.net/pdf?id=ryghZJBKPS) | [BADGE](active/sampler.py) | 35 | | CLUE | [ICCV 2021](https://openaccess.thecvf.com/content/ICCV2021/papers/Prabhu_Active_Domain_Adaptation_via_Clustering_Uncertainty-Weighted_Embeddings_ICCV_2021_paper.pdf) | [CLUE](active/sampler.py) | 36 | | MHP | [CVPR 2023](https://openaccess.thecvf.com/content/CVPR2023/papers/Wang_MHPL_Minimum_Happy_Points_Learning_for_Active_Source_Free_Domain_CVPR_2023_paper.pdf) | [MHP](active/MHPsampler.py) | 37 | | LAS (ours) | [ICCV 2023](https://arxiv.org/abs/2208.12856) | [LAS](active/LASsampler.py) | 38 | 39 | 40 | | Domain Adaptation | Paper | Implementation | 41 | |-------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------:|:----------------------------:| 42 | | Fine-tuning (joint label set) | - | [ft_joint](active/solver.py) | 43 | | Fine-tuning | - | [ft](active/solver.py) | 44 | | DANN | [JMLR 2016](https://jmlr.org/papers/volume17/15-239/15-239.pdf) | [dann](active/solver.py) | 45 | | MME | [ICCV 2019](https://openaccess.thecvf.com/content_ICCV_2019/papers/Saito_Semi-Supervised_Domain_Adaptation_via_Minimax_Entropy_ICCV_2019_paper.pdf) | [mme](active/MMEsolver.py) | 46 | | MCC | [ECCV 2020](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123660460.pdf) | [MCC](active/MCCsolver.py) | 47 | | CDAC | [CVPR 2021](https://openaccess.thecvf.com/content/CVPR2021/papers/Li_Cross-Domain_Adaptive_Clustering_for_Semi-Supervised_Domain_Adaptation_CVPR_2021_paper.pdf) | [CDAC](solver/CDACsolver.py) | 48 | | RAA (ours) | [ICCV 2023](https://arxiv.org/abs/2208.12856) | [RAA](solver/PAAsolver.py) | 49 | | LAA (ours) | [ICCV 2023](https://arxiv.org/abs/2208.12856) | [LAA](solver/PAAsolver.py) | 50 | 51 | 52 | 53 | 54 | ### Training 55 | To obtain results of baseline active selection criteria on office home with 5% labeling budget, 56 | ```shell 57 | for ADA_DA in 'ft' 'mme'; do 58 | for ADA_AL in 'random' 'entropy' 'margin' 'coreset' 'leastConfidence' 'BADGE' 'AADA' 'CLUE' 'MHP'; do 59 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/baseline ADA.AL $ADA_AL ADA.DA $ADA_DA 60 | done 61 | done 62 | ``` 63 | 64 | To reproduce results of LADA on office home with 5% labeling budget, 65 | ```shell 66 | # LAS + fine-tuning with CE loss 67 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA ft 68 | # LAS + MME model adaptation 69 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA mme 70 | # LAS + Random Anchor set Augmentation (RAA) 71 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA RAA 72 | # LAS + Local context-aware Anchor set Augmentation (LAA) 73 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA LAA 74 | ``` 75 | 76 | More commands can be found in *run.sh*. 77 | 78 | ## Acknowledgements 79 | The pipline and implementation of baseline methods are adapted from [CLUE](https://github.com/virajprabhu/CLUE) and [deep-active-learning](https://github.com/ej0cl6/deep-active-learning). We adopt configuration files as [EADA](https://github.com/BIT-DA/EADA). 80 | 81 | 82 | ## Citation 83 | If you find our paper and code useful for your research, please consider citing 84 | ```bibtex 85 | @article{sun2022local, 86 | author = {Sun, Tao and Lu, Cheng and Ling, Haibin}, 87 | title = {Local Context-Aware Active Domain Adaptation}, 88 | journal = {IEEE/CVF International Conference on Computer Vision}, 89 | year = {2023} 90 | } 91 | ``` -------------------------------------------------------------------------------- /active/LASsampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from sklearn.cluster import KMeans 5 | from sklearn.metrics.pairwise import euclidean_distances 6 | 7 | from .utils import ActualSequentialSampler 8 | from .sampler import register_strategy, SamplingStrategy 9 | 10 | @register_strategy('LAS') 11 | class LASSampling(SamplingStrategy): 12 | ''' 13 | Implement Local context-aware sampling (LAS) 14 | ''' 15 | 16 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg): 17 | super(LASSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg) 18 | 19 | def query(self, n, epoch): 20 | idxs_unlabeled = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb] 21 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs_unlabeled]) 22 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler, 23 | num_workers=self.cfg.DATALOADER.NUM_WORKERS, 24 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False) 25 | # build nearest neighbors 26 | self.model.eval() 27 | all_probs = [] 28 | all_embs = [] 29 | with torch.no_grad(): 30 | for batch_idx, (data, target, _, *_) in enumerate(data_loader): 31 | data, target = data.to(self.device), target.to(self.device) 32 | scores, embs = self.model(data, with_emb=True) 33 | all_embs.append(embs.cpu()) 34 | probs = F.softmax(scores, dim=-1) 35 | all_probs.append(probs.cpu()) 36 | 37 | all_probs = torch.cat(all_probs) 38 | all_embs = F.normalize(torch.cat(all_embs), dim=-1) 39 | 40 | # get Q_score 41 | sim = all_embs.cpu().mm(all_embs.transpose(1, 0)) 42 | K = self.cfg.LADA.S_K 43 | sim_topk, topk = torch.topk(sim, k=K + 1, dim=1) 44 | sim_topk, topk = sim_topk[:, 1:], topk[:, 1:] 45 | wgt_topk = (sim_topk / sim_topk.sum(dim=1, keepdim=True)) 46 | 47 | Q_score = -((all_probs[topk] * all_probs.unsqueeze(1)).sum(-1) * wgt_topk).sum(-1) 48 | 49 | # propagate Q_score 50 | for i in range(self.cfg.LADA.S_PROP_ITER): 51 | Q_score += (wgt_topk * Q_score[topk]).sum(-1) * self.cfg.LADA.S_PROP_COEF 52 | 53 | m_idxs = Q_score.sort(descending=True)[1] 54 | 55 | # oversample and find centroids 56 | M = self.cfg.LADA.S_M 57 | m_topk = m_idxs[:n * (1 + M)] 58 | km = KMeans(n_clusters=n) 59 | km.fit(all_embs[m_topk]) 60 | dists = euclidean_distances(km.cluster_centers_, all_embs[m_topk]) 61 | sort_idxs = dists.argsort(axis=1) 62 | q_idxs = [] 63 | ax, rem = 0, n 64 | while rem > 0: 65 | q_idxs.extend(list(sort_idxs[:, ax][:rem])) 66 | q_idxs = list(set(q_idxs)) 67 | rem = n - len(q_idxs) 68 | ax += 1 69 | 70 | q_idxs = m_idxs[q_idxs].cpu().numpy() 71 | self.query_dset.rand_transform = None 72 | 73 | return idxs_unlabeled[q_idxs] 74 | 75 | -------------------------------------------------------------------------------- /active/MHPsampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from sklearn.cluster import KMeans 6 | from sklearn.metrics.pairwise import euclidean_distances 7 | 8 | from .utils import ActualSequentialSampler 9 | from .sampler import register_strategy, SamplingStrategy 10 | 11 | 12 | @register_strategy('MHP') 13 | class MHPSampling(SamplingStrategy): 14 | ''' 15 | Implements MHPL: Minimum Happy Points Learning for Active Source Free Domain Adaptation (CVPR'23) 16 | ''' 17 | 18 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg): 19 | super(MHPSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg) 20 | 21 | def query(self, n, epoch): 22 | idxs_unlabeled = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb] 23 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs_unlabeled]) 24 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler, 25 | num_workers=self.cfg.DATALOADER.NUM_WORKERS, 26 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False) 27 | self.model.eval() 28 | all_probs = [] 29 | all_embs = [] 30 | with torch.no_grad(): 31 | for batch_idx, (data, target, _) in enumerate(data_loader): 32 | data, target = data.to(self.device), target.to(self.device) 33 | scores, embs = self.model(data, with_emb=True) 34 | all_embs.append(embs.cpu()) 35 | probs = F.softmax(scores, dim=-1) 36 | all_probs.append(probs) 37 | 38 | all_probs = torch.cat(all_probs) 39 | all_embs = F.normalize(torch.cat(all_embs), dim=-1) 40 | 41 | # find KNN 42 | sim = all_embs.cpu().mm(all_embs.transpose(1, 0)) 43 | K = self.cfg.LADA.S_K 44 | sim_topk, topk = torch.topk(sim, k=K + 1, dim=1) 45 | sim_topk, topk = sim_topk[:, 1:], topk[:, 1:] 46 | 47 | # get NP scores 48 | all_preds = all_probs.argmax(-1) 49 | Sp = (torch.eye(self.num_classes)[all_preds[topk]]).sum(1) 50 | Sp = Sp / Sp.sum(-1, keepdim=True) 51 | NP = -(torch.log(Sp+1e-9)*Sp).sum(-1) 52 | 53 | # get NA scores 54 | NA = sim_topk.sum(-1) / K 55 | NAU = NP*NA 56 | sort_idxs = NAU.argsort(descending=True) 57 | 58 | q_idxs = [] 59 | ax, rem = 0, n 60 | while rem > 0: 61 | if topk[sort_idxs[ax]][0] not in q_idxs: 62 | q_idxs.append(sort_idxs[ax]) 63 | rem = n - len(q_idxs) 64 | ax += 1 65 | 66 | q_idxs = np.array(q_idxs) 67 | 68 | return idxs_unlabeled[q_idxs] 69 | 70 | -------------------------------------------------------------------------------- /active/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import * 2 | from .LASsampler import * 3 | -------------------------------------------------------------------------------- /active/budget.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class BudgetAllocator(): 4 | def __init__(self, budget, cfg): 5 | self.budget = budget 6 | self.max_epochs = cfg.TRAINER.MAX_EPOCHS 7 | self.cfg = cfg 8 | self.build_budgets() 9 | 10 | def build_budgets(self): 11 | self.budgets = np.zeros(self.cfg.TRAINER.MAX_EPOCHS, dtype=np.int32) 12 | rounds = self.cfg.ADA.ROUNDS or np.arange(0, self.cfg.TRAINER.MAX_EPOCHS, self.cfg.TRAINER.MAX_EPOCHS // self.cfg.ADA.ROUND) 13 | 14 | for r in rounds: 15 | self.budgets[r] = self.budget // len(rounds) 16 | 17 | self.budgets[rounds[-1]] += self.budget - self.budgets.sum() 18 | 19 | def get_budget(self, epoch): 20 | curr_budget = self.budgets[epoch] 21 | used_budget = self.budgets[:epoch].sum() 22 | return curr_budget, used_budget 23 | -------------------------------------------------------------------------------- /active/sampler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Implements active learning sampling strategies 4 | Adapted from https://github.com/ej0cl6/deep-active-learning 5 | """ 6 | 7 | import os 8 | import copy 9 | import random 10 | import numpy as np 11 | 12 | from sklearn.cluster import KMeans 13 | from sklearn.metrics.pairwise import euclidean_distances 14 | from sklearn.metrics import pairwise_distances 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.optim as optim 20 | from torch.utils.data.sampler import Sampler, SubsetRandomSampler 21 | from torch.utils.data import DataLoader 22 | import logging 23 | 24 | import utils.utils as utils 25 | from solver import get_solver 26 | from model import get_model 27 | from dataset.image_list import ImageList 28 | from .utils import row_norms, kmeans_plus_plus_opt, get_embedding, ActualSequentialSampler 29 | 30 | torch.manual_seed(1234) 31 | torch.cuda.manual_seed(1234) 32 | random.seed(1234) 33 | np.random.seed(1234) 34 | 35 | al_dict = {} 36 | 37 | def register_strategy(name): 38 | def decorator(cls): 39 | al_dict[name] = cls 40 | return cls 41 | 42 | return decorator 43 | 44 | 45 | def get_strategy(sample, *args): 46 | if sample not in al_dict: raise NotImplementedError 47 | return al_dict[sample](*args) 48 | 49 | 50 | class SamplingStrategy: 51 | """ 52 | Sampling Strategy wrapper class 53 | """ 54 | 55 | def __init__(self, src_dset, tgt_dset, source_model, device, num_classes, cfg): 56 | self.src_dset = src_dset 57 | self.tgt_dset = tgt_dset 58 | self.num_classes = num_classes 59 | self.model = copy.deepcopy(source_model) # initialized with source model 60 | self.device = device 61 | self.cfg = cfg 62 | self.discrim = nn.Sequential( 63 | nn.Linear(self.cfg.DATASET.NUM_CLASS, 500), 64 | nn.ReLU(), 65 | nn.Linear(500, 500), 66 | nn.ReLU(), 67 | nn.Linear(500, 2)).to(self.device) # should be initialized by running train_uda 68 | self.idxs_lb = np.zeros(len(self.tgt_dset.train_idx), dtype=bool) 69 | self.solver = None 70 | self.lr_scheduler = None 71 | self.opt_discrim = None 72 | self.opt_net_tgt = None 73 | self.query_dset = tgt_dset.get_dsets()[1] # change to query dataset 74 | 75 | def query(self, n, epoch): 76 | pass 77 | 78 | def update(self, idxs_lb): 79 | self.idxs_lb = idxs_lb 80 | 81 | def pred(self, idxs=None, with_emb=False): 82 | if idxs is None: 83 | idxs = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb] 84 | 85 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs]) 86 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler, num_workers=self.cfg.DATALOADER.NUM_WORKERS, 87 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False) 88 | self.model.eval() 89 | all_log_probs = [] 90 | all_scores = [] 91 | all_embs = [] 92 | with torch.no_grad(): 93 | for batch_idx, (data, target, _) in enumerate(data_loader): 94 | data, target = data.to(self.device), target.to(self.device) 95 | if with_emb: 96 | scores, embs = self.model(data, with_emb=True) 97 | all_embs.append(embs.cpu()) 98 | else: 99 | scores = self.model(data, with_emb=False) 100 | log_probs = nn.LogSoftmax(dim=1)(scores) 101 | all_log_probs.append(log_probs) 102 | all_scores.append(scores) 103 | 104 | all_log_probs = torch.cat(all_log_probs) 105 | all_probs = torch.exp(all_log_probs) 106 | all_scores = torch.cat(all_scores) 107 | if with_emb: 108 | all_embs = torch.cat(all_embs) 109 | return idxs, all_probs, all_log_probs, all_scores, all_embs 110 | else: 111 | return idxs, all_probs, all_log_probs, all_scores 112 | 113 | def train_uda(self, epochs=1): 114 | """ 115 | Unsupervised adaptation of source model to target at round 0 116 | Returns: 117 | Model post adaptation 118 | """ 119 | source = self.cfg.DATASET.SOURCE_DOMAIN 120 | target = self.cfg.DATASET.TARGET_DOMAIN 121 | uda_strat = self.cfg.ADA.UDA 122 | 123 | adapt_dir = os.path.join('checkpoints', 'adapt') 124 | adapt_net_file = os.path.join(adapt_dir, '{}_{}_{}_{}_{}.pth'.format(uda_strat, source, target, 125 | self.cfg.MODEL.BACKBONE.NAME, self.cfg.TRAINER.MAX_UDA_EPOCHS)) 126 | 127 | if not os.path.exists(adapt_dir): 128 | os.makedirs(adapt_dir) 129 | 130 | if self.cfg.TRAINER.LOAD_FROM_CHECKPOINT and os.path.exists(adapt_net_file): 131 | logging.info('Found pretrained uda checkpoint, loading...') 132 | adapt_model = get_model('AdaptNet', num_cls=self.num_classes, weights_init=adapt_net_file, 133 | model=self.cfg.MODEL.BACKBONE.NAME) 134 | else: 135 | src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader = self.build_loaders() 136 | 137 | src_train_loader = self.src_dset.get_loaders()[0] 138 | target_train_dset = self.tgt_dset.get_dsets()[0] 139 | train_sampler = SubsetRandomSampler(self.tgt_dset.train_idx[self.idxs_lb]) 140 | tgt_sup_loader = torch.utils.data.DataLoader(target_train_dset, sampler=train_sampler, 141 | num_workers=self.cfg.DATALOADER.NUM_WORKERS, \ 142 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False) 143 | tgt_unsup_loader = torch.utils.data.DataLoader(target_train_dset, shuffle=True, 144 | num_workers=self.cfg.DATALOADER.NUM_WORKERS, \ 145 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False) 146 | 147 | logging.info('No pretrained checkpoint found, training...') 148 | adapt_model = get_model('AdaptNet', num_cls=self.num_classes, src_weights_init=self.model, 149 | model=self.cfg.MODEL.BACKBONE.NAME, normalize=self.cfg.MODEL.NORMALIZE, temp=self.cfg.MODEL.TEMP) 150 | opt_net_tgt = utils.get_optim(self.cfg.OPTIM.UDA_NAME, adapt_model.tgt_net.parameters(self.cfg.OPTIM.UDA_LR, self.cfg.OPTIM.BASE_LR_MULT), 151 | lr=self.cfg.OPTIM.UDA_LR) 152 | uda_solver = get_solver(uda_strat, adapt_model.tgt_net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, 153 | joint_sup_loader, opt_net_tgt, False, self.device, self.cfg) 154 | 155 | for epoch in range(epochs): 156 | print("Running uda epoch {}/{}".format(epoch, epochs)) 157 | if uda_strat in ['dann']: 158 | opt_discrim = optim.Adadelta(adapt_model.discrim.parameters(), lr=self.cfg.OPTIM.UDA_LR) 159 | uda_solver.solve(epoch, adapt_model.discrim, opt_discrim) 160 | elif uda_strat in ['mme']: 161 | uda_solver.solve(epoch) 162 | else: 163 | logging.info('Warning: no uda training with {}, skipped'.format(uda_strat)) 164 | return self.model 165 | 166 | adapt_model.save(adapt_net_file) 167 | 168 | self.model = adapt_model.tgt_net 169 | return self.model 170 | 171 | def build_loaders(self): 172 | src_loader = self.src_dset.get_loaders()[0] 173 | tgt_loader = self.tgt_dset.get_loaders()[0] 174 | 175 | target_train_dset = self.tgt_dset.get_dsets()[0] 176 | train_sampler = SubsetRandomSampler(self.tgt_dset.train_idx[self.idxs_lb]) 177 | tgt_sup_loader = DataLoader(target_train_dset, sampler=train_sampler, 178 | num_workers=self.cfg.DATALOADER.NUM_WORKERS, \ 179 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False) 180 | train_sampler = SubsetRandomSampler(self.tgt_dset.train_idx[~self.idxs_lb]) 181 | tgt_unsup_loader = DataLoader(target_train_dset, sampler=train_sampler, 182 | num_workers=self.cfg.DATALOADER.NUM_WORKERS, \ 183 | batch_size=self.cfg.DATALOADER.BATCH_SIZE*self.cfg.DATALOADER.TGT_UNSUP_BS_MUL, 184 | drop_last=False) 185 | 186 | # create joint src_tgt_sup loader as commonly used 187 | joint_list = [self.src_dset.train_dataset.samples[_] for _ in self.src_dset.train_idx] + \ 188 | [self.tgt_dset.train_dataset.samples[_] for _ in self.tgt_dset.train_idx[self.idxs_lb]] 189 | 190 | # use source train transform 191 | join_transform = self.src_dset.get_dsets()[0].transform 192 | joint_train_ds = ImageList(joint_list, root=self.cfg.DATASET.ROOT, transform=join_transform) 193 | joint_sup_loader = DataLoader(joint_train_ds, batch_size=self.cfg.DATALOADER.BATCH_SIZE, shuffle=True, 194 | drop_last=False, num_workers=self.cfg.DATALOADER.NUM_WORKERS) 195 | 196 | return src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader 197 | 198 | def train(self, epoch): 199 | src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader = self.build_loaders() 200 | 201 | if self.opt_net_tgt is None: 202 | self.opt_net_tgt = utils.get_optim(self.cfg.OPTIM.NAME, self.model.parameters(self.cfg.OPTIM.ADAPT_LR, 203 | self.cfg.OPTIM.BASE_LR_MULT), lr=self.cfg.OPTIM.ADAPT_LR, weight_decay=0.00001) 204 | 205 | if self.opt_discrim is None: 206 | self.opt_discrim = utils.get_optim(self.cfg.OPTIM.NAME, self.discrim.parameters(), lr=self.cfg.OPTIM.ADAPT_LR, weight_decay=0) 207 | 208 | solver = get_solver(self.cfg.ADA.DA, self.model, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, 209 | joint_sup_loader, self.opt_net_tgt, True, self.device, self.cfg) 210 | 211 | if self.cfg.ADA.DA in ['dann']: 212 | solver.solve(epoch, self.discrim, self.opt_discrim) 213 | elif self.cfg.ADA.DA in ['LAA', 'RAA']: 214 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx) 215 | seq_query_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler, 216 | num_workers=self.cfg.DATALOADER.NUM_WORKERS, 217 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False) 218 | solver.solve(epoch, seq_query_loader) 219 | else: 220 | solver.solve(epoch) 221 | 222 | 223 | return self.model 224 | 225 | 226 | @register_strategy('random') 227 | class RandomSampling(SamplingStrategy): 228 | """ 229 | Uniform sampling 230 | """ 231 | 232 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg): 233 | super(RandomSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg) 234 | 235 | def query(self, n, epoch): 236 | return np.random.choice(np.where(self.idxs_lb == 0)[0], n, replace=False) 237 | 238 | 239 | 240 | @register_strategy('entropy') 241 | class EntropySampling(SamplingStrategy): 242 | """ 243 | Implements entropy based sampling 244 | """ 245 | 246 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg): 247 | super(EntropySampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg) 248 | 249 | def query(self, n, epoch): 250 | idxs_unlabeled, all_probs, all_log_probs, _ = self.pred() 251 | # Compute entropy 252 | entropy = -(all_probs * all_log_probs).sum(1) 253 | q_idxs = (entropy).sort(descending=True)[1][:n] 254 | q_idxs = q_idxs.cpu().numpy() 255 | return idxs_unlabeled[q_idxs] 256 | 257 | @register_strategy('margin') 258 | class MarginSampling(SamplingStrategy): 259 | """ 260 | Implements margin based sampling 261 | """ 262 | 263 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg): 264 | super(MarginSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg) 265 | 266 | def query(self, n, epoch): 267 | idxs_unlabeled, all_probs, _, _ = self.pred() 268 | # Compute BvSB margin 269 | top2 = torch.topk(all_probs, 2).values 270 | BvSB_scores = 1-(top2[:,0] - top2[:,1]) # use minus for descending sorting 271 | q_idxs = (BvSB_scores).sort(descending=True)[1][:n] 272 | q_idxs = q_idxs.cpu().numpy() 273 | return idxs_unlabeled[q_idxs] 274 | 275 | 276 | @register_strategy('leastConfidence') 277 | class LeastConfidenceSampling(SamplingStrategy): 278 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg): 279 | super(LeastConfidenceSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg) 280 | 281 | def query(self, n, epoch): 282 | idxs_unlabeled, all_probs, _, _ = self.pred() 283 | confidences = -all_probs.max(1)[0] # use minus for descending sorting 284 | q_idxs = (confidences).sort(descending=True)[1][:n] 285 | q_idxs = q_idxs.cpu().numpy() 286 | return idxs_unlabeled[q_idxs] 287 | 288 | 289 | @register_strategy('coreset') 290 | class CoreSetSampling(SamplingStrategy): 291 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg): 292 | super(CoreSetSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg) 293 | 294 | def furthest_first(self, X, X_lb, n): 295 | m = np.shape(X)[0] 296 | if np.shape(X_lb)[0] == 0: 297 | min_dist = np.tile(float("inf"), m) 298 | else: 299 | dist_ctr = pairwise_distances(X, X_lb) 300 | min_dist = np.amin(dist_ctr, axis=1) 301 | 302 | idxs = [] 303 | 304 | for i in range(n): 305 | idx = min_dist.argmax() 306 | idxs.append(idx) 307 | dist_new_ctr = pairwise_distances(X, X[[idx], :]) 308 | for j in range(m): 309 | min_dist[j] = min(min_dist[j], dist_new_ctr[j, 0]) 310 | 311 | return idxs 312 | 313 | def query(self, n, epoch): 314 | idxs = np.arange(len(self.tgt_dset.train_idx)) 315 | idxs_unlabeled, _, _, _, all_embs = self.pred(idxs=idxs, with_emb=True) 316 | all_embs = all_embs.numpy() 317 | q_idxs = self.furthest_first(all_embs[~self.idxs_lb, :], all_embs[self.idxs_lb, :], n) 318 | return idxs_unlabeled[q_idxs] 319 | 320 | 321 | @register_strategy('AADA') 322 | class AADASampling(SamplingStrategy): 323 | """ 324 | Implements Active Adversarial Domain Adaptation (https://arxiv.org/abs/1904.07848) 325 | """ 326 | 327 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg): 328 | super(AADASampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg) 329 | 330 | def query(self, n, epoch): 331 | """ 332 | s(x) = frac{1-G*_d}{G_f(x))}{G*_d(G_f(x))} [Diversity] * H(G_y(G_f(x))) [Uncertainty] 333 | """ 334 | self.model.eval() 335 | idxs_unlabeled = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb] 336 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs_unlabeled]) 337 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler, num_workers=4, batch_size=64, 338 | drop_last=False) 339 | 340 | # Get diversity and entropy 341 | all_log_probs, all_scores = [], [] 342 | with torch.no_grad(): 343 | for batch_idx, (data, target, _) in enumerate(data_loader): 344 | data, target = data.to(self.device), target.to(self.device) 345 | scores = self.model(data) 346 | log_probs = nn.LogSoftmax(dim=1)(scores) 347 | all_scores.append(scores) 348 | all_log_probs.append(log_probs) 349 | 350 | all_scores = torch.cat(all_scores) 351 | all_log_probs = torch.cat(all_log_probs) 352 | 353 | all_probs = torch.exp(all_log_probs) 354 | disc_scores = nn.Softmax(dim=1)(self.discrim(all_scores)) 355 | # Compute diversity 356 | self.D = torch.div(disc_scores[:, 0], disc_scores[:, 1]) 357 | # Compute entropy 358 | self.E = -(all_probs * all_log_probs).sum(1) 359 | scores = (self.D * self.E).sort(descending=True)[1] 360 | # Sample from top-2 % instances, as recommended by authors 361 | top_N = max(int(len(scores) * 0.02), n) 362 | q_idxs = np.random.choice(scores[:top_N].cpu().numpy(), n, replace=False) 363 | 364 | return idxs_unlabeled[q_idxs] 365 | 366 | 367 | @register_strategy('BADGE') 368 | class BADGESampling(SamplingStrategy): 369 | """ 370 | Implements BADGE: Batch Active Learning by Diverse, Uncertain Gradient Lower Bounds (https://arxiv.org/abs/1906.03671) 371 | """ 372 | 373 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg): 374 | super(BADGESampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg) 375 | 376 | def query(self, n, epoch): 377 | idxs_unlabeled = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb] 378 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs_unlabeled]) 379 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler, num_workers=self.cfg.DATALOADER.NUM_WORKERS, 380 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False) 381 | self.model.eval() 382 | 383 | if 'LeNet' in self.cfg.MODEL.BACKBONE.NAME: 384 | emb_dim = 500 385 | elif 'ResNet34' in self.cfg.MODEL.BACKBONE.NAME: 386 | emb_dim = 512 387 | elif 'ResNet50' in self.cfg.MODEL.BACKBONE.NAME: 388 | emb_dim = 256 389 | 390 | tgt_emb = torch.zeros([len(data_loader.sampler), self.num_classes]) 391 | tgt_pen_emb = torch.zeros([len(data_loader.sampler), emb_dim]) 392 | tgt_lab = torch.zeros(len(data_loader.sampler)) 393 | tgt_preds = torch.zeros(len(data_loader.sampler)) 394 | batch_sz = self.cfg.DATALOADER.BATCH_SIZE 395 | 396 | with torch.no_grad(): 397 | for batch_idx, (data, target, _) in enumerate(data_loader): 398 | data, target = data.to(self.device), target.to(self.device) 399 | e1, e2 = self.model(data, with_emb=True) 400 | tgt_pen_emb[batch_idx * batch_sz:batch_idx * batch_sz + min(batch_sz, e2.shape[0]), :] = e2.cpu() 401 | tgt_emb[batch_idx * batch_sz:batch_idx * batch_sz + min(batch_sz, e1.shape[0]), :] = e1.cpu() 402 | tgt_lab[batch_idx * batch_sz:batch_idx * batch_sz + min(batch_sz, e1.shape[0])] = target 403 | tgt_preds[batch_idx * batch_sz:batch_idx * batch_sz + min(batch_sz, e1.shape[0])] = e1.argmax(dim=1, 404 | keepdim=True).squeeze() 405 | 406 | # Compute uncertainty gradient 407 | tgt_scores = nn.Softmax(dim=1)(tgt_emb) 408 | tgt_scores_delta = torch.zeros_like(tgt_scores) 409 | tgt_scores_delta[torch.arange(len(tgt_scores_delta)), tgt_preds.long()] = 1 410 | 411 | # Uncertainty embedding 412 | badge_uncertainty = (tgt_scores - tgt_scores_delta) 413 | 414 | # Seed with maximum uncertainty example 415 | max_norm = row_norms(badge_uncertainty.cpu().numpy()).argmax() 416 | 417 | _, q_idxs = kmeans_plus_plus_opt(badge_uncertainty.cpu().numpy(), tgt_pen_emb.cpu().numpy(), n, 418 | init=[max_norm]) 419 | 420 | return idxs_unlabeled[q_idxs] 421 | 422 | 423 | @register_strategy('kmeans') 424 | class KmeansSampling(SamplingStrategy): 425 | """ 426 | Implements CLUE: CLustering via Uncertainty-weighted Embeddings 427 | """ 428 | 429 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg): 430 | super(KmeansSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg) 431 | 432 | def query(self, n, epoch): 433 | idxs_unlabeled, _, _, _, all_embs = self.pred(with_emb=True) 434 | all_embs = all_embs.numpy() 435 | 436 | # Run weighted K-means over embeddings 437 | km = KMeans(n_clusters=n) 438 | km.fit(all_embs) 439 | 440 | # use below code to match CLUE implementation 441 | # Find nearest neighbors to inferred centroids 442 | dists = euclidean_distances(km.cluster_centers_, all_embs) 443 | sort_idxs = dists.argsort(axis=1) 444 | q_idxs = [] 445 | ax, rem = 0, n 446 | while rem > 0: 447 | q_idxs.extend(list(sort_idxs[:, ax][:rem])) 448 | q_idxs = list(set(q_idxs)) 449 | rem = n - len(q_idxs) 450 | ax += 1 451 | 452 | return idxs_unlabeled[q_idxs] 453 | 454 | 455 | @register_strategy('CLUE') 456 | class CLUESampling(SamplingStrategy): 457 | """ 458 | Implements CLUE: CLustering via Uncertainty-weighted Embeddings 459 | """ 460 | 461 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg): 462 | super(CLUESampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg) 463 | self.random_state = np.random.RandomState(1234) 464 | self.T = 0.1 465 | 466 | def query(self, n, epoch): 467 | idxs_unlabeled = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb] 468 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs_unlabeled]) 469 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler, num_workers=self.cfg.DATALOADER.NUM_WORKERS, \ 470 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False) 471 | self.model.eval() 472 | 473 | if 'LeNet' in self.cfg.MODEL.BACKBONE.NAME: 474 | emb_dim = 500 475 | elif 'ResNet34' in self.cfg.MODEL.BACKBONE.NAME: 476 | emb_dim = 512 477 | elif 'ResNet50' in self.cfg.MODEL.BACKBONE.NAME: 478 | emb_dim = 256 479 | 480 | # Get embedding of target instances 481 | tgt_emb, tgt_lab, tgt_preds, tgt_pen_emb = get_embedding(self.model, data_loader, self.device, 482 | self.num_classes, \ 483 | self.cfg, with_emb=True, emb_dim=emb_dim) 484 | tgt_pen_emb = tgt_pen_emb.cpu().numpy() 485 | tgt_scores = torch.softmax(tgt_emb / self.T, dim=-1) 486 | tgt_scores += 1e-8 487 | sample_weights = -(tgt_scores * torch.log(tgt_scores)).sum(1).cpu().numpy() 488 | 489 | # Run weighted K-means over embeddings 490 | km = KMeans(n) 491 | km.fit(tgt_pen_emb, sample_weight=sample_weights) 492 | 493 | # Find nearest neighbors to inferred centroids 494 | dists = euclidean_distances(km.cluster_centers_, tgt_pen_emb) 495 | sort_idxs = dists.argsort(axis=1) 496 | q_idxs = [] 497 | ax, rem = 0, n 498 | while rem > 0: 499 | q_idxs.extend(list(sort_idxs[:, ax][:rem])) 500 | q_idxs = list(set(q_idxs)) 501 | rem = n - len(q_idxs) 502 | ax += 1 503 | 504 | return idxs_unlabeled[q_idxs] 505 | 506 | -------------------------------------------------------------------------------- /active/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | import copy 5 | 6 | def ActualSequentialLoader(subsetRandomLoader, indices=None, transform=None, batch_size=None): 7 | indices = indices if indices is not None else subsetRandomLoader.sampler.indices 8 | train_sampler = ActualSequentialSampler(indices) 9 | dataset = copy.deepcopy(subsetRandomLoader.dataset) 10 | if transform is not None: 11 | dataset.transform = transform 12 | 13 | batch_size = batch_size if batch_size is not None else subsetRandomLoader.batch_size 14 | actualSequentialLoader = torch.utils.data.DataLoader(dataset, sampler=train_sampler, 15 | num_workers=subsetRandomLoader.num_workers, 16 | batch_size=batch_size, drop_last=False) 17 | return actualSequentialLoader 18 | 19 | class ActualSequentialSampler(Sampler): 20 | r"""Samples elements sequentially, always in the same order. 21 | 22 | Arguments: 23 | data_source (Dataset): dataset to sample from 24 | """ 25 | 26 | def __init__(self, data_source): 27 | self.data_source = data_source 28 | 29 | def __iter__(self): 30 | return iter(self.data_source) 31 | 32 | def __len__(self): 33 | return len(self.data_source) 34 | 35 | def row_norms(X, squared=False): 36 | """Row-wise (squared) Euclidean norm of X. 37 | Equivalent to np.sqrt((X * X).sum(axis=1)), but also supports sparse 38 | matrices and does not create an X.shape-sized temporary. 39 | Performs no input validation. 40 | Parameters 41 | ---------- 42 | X : array_like 43 | The input array 44 | squared : bool, optional (default = False) 45 | If True, return squared norms. 46 | Returns 47 | ------- 48 | array_like 49 | The row-wise (squared) Euclidean norm of X. 50 | """ 51 | norms = np.einsum('ij,ij->i', X, X) 52 | 53 | if not squared: 54 | np.sqrt(norms, norms) 55 | return norms 56 | 57 | def outer_product_opt(c1, d1, c2, d2): 58 | """Computes euclidean distance between a1xb1 and a2xb2 without evaluating / storing cross products 59 | """ 60 | B1, B2 = c1.shape[0], c2.shape[0] 61 | t1 = np.matmul(np.matmul(c1[:, None, :], c1[:, None, :].swapaxes(2, 1)), np.matmul(d1[:, None, :], d1[:, None, :].swapaxes(2, 1))) 62 | t2 = np.matmul(np.matmul(c2[:, None, :], c2[:, None, :].swapaxes(2, 1)), np.matmul(d2[:, None, :], d2[:, None, :].swapaxes(2, 1))) 63 | t3 = np.matmul(c1, c2.T) * np.matmul(d1, d2.T) 64 | t1 = t1.reshape(B1, 1).repeat(B2, axis=1) 65 | t2 = t2.reshape(1, B2).repeat(B1, axis=0) 66 | return t1 + t2 - 2*t3 67 | 68 | def kmeans_plus_plus_opt(X1, X2, n_clusters, init=[0], random_state=np.random.RandomState(1234), n_local_trials=None): 69 | """Init n_clusters seeds according to k-means++ (adapted from scikit-learn source code) 70 | Parameters 71 | ---------- 72 | X1, X2 : array or sparse matrix 73 | The data to pick seeds for. To avoid memory copy, the input data 74 | should be double precision (dtype=np.float64). 75 | n_clusters : integer 76 | The number of seeds to choose 77 | init : list 78 | List of points already picked 79 | random_state : int, RandomState instance 80 | The generator used to initialize the centers. Use an int to make the 81 | randomness deterministic. 82 | See :term:`Glossary `. 83 | n_local_trials : integer, optional 84 | The number of seeding trials for each center (except the first), 85 | of which the one reducing inertia the most is greedily chosen. 86 | Set to None to make the number of trials depend logarithmically 87 | on the number of seeds (2+log(k)); this is the default. 88 | Notes 89 | ----- 90 | Selects initial cluster centers for k-mean clustering in a smart way 91 | to speed up convergence. see: Arthur, D. and Vassilvitskii, S. 92 | "k-means++: the advantages of careful seeding". ACM-SIAM symposium 93 | on Discrete algorithms. 2007 94 | Version ported from http://www.stanford.edu/~darthur/kMeansppTest.zip, 95 | which is the implementation used in the aforementioned paper. 96 | """ 97 | 98 | n_samples, n_feat1 = X1.shape 99 | _, n_feat2 = X2.shape 100 | # x_squared_norms = row_norms(X, squared=True) 101 | centers1 = np.empty((n_clusters+len(init)-1, n_feat1), dtype=X1.dtype) 102 | centers2 = np.empty((n_clusters+len(init)-1, n_feat2), dtype=X1.dtype) 103 | 104 | idxs = np.empty((n_clusters+len(init)-1,), dtype=np.long) 105 | 106 | # Set the number of local seeding trials if none is given 107 | if n_local_trials is None: 108 | # This is what Arthur/Vassilvitskii tried, but did not report 109 | # specific results for other than mentioning in the conclusion 110 | # that it helped. 111 | n_local_trials = 2 + int(np.log(n_clusters)) 112 | 113 | # Pick first center randomly 114 | center_id = init 115 | 116 | centers1[:len(init)] = X1[center_id] 117 | centers2[:len(init)] = X2[center_id] 118 | idxs[:len(init)] = center_id 119 | 120 | # Initialize list of closest distances and calculate current potential 121 | distance_to_candidates = outer_product_opt(centers1[:len(init)], centers2[:len(init)], X1, X2).reshape(len(init), -1) 122 | 123 | candidates_pot = distance_to_candidates.sum(axis=1) 124 | best_candidate = np.argmin(candidates_pot) 125 | current_pot = candidates_pot[best_candidate] 126 | closest_dist_sq = distance_to_candidates[best_candidate] 127 | 128 | # Pick the remaining n_clusters-1 points 129 | for c in range(len(init), len(init)+n_clusters-1): 130 | # Choose center candidates by sampling with probability proportional 131 | # to the squared distance to the closest existing center 132 | rand_vals = random_state.random_sample(n_local_trials) * current_pot 133 | candidate_ids = np.searchsorted(closest_dist_sq.cumsum(), 134 | rand_vals) 135 | # XXX: numerical imprecision can result in a candidate_id out of range 136 | np.clip(candidate_ids, None, closest_dist_sq.size - 1, 137 | out=candidate_ids) 138 | 139 | # Compute distances to center candidates 140 | distance_to_candidates = outer_product_opt(X1[candidate_ids], X2[candidate_ids], X1, X2).reshape(len(candidate_ids), -1) 141 | 142 | # update closest distances squared and potential for each candidate 143 | np.minimum(closest_dist_sq, distance_to_candidates, 144 | out=distance_to_candidates) 145 | candidates_pot = distance_to_candidates.sum(axis=1) 146 | 147 | # Decide which candidate is the best 148 | best_candidate = np.argmin(candidates_pot) 149 | current_pot = candidates_pot[best_candidate] 150 | closest_dist_sq = distance_to_candidates[best_candidate] 151 | best_candidate = candidate_ids[best_candidate] 152 | 153 | idxs[c] = best_candidate 154 | 155 | return None, idxs[len(init)-1:] 156 | 157 | def get_embedding(model, loader, device, num_classes, cfg, with_emb=False, emb_dim=512): 158 | model.eval() 159 | embedding = torch.zeros([len(loader.sampler), num_classes]) 160 | embedding_pen = torch.zeros([len(loader.sampler), emb_dim]) 161 | labels = torch.zeros(len(loader.sampler)) 162 | preds = torch.zeros(len(loader.sampler)) 163 | batch_sz = cfg.DATALOADER.BATCH_SIZE 164 | with torch.no_grad(): 165 | for batch_idx, (data, target, _) in enumerate(loader): 166 | data, target = data.to(device), target.to(device) 167 | if with_emb: 168 | e1, e2 = model(data, with_emb=True) 169 | embedding_pen[batch_idx*batch_sz:batch_idx*batch_sz + min(batch_sz, e2.shape[0]), :] = e2.cpu() 170 | else: 171 | e1 = model(data, with_emb=False) 172 | 173 | embedding[batch_idx*batch_sz:batch_idx*batch_sz + min(batch_sz, e1.shape[0]), :] = e1.cpu() 174 | labels[batch_idx*batch_sz:batch_idx*batch_sz + min(batch_sz, e1.shape[0])] = target 175 | preds[batch_idx*batch_sz:batch_idx*batch_sz + min(batch_sz, e1.shape[0])] = e1.argmax(dim=1, keepdim=True).squeeze() 176 | 177 | return embedding, labels, preds, embedding_pen 178 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg 2 | 3 | 4 | def get_cfg_default(): 5 | return cfg.clone() 6 | -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | ########################### 4 | # Config definition 5 | ########################### 6 | 7 | _C = CN() 8 | 9 | _C.SEED = 0 10 | _C.NOTE = '' 11 | 12 | ########################### 13 | # Dataset 14 | ########################### 15 | _C.DATASET = CN() 16 | # Directory where datasets are stored 17 | _C.DATASET.ROOT = '' 18 | _C.DATASET.NAME = '' 19 | # List of domains 20 | _C.DATASET.SOURCE_DOMAINS = [] 21 | _C.DATASET.TARGET_DOMAINS = [] 22 | _C.DATASET.SOURCE_DOMAIN = '' 23 | _C.DATASET.TARGET_DOMAIN = '' 24 | _C.DATASET.SOURCE_VALID_TYPE = 'val' 25 | _C.DATASET.SOURCE_VALID_RATIO = 1.0 26 | _C.DATASET.SOURCE_TRANSFORMS = ('Resize','RandomCrop','Normalize') 27 | _C.DATASET.TARGET_TRANSFORMS = ('Resize','RandomCrop','Normalize') 28 | _C.DATASET.QUERY_TRANSFORMS = ('Resize','CenterCrop','Normalize') 29 | _C.DATASET.TEST_TRANSFORMS = ('Resize','CenterCrop','Normalize') 30 | _C.DATASET.RAND_TRANSFORMS = 'rand_transform' 31 | _C.DATASET.NUM_CLASS = 12 32 | 33 | ########################### 34 | # Dataloader 35 | ########################### 36 | _C.DATALOADER = CN() 37 | _C.DATALOADER.NUM_WORKERS = 4 38 | _C.DATALOADER.BATCH_SIZE = 32 39 | _C.DATALOADER.TGT_UNSUP_BS_MUL = 1 40 | 41 | ########################### 42 | # Model 43 | ########################### 44 | _C.MODEL = CN() 45 | # Path to model weights (for initialization) 46 | _C.MODEL.INIT_WEIGHTS = '' 47 | _C.MODEL.BACKBONE = CN() 48 | _C.MODEL.BACKBONE.NAME = 'ResNet50Fc' 49 | _C.MODEL.BACKBONE.PRETRAINED = True 50 | _C.MODEL.NORMALIZE = False 51 | _C.MODEL.TEMP = 0.05 52 | 53 | ########################### 54 | # Optimization 55 | ########################### 56 | _C.OPTIM = CN() 57 | _C.OPTIM.NAME = 'Adadelta' 58 | _C.OPTIM.SOURCE_NAME = 'Adadelta' 59 | _C.OPTIM.UDA_NAME = 'Adadelta' 60 | _C.OPTIM.SOURCE_LR = 0.1 61 | _C.OPTIM.UDA_LR = 0.1 62 | _C.OPTIM.ADAPT_LR = 0.1 63 | _C.OPTIM.BASE_LR_MULT = 0.1 64 | 65 | ########################### 66 | # Trainer specifics 67 | ########################### 68 | _C.TRAINER = CN() 69 | _C.TRAINER.LOAD_FROM_CHECKPOINT = True 70 | _C.TRAINER.TRAIN_ON_SOURCE = True 71 | _C.TRAINER.MAX_SOURCE_EPOCHS = 20 72 | _C.TRAINER.MAX_UDA_EPOCHS = 20 73 | _C.TRAINER.MAX_EPOCHS = 20 74 | _C.TRAINER.EVAL_ACC = True 75 | _C.TRAINER.ITER_PER_EPOCH = None 76 | 77 | ########################### 78 | # Active DA 79 | ########################### 80 | _C.ADA = CN() 81 | _C.ADA.TASKS = None 82 | _C.ADA.BUDGET = 0.05 83 | _C.ADA.ROUND = 5 84 | _C.ADA.ROUNDS = None 85 | _C.ADA.UDA = 'dann' 86 | _C.ADA.DA = 'ft' 87 | _C.ADA.AL = 'LAS' 88 | _C.ADA.SRC_SUP_WT = 1.0 89 | _C.ADA.TGT_SUP_WT = 1.0 90 | _C.ADA.UNSUP_WT = 0.1 91 | _C.ADA.CEN_WT = 0.1 92 | 93 | ########################### 94 | # LADA 95 | ########################### 96 | _C.LADA = CN() 97 | _C.LADA.S_K = 10 98 | _C.LADA.S_M = 10 99 | _C.LADA.S_PROP_ITER = 1 100 | _C.LADA.S_PROP_COEF = 1.0 101 | _C.LADA.A_K = 10 102 | _C.LADA.A_TH = 0.9 103 | _C.LADA.A_RAND_NUM = 1 104 | -------------------------------------------------------------------------------- /configs/office31.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | ROOT: 'data/' 3 | NAME: 'office31' 4 | SOURCE_DOMAINS: ['webcam', 'amazon', 'dslr'] 5 | TARGET_DOMAINS: ['webcam', 'amazon', 'dslr'] 6 | NUM_CLASS: 31 7 | 8 | DATALOADER: 9 | BATCH_SIZE: 32 10 | 11 | OPTIM: 12 | NAME: 'Adadelta' 13 | SOURCE_LR: 0.1 14 | BASE_LR_MULT: 0.1 15 | 16 | TRAINER: 17 | MAX_EPOCHS: 40 18 | TRAIN_ON_SOURCE : False 19 | MAX_UDA_EPOCHS: 0 20 | 21 | ADA: 22 | DA : 'ft' 23 | AL : 'random' 24 | ROUNDS : [10, 12, 14, 16, 18] 25 | 26 | LADA: 27 | S_K : 5 28 | S_M : 10 29 | A_K : 5 30 | 31 | SEED: 0 # 0,1,2,3,4 for five random experiments 32 | 33 | -------------------------------------------------------------------------------- /configs/officehome.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | ROOT: 'data/' 3 | NAME: 'officehome' 4 | SOURCE_DOMAINS: ['Art', 'Clipart', 'Product', 'RealWorld'] 5 | TARGET_DOMAINS: ['Art', 'Clipart', 'Product', 'RealWorld'] 6 | NUM_CLASS: 65 7 | 8 | DATALOADER: 9 | BATCH_SIZE: 32 10 | 11 | OPTIM: 12 | NAME: 'Adadelta' 13 | SOURCE_LR: 0.1 14 | BASE_LR_MULT: 0.1 15 | 16 | TRAINER: 17 | MAX_EPOCHS: 40 18 | TRAIN_ON_SOURCE : False 19 | MAX_UDA_EPOCHS: 0 20 | 21 | ADA: 22 | DA : 'ft' 23 | AL : 'random' 24 | ROUNDS : [10, 12, 14, 16, 18] 25 | 26 | SEED: 0 # 0,1,2,3,4 for five random experiments 27 | 28 | -------------------------------------------------------------------------------- /configs/officehome_RSUT.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | ROOT: 'data/' 3 | NAME: 'officehome_RSUT' 4 | SOURCE_DOMAINS: ['Clipart_RS', 'Product_RS', 'RealWorld_RS'] 5 | TARGET_DOMAINS: ['Clipart_UT', 'Product_UT', 'RealWorld_UT'] 6 | NUM_CLASS: 65 7 | 8 | DATALOADER: 9 | BATCH_SIZE: 32 10 | 11 | OPTIM: 12 | NAME: 'Adadelta' 13 | SOURCE_LR: 0.1 14 | BASE_LR_MULT: 0.1 15 | 16 | TRAINER: 17 | MAX_EPOCHS: 40 18 | TRAIN_ON_SOURCE : False 19 | MAX_UDA_EPOCHS: 0 20 | 21 | ADA: 22 | TASKS: [['Clipart_RS','Product_UT'], ['Clipart_RS','RealWorld_UT'], ['Product_RS','Clipart_UT'], ['Product_RS','RealWorld_UT'], ['RealWorld_RS','Clipart_UT'], ['RealWorld_RS','Product_UT']] 23 | DA : 'ft' 24 | AL : 'random' 25 | ROUNDS : [10, 12, 14, 16, 18] 26 | 27 | SEED: 0 # 0,1,2,3,4 for five random experiments 28 | 29 | -------------------------------------------------------------------------------- /configs/visda.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | ROOT: 'data/' 3 | NAME: 'visda' 4 | SOURCE_DOMAINS: ['train'] 5 | TARGET_DOMAINS: ['validation'] 6 | NUM_CLASS: 12 7 | 8 | DATALOADER: 9 | BATCH_SIZE: 32 10 | 11 | OPTIM: 12 | NAME: 'Adadelta' 13 | SOURCE_LR: 0.1 14 | BASE_LR_MULT: 0.1 15 | 16 | TRAINER: 17 | MAX_EPOCHS: 10 18 | TRAIN_ON_SOURCE : True 19 | MAX_UDA_EPOCHS: 0 20 | MAX_SOURCE_EPOCHS : 1 21 | 22 | ADA: 23 | DA : 'ft' 24 | AL : 'random' 25 | ROUNDS : [0,2,4,6,8] 26 | 27 | LADA: 28 | S_K : 10 29 | S_M : 5 30 | A_K : 10 31 | 32 | SEED: 0 # 0,1,2,3,4 for five random experiments 33 | 34 | -------------------------------------------------------------------------------- /data/image_list/office31/dslr.txt: -------------------------------------------------------------------------------- 1 | office31/dslr/images/calculator/frame_0001.jpg 5 2 | office31/dslr/images/calculator/frame_0002.jpg 5 3 | office31/dslr/images/calculator/frame_0003.jpg 5 4 | office31/dslr/images/calculator/frame_0004.jpg 5 5 | office31/dslr/images/calculator/frame_0005.jpg 5 6 | office31/dslr/images/calculator/frame_0006.jpg 5 7 | office31/dslr/images/calculator/frame_0007.jpg 5 8 | office31/dslr/images/calculator/frame_0008.jpg 5 9 | office31/dslr/images/calculator/frame_0009.jpg 5 10 | office31/dslr/images/calculator/frame_0010.jpg 5 11 | office31/dslr/images/calculator/frame_0011.jpg 5 12 | office31/dslr/images/calculator/frame_0012.jpg 5 13 | office31/dslr/images/ring_binder/frame_0001.jpg 24 14 | office31/dslr/images/ring_binder/frame_0002.jpg 24 15 | office31/dslr/images/ring_binder/frame_0003.jpg 24 16 | office31/dslr/images/ring_binder/frame_0004.jpg 24 17 | office31/dslr/images/ring_binder/frame_0005.jpg 24 18 | office31/dslr/images/ring_binder/frame_0006.jpg 24 19 | office31/dslr/images/ring_binder/frame_0007.jpg 24 20 | office31/dslr/images/ring_binder/frame_0008.jpg 24 21 | office31/dslr/images/ring_binder/frame_0009.jpg 24 22 | office31/dslr/images/ring_binder/frame_0010.jpg 24 23 | office31/dslr/images/printer/frame_0001.jpg 21 24 | office31/dslr/images/printer/frame_0002.jpg 21 25 | office31/dslr/images/printer/frame_0003.jpg 21 26 | office31/dslr/images/printer/frame_0004.jpg 21 27 | office31/dslr/images/printer/frame_0005.jpg 21 28 | office31/dslr/images/printer/frame_0006.jpg 21 29 | office31/dslr/images/printer/frame_0007.jpg 21 30 | office31/dslr/images/printer/frame_0008.jpg 21 31 | office31/dslr/images/printer/frame_0009.jpg 21 32 | office31/dslr/images/printer/frame_0010.jpg 21 33 | office31/dslr/images/printer/frame_0011.jpg 21 34 | office31/dslr/images/printer/frame_0012.jpg 21 35 | office31/dslr/images/printer/frame_0013.jpg 21 36 | office31/dslr/images/printer/frame_0014.jpg 21 37 | office31/dslr/images/printer/frame_0015.jpg 21 38 | office31/dslr/images/keyboard/frame_0001.jpg 11 39 | office31/dslr/images/keyboard/frame_0002.jpg 11 40 | office31/dslr/images/keyboard/frame_0003.jpg 11 41 | office31/dslr/images/keyboard/frame_0004.jpg 11 42 | office31/dslr/images/keyboard/frame_0005.jpg 11 43 | office31/dslr/images/keyboard/frame_0006.jpg 11 44 | office31/dslr/images/keyboard/frame_0007.jpg 11 45 | office31/dslr/images/keyboard/frame_0008.jpg 11 46 | office31/dslr/images/keyboard/frame_0009.jpg 11 47 | office31/dslr/images/keyboard/frame_0010.jpg 11 48 | office31/dslr/images/scissors/frame_0001.jpg 26 49 | office31/dslr/images/scissors/frame_0002.jpg 26 50 | office31/dslr/images/scissors/frame_0003.jpg 26 51 | office31/dslr/images/scissors/frame_0004.jpg 26 52 | office31/dslr/images/scissors/frame_0005.jpg 26 53 | office31/dslr/images/scissors/frame_0006.jpg 26 54 | office31/dslr/images/scissors/frame_0007.jpg 26 55 | office31/dslr/images/scissors/frame_0008.jpg 26 56 | office31/dslr/images/scissors/frame_0009.jpg 26 57 | office31/dslr/images/scissors/frame_0010.jpg 26 58 | office31/dslr/images/scissors/frame_0011.jpg 26 59 | office31/dslr/images/scissors/frame_0012.jpg 26 60 | office31/dslr/images/scissors/frame_0013.jpg 26 61 | office31/dslr/images/scissors/frame_0014.jpg 26 62 | office31/dslr/images/scissors/frame_0015.jpg 26 63 | office31/dslr/images/scissors/frame_0016.jpg 26 64 | office31/dslr/images/scissors/frame_0017.jpg 26 65 | office31/dslr/images/scissors/frame_0018.jpg 26 66 | office31/dslr/images/laptop_computer/frame_0001.jpg 12 67 | office31/dslr/images/laptop_computer/frame_0002.jpg 12 68 | office31/dslr/images/laptop_computer/frame_0003.jpg 12 69 | office31/dslr/images/laptop_computer/frame_0004.jpg 12 70 | office31/dslr/images/laptop_computer/frame_0005.jpg 12 71 | office31/dslr/images/laptop_computer/frame_0006.jpg 12 72 | office31/dslr/images/laptop_computer/frame_0007.jpg 12 73 | office31/dslr/images/laptop_computer/frame_0008.jpg 12 74 | office31/dslr/images/laptop_computer/frame_0009.jpg 12 75 | office31/dslr/images/laptop_computer/frame_0010.jpg 12 76 | office31/dslr/images/laptop_computer/frame_0011.jpg 12 77 | office31/dslr/images/laptop_computer/frame_0012.jpg 12 78 | office31/dslr/images/laptop_computer/frame_0013.jpg 12 79 | office31/dslr/images/laptop_computer/frame_0014.jpg 12 80 | office31/dslr/images/laptop_computer/frame_0015.jpg 12 81 | office31/dslr/images/laptop_computer/frame_0016.jpg 12 82 | office31/dslr/images/laptop_computer/frame_0017.jpg 12 83 | office31/dslr/images/laptop_computer/frame_0018.jpg 12 84 | office31/dslr/images/laptop_computer/frame_0019.jpg 12 85 | office31/dslr/images/laptop_computer/frame_0020.jpg 12 86 | office31/dslr/images/laptop_computer/frame_0021.jpg 12 87 | office31/dslr/images/laptop_computer/frame_0022.jpg 12 88 | office31/dslr/images/laptop_computer/frame_0023.jpg 12 89 | office31/dslr/images/laptop_computer/frame_0024.jpg 12 90 | office31/dslr/images/mouse/frame_0001.jpg 16 91 | office31/dslr/images/mouse/frame_0002.jpg 16 92 | office31/dslr/images/mouse/frame_0003.jpg 16 93 | office31/dslr/images/mouse/frame_0004.jpg 16 94 | office31/dslr/images/mouse/frame_0005.jpg 16 95 | office31/dslr/images/mouse/frame_0006.jpg 16 96 | office31/dslr/images/mouse/frame_0007.jpg 16 97 | office31/dslr/images/mouse/frame_0008.jpg 16 98 | office31/dslr/images/mouse/frame_0009.jpg 16 99 | office31/dslr/images/mouse/frame_0010.jpg 16 100 | office31/dslr/images/mouse/frame_0011.jpg 16 101 | office31/dslr/images/mouse/frame_0012.jpg 16 102 | office31/dslr/images/monitor/frame_0001.jpg 15 103 | office31/dslr/images/monitor/frame_0002.jpg 15 104 | office31/dslr/images/monitor/frame_0003.jpg 15 105 | office31/dslr/images/monitor/frame_0004.jpg 15 106 | office31/dslr/images/monitor/frame_0005.jpg 15 107 | office31/dslr/images/monitor/frame_0006.jpg 15 108 | office31/dslr/images/monitor/frame_0007.jpg 15 109 | office31/dslr/images/monitor/frame_0008.jpg 15 110 | office31/dslr/images/monitor/frame_0009.jpg 15 111 | office31/dslr/images/monitor/frame_0010.jpg 15 112 | office31/dslr/images/monitor/frame_0011.jpg 15 113 | office31/dslr/images/monitor/frame_0012.jpg 15 114 | office31/dslr/images/monitor/frame_0013.jpg 15 115 | office31/dslr/images/monitor/frame_0014.jpg 15 116 | office31/dslr/images/monitor/frame_0015.jpg 15 117 | office31/dslr/images/monitor/frame_0016.jpg 15 118 | office31/dslr/images/monitor/frame_0017.jpg 15 119 | office31/dslr/images/monitor/frame_0018.jpg 15 120 | office31/dslr/images/monitor/frame_0019.jpg 15 121 | office31/dslr/images/monitor/frame_0020.jpg 15 122 | office31/dslr/images/monitor/frame_0021.jpg 15 123 | office31/dslr/images/monitor/frame_0022.jpg 15 124 | office31/dslr/images/mug/frame_0001.jpg 17 125 | office31/dslr/images/mug/frame_0002.jpg 17 126 | office31/dslr/images/mug/frame_0003.jpg 17 127 | office31/dslr/images/mug/frame_0004.jpg 17 128 | office31/dslr/images/mug/frame_0005.jpg 17 129 | office31/dslr/images/mug/frame_0006.jpg 17 130 | office31/dslr/images/mug/frame_0007.jpg 17 131 | office31/dslr/images/mug/frame_0008.jpg 17 132 | office31/dslr/images/tape_dispenser/frame_0001.jpg 29 133 | office31/dslr/images/tape_dispenser/frame_0002.jpg 29 134 | office31/dslr/images/tape_dispenser/frame_0003.jpg 29 135 | office31/dslr/images/tape_dispenser/frame_0004.jpg 29 136 | office31/dslr/images/tape_dispenser/frame_0005.jpg 29 137 | office31/dslr/images/tape_dispenser/frame_0006.jpg 29 138 | office31/dslr/images/tape_dispenser/frame_0007.jpg 29 139 | office31/dslr/images/tape_dispenser/frame_0008.jpg 29 140 | office31/dslr/images/tape_dispenser/frame_0009.jpg 29 141 | office31/dslr/images/tape_dispenser/frame_0010.jpg 29 142 | office31/dslr/images/tape_dispenser/frame_0011.jpg 29 143 | office31/dslr/images/tape_dispenser/frame_0012.jpg 29 144 | office31/dslr/images/tape_dispenser/frame_0013.jpg 29 145 | office31/dslr/images/tape_dispenser/frame_0014.jpg 29 146 | office31/dslr/images/tape_dispenser/frame_0015.jpg 29 147 | office31/dslr/images/tape_dispenser/frame_0016.jpg 29 148 | office31/dslr/images/tape_dispenser/frame_0017.jpg 29 149 | office31/dslr/images/tape_dispenser/frame_0018.jpg 29 150 | office31/dslr/images/tape_dispenser/frame_0019.jpg 29 151 | office31/dslr/images/tape_dispenser/frame_0020.jpg 29 152 | office31/dslr/images/tape_dispenser/frame_0021.jpg 29 153 | office31/dslr/images/tape_dispenser/frame_0022.jpg 29 154 | office31/dslr/images/pen/frame_0001.jpg 19 155 | office31/dslr/images/pen/frame_0002.jpg 19 156 | office31/dslr/images/pen/frame_0003.jpg 19 157 | office31/dslr/images/pen/frame_0004.jpg 19 158 | office31/dslr/images/pen/frame_0005.jpg 19 159 | office31/dslr/images/pen/frame_0006.jpg 19 160 | office31/dslr/images/pen/frame_0007.jpg 19 161 | office31/dslr/images/pen/frame_0008.jpg 19 162 | office31/dslr/images/pen/frame_0009.jpg 19 163 | office31/dslr/images/pen/frame_0010.jpg 19 164 | office31/dslr/images/bike/frame_0001.jpg 1 165 | office31/dslr/images/bike/frame_0002.jpg 1 166 | office31/dslr/images/bike/frame_0003.jpg 1 167 | office31/dslr/images/bike/frame_0004.jpg 1 168 | office31/dslr/images/bike/frame_0005.jpg 1 169 | office31/dslr/images/bike/frame_0006.jpg 1 170 | office31/dslr/images/bike/frame_0007.jpg 1 171 | office31/dslr/images/bike/frame_0008.jpg 1 172 | office31/dslr/images/bike/frame_0009.jpg 1 173 | office31/dslr/images/bike/frame_0010.jpg 1 174 | office31/dslr/images/bike/frame_0011.jpg 1 175 | office31/dslr/images/bike/frame_0012.jpg 1 176 | office31/dslr/images/bike/frame_0013.jpg 1 177 | office31/dslr/images/bike/frame_0014.jpg 1 178 | office31/dslr/images/bike/frame_0015.jpg 1 179 | office31/dslr/images/bike/frame_0016.jpg 1 180 | office31/dslr/images/bike/frame_0017.jpg 1 181 | office31/dslr/images/bike/frame_0018.jpg 1 182 | office31/dslr/images/bike/frame_0019.jpg 1 183 | office31/dslr/images/bike/frame_0020.jpg 1 184 | office31/dslr/images/bike/frame_0021.jpg 1 185 | office31/dslr/images/punchers/frame_0001.jpg 23 186 | office31/dslr/images/punchers/frame_0002.jpg 23 187 | office31/dslr/images/punchers/frame_0003.jpg 23 188 | office31/dslr/images/punchers/frame_0004.jpg 23 189 | office31/dslr/images/punchers/frame_0005.jpg 23 190 | office31/dslr/images/punchers/frame_0006.jpg 23 191 | office31/dslr/images/punchers/frame_0007.jpg 23 192 | office31/dslr/images/punchers/frame_0008.jpg 23 193 | office31/dslr/images/punchers/frame_0009.jpg 23 194 | office31/dslr/images/punchers/frame_0010.jpg 23 195 | office31/dslr/images/punchers/frame_0011.jpg 23 196 | office31/dslr/images/punchers/frame_0012.jpg 23 197 | office31/dslr/images/punchers/frame_0013.jpg 23 198 | office31/dslr/images/punchers/frame_0014.jpg 23 199 | office31/dslr/images/punchers/frame_0015.jpg 23 200 | office31/dslr/images/punchers/frame_0016.jpg 23 201 | office31/dslr/images/punchers/frame_0017.jpg 23 202 | office31/dslr/images/punchers/frame_0018.jpg 23 203 | office31/dslr/images/back_pack/frame_0001.jpg 0 204 | office31/dslr/images/back_pack/frame_0002.jpg 0 205 | office31/dslr/images/back_pack/frame_0003.jpg 0 206 | office31/dslr/images/back_pack/frame_0004.jpg 0 207 | office31/dslr/images/back_pack/frame_0005.jpg 0 208 | office31/dslr/images/back_pack/frame_0006.jpg 0 209 | office31/dslr/images/back_pack/frame_0007.jpg 0 210 | office31/dslr/images/back_pack/frame_0008.jpg 0 211 | office31/dslr/images/back_pack/frame_0009.jpg 0 212 | office31/dslr/images/back_pack/frame_0010.jpg 0 213 | office31/dslr/images/back_pack/frame_0011.jpg 0 214 | office31/dslr/images/back_pack/frame_0012.jpg 0 215 | office31/dslr/images/desktop_computer/frame_0001.jpg 8 216 | office31/dslr/images/desktop_computer/frame_0002.jpg 8 217 | office31/dslr/images/desktop_computer/frame_0003.jpg 8 218 | office31/dslr/images/desktop_computer/frame_0004.jpg 8 219 | office31/dslr/images/desktop_computer/frame_0005.jpg 8 220 | office31/dslr/images/desktop_computer/frame_0006.jpg 8 221 | office31/dslr/images/desktop_computer/frame_0007.jpg 8 222 | office31/dslr/images/desktop_computer/frame_0008.jpg 8 223 | office31/dslr/images/desktop_computer/frame_0009.jpg 8 224 | office31/dslr/images/desktop_computer/frame_0010.jpg 8 225 | office31/dslr/images/desktop_computer/frame_0011.jpg 8 226 | office31/dslr/images/desktop_computer/frame_0012.jpg 8 227 | office31/dslr/images/desktop_computer/frame_0013.jpg 8 228 | office31/dslr/images/desktop_computer/frame_0014.jpg 8 229 | office31/dslr/images/desktop_computer/frame_0015.jpg 8 230 | office31/dslr/images/speaker/frame_0001.jpg 27 231 | office31/dslr/images/speaker/frame_0002.jpg 27 232 | office31/dslr/images/speaker/frame_0003.jpg 27 233 | office31/dslr/images/speaker/frame_0004.jpg 27 234 | office31/dslr/images/speaker/frame_0005.jpg 27 235 | office31/dslr/images/speaker/frame_0006.jpg 27 236 | office31/dslr/images/speaker/frame_0007.jpg 27 237 | office31/dslr/images/speaker/frame_0008.jpg 27 238 | office31/dslr/images/speaker/frame_0009.jpg 27 239 | office31/dslr/images/speaker/frame_0010.jpg 27 240 | office31/dslr/images/speaker/frame_0011.jpg 27 241 | office31/dslr/images/speaker/frame_0012.jpg 27 242 | office31/dslr/images/speaker/frame_0013.jpg 27 243 | office31/dslr/images/speaker/frame_0014.jpg 27 244 | office31/dslr/images/speaker/frame_0015.jpg 27 245 | office31/dslr/images/speaker/frame_0016.jpg 27 246 | office31/dslr/images/speaker/frame_0017.jpg 27 247 | office31/dslr/images/speaker/frame_0018.jpg 27 248 | office31/dslr/images/speaker/frame_0019.jpg 27 249 | office31/dslr/images/speaker/frame_0020.jpg 27 250 | office31/dslr/images/speaker/frame_0021.jpg 27 251 | office31/dslr/images/speaker/frame_0022.jpg 27 252 | office31/dslr/images/speaker/frame_0023.jpg 27 253 | office31/dslr/images/speaker/frame_0024.jpg 27 254 | office31/dslr/images/speaker/frame_0025.jpg 27 255 | office31/dslr/images/speaker/frame_0026.jpg 27 256 | office31/dslr/images/mobile_phone/frame_0001.jpg 14 257 | office31/dslr/images/mobile_phone/frame_0002.jpg 14 258 | office31/dslr/images/mobile_phone/frame_0003.jpg 14 259 | office31/dslr/images/mobile_phone/frame_0004.jpg 14 260 | office31/dslr/images/mobile_phone/frame_0005.jpg 14 261 | office31/dslr/images/mobile_phone/frame_0006.jpg 14 262 | office31/dslr/images/mobile_phone/frame_0007.jpg 14 263 | office31/dslr/images/mobile_phone/frame_0008.jpg 14 264 | office31/dslr/images/mobile_phone/frame_0009.jpg 14 265 | office31/dslr/images/mobile_phone/frame_0010.jpg 14 266 | office31/dslr/images/mobile_phone/frame_0011.jpg 14 267 | office31/dslr/images/mobile_phone/frame_0012.jpg 14 268 | office31/dslr/images/mobile_phone/frame_0013.jpg 14 269 | office31/dslr/images/mobile_phone/frame_0014.jpg 14 270 | office31/dslr/images/mobile_phone/frame_0015.jpg 14 271 | office31/dslr/images/mobile_phone/frame_0016.jpg 14 272 | office31/dslr/images/mobile_phone/frame_0017.jpg 14 273 | office31/dslr/images/mobile_phone/frame_0018.jpg 14 274 | office31/dslr/images/mobile_phone/frame_0019.jpg 14 275 | office31/dslr/images/mobile_phone/frame_0020.jpg 14 276 | office31/dslr/images/mobile_phone/frame_0021.jpg 14 277 | office31/dslr/images/mobile_phone/frame_0022.jpg 14 278 | office31/dslr/images/mobile_phone/frame_0023.jpg 14 279 | office31/dslr/images/mobile_phone/frame_0024.jpg 14 280 | office31/dslr/images/mobile_phone/frame_0025.jpg 14 281 | office31/dslr/images/mobile_phone/frame_0026.jpg 14 282 | office31/dslr/images/mobile_phone/frame_0027.jpg 14 283 | office31/dslr/images/mobile_phone/frame_0028.jpg 14 284 | office31/dslr/images/mobile_phone/frame_0029.jpg 14 285 | office31/dslr/images/mobile_phone/frame_0030.jpg 14 286 | office31/dslr/images/mobile_phone/frame_0031.jpg 14 287 | office31/dslr/images/paper_notebook/frame_0001.jpg 18 288 | office31/dslr/images/paper_notebook/frame_0002.jpg 18 289 | office31/dslr/images/paper_notebook/frame_0003.jpg 18 290 | office31/dslr/images/paper_notebook/frame_0004.jpg 18 291 | office31/dslr/images/paper_notebook/frame_0005.jpg 18 292 | office31/dslr/images/paper_notebook/frame_0006.jpg 18 293 | office31/dslr/images/paper_notebook/frame_0007.jpg 18 294 | office31/dslr/images/paper_notebook/frame_0008.jpg 18 295 | office31/dslr/images/paper_notebook/frame_0009.jpg 18 296 | office31/dslr/images/paper_notebook/frame_0010.jpg 18 297 | office31/dslr/images/ruler/frame_0001.jpg 25 298 | office31/dslr/images/ruler/frame_0002.jpg 25 299 | office31/dslr/images/ruler/frame_0003.jpg 25 300 | office31/dslr/images/ruler/frame_0004.jpg 25 301 | office31/dslr/images/ruler/frame_0005.jpg 25 302 | office31/dslr/images/ruler/frame_0006.jpg 25 303 | office31/dslr/images/ruler/frame_0007.jpg 25 304 | office31/dslr/images/letter_tray/frame_0001.jpg 13 305 | office31/dslr/images/letter_tray/frame_0002.jpg 13 306 | office31/dslr/images/letter_tray/frame_0003.jpg 13 307 | office31/dslr/images/letter_tray/frame_0004.jpg 13 308 | office31/dslr/images/letter_tray/frame_0005.jpg 13 309 | office31/dslr/images/letter_tray/frame_0006.jpg 13 310 | office31/dslr/images/letter_tray/frame_0007.jpg 13 311 | office31/dslr/images/letter_tray/frame_0008.jpg 13 312 | office31/dslr/images/letter_tray/frame_0009.jpg 13 313 | office31/dslr/images/letter_tray/frame_0010.jpg 13 314 | office31/dslr/images/letter_tray/frame_0011.jpg 13 315 | office31/dslr/images/letter_tray/frame_0012.jpg 13 316 | office31/dslr/images/letter_tray/frame_0013.jpg 13 317 | office31/dslr/images/letter_tray/frame_0014.jpg 13 318 | office31/dslr/images/letter_tray/frame_0015.jpg 13 319 | office31/dslr/images/letter_tray/frame_0016.jpg 13 320 | office31/dslr/images/file_cabinet/frame_0001.jpg 9 321 | office31/dslr/images/file_cabinet/frame_0002.jpg 9 322 | office31/dslr/images/file_cabinet/frame_0003.jpg 9 323 | office31/dslr/images/file_cabinet/frame_0004.jpg 9 324 | office31/dslr/images/file_cabinet/frame_0005.jpg 9 325 | office31/dslr/images/file_cabinet/frame_0006.jpg 9 326 | office31/dslr/images/file_cabinet/frame_0007.jpg 9 327 | office31/dslr/images/file_cabinet/frame_0008.jpg 9 328 | office31/dslr/images/file_cabinet/frame_0009.jpg 9 329 | office31/dslr/images/file_cabinet/frame_0010.jpg 9 330 | office31/dslr/images/file_cabinet/frame_0011.jpg 9 331 | office31/dslr/images/file_cabinet/frame_0012.jpg 9 332 | office31/dslr/images/file_cabinet/frame_0013.jpg 9 333 | office31/dslr/images/file_cabinet/frame_0014.jpg 9 334 | office31/dslr/images/file_cabinet/frame_0015.jpg 9 335 | office31/dslr/images/phone/frame_0001.jpg 20 336 | office31/dslr/images/phone/frame_0002.jpg 20 337 | office31/dslr/images/phone/frame_0003.jpg 20 338 | office31/dslr/images/phone/frame_0004.jpg 20 339 | office31/dslr/images/phone/frame_0005.jpg 20 340 | office31/dslr/images/phone/frame_0006.jpg 20 341 | office31/dslr/images/phone/frame_0007.jpg 20 342 | office31/dslr/images/phone/frame_0008.jpg 20 343 | office31/dslr/images/phone/frame_0009.jpg 20 344 | office31/dslr/images/phone/frame_0010.jpg 20 345 | office31/dslr/images/phone/frame_0011.jpg 20 346 | office31/dslr/images/phone/frame_0012.jpg 20 347 | office31/dslr/images/phone/frame_0013.jpg 20 348 | office31/dslr/images/bookcase/frame_0001.jpg 3 349 | office31/dslr/images/bookcase/frame_0002.jpg 3 350 | office31/dslr/images/bookcase/frame_0003.jpg 3 351 | office31/dslr/images/bookcase/frame_0004.jpg 3 352 | office31/dslr/images/bookcase/frame_0005.jpg 3 353 | office31/dslr/images/bookcase/frame_0006.jpg 3 354 | office31/dslr/images/bookcase/frame_0007.jpg 3 355 | office31/dslr/images/bookcase/frame_0008.jpg 3 356 | office31/dslr/images/bookcase/frame_0009.jpg 3 357 | office31/dslr/images/bookcase/frame_0010.jpg 3 358 | office31/dslr/images/bookcase/frame_0011.jpg 3 359 | office31/dslr/images/bookcase/frame_0012.jpg 3 360 | office31/dslr/images/projector/frame_0001.jpg 22 361 | office31/dslr/images/projector/frame_0002.jpg 22 362 | office31/dslr/images/projector/frame_0003.jpg 22 363 | office31/dslr/images/projector/frame_0004.jpg 22 364 | office31/dslr/images/projector/frame_0005.jpg 22 365 | office31/dslr/images/projector/frame_0006.jpg 22 366 | office31/dslr/images/projector/frame_0007.jpg 22 367 | office31/dslr/images/projector/frame_0008.jpg 22 368 | office31/dslr/images/projector/frame_0009.jpg 22 369 | office31/dslr/images/projector/frame_0010.jpg 22 370 | office31/dslr/images/projector/frame_0011.jpg 22 371 | office31/dslr/images/projector/frame_0012.jpg 22 372 | office31/dslr/images/projector/frame_0013.jpg 22 373 | office31/dslr/images/projector/frame_0014.jpg 22 374 | office31/dslr/images/projector/frame_0015.jpg 22 375 | office31/dslr/images/projector/frame_0016.jpg 22 376 | office31/dslr/images/projector/frame_0017.jpg 22 377 | office31/dslr/images/projector/frame_0018.jpg 22 378 | office31/dslr/images/projector/frame_0019.jpg 22 379 | office31/dslr/images/projector/frame_0020.jpg 22 380 | office31/dslr/images/projector/frame_0021.jpg 22 381 | office31/dslr/images/projector/frame_0022.jpg 22 382 | office31/dslr/images/projector/frame_0023.jpg 22 383 | office31/dslr/images/stapler/frame_0001.jpg 28 384 | office31/dslr/images/stapler/frame_0002.jpg 28 385 | office31/dslr/images/stapler/frame_0003.jpg 28 386 | office31/dslr/images/stapler/frame_0004.jpg 28 387 | office31/dslr/images/stapler/frame_0005.jpg 28 388 | office31/dslr/images/stapler/frame_0006.jpg 28 389 | office31/dslr/images/stapler/frame_0007.jpg 28 390 | office31/dslr/images/stapler/frame_0008.jpg 28 391 | office31/dslr/images/stapler/frame_0009.jpg 28 392 | office31/dslr/images/stapler/frame_0010.jpg 28 393 | office31/dslr/images/stapler/frame_0011.jpg 28 394 | office31/dslr/images/stapler/frame_0012.jpg 28 395 | office31/dslr/images/stapler/frame_0013.jpg 28 396 | office31/dslr/images/stapler/frame_0014.jpg 28 397 | office31/dslr/images/stapler/frame_0015.jpg 28 398 | office31/dslr/images/stapler/frame_0016.jpg 28 399 | office31/dslr/images/stapler/frame_0017.jpg 28 400 | office31/dslr/images/stapler/frame_0018.jpg 28 401 | office31/dslr/images/stapler/frame_0019.jpg 28 402 | office31/dslr/images/stapler/frame_0020.jpg 28 403 | office31/dslr/images/stapler/frame_0021.jpg 28 404 | office31/dslr/images/trash_can/frame_0001.jpg 30 405 | office31/dslr/images/trash_can/frame_0002.jpg 30 406 | office31/dslr/images/trash_can/frame_0003.jpg 30 407 | office31/dslr/images/trash_can/frame_0004.jpg 30 408 | office31/dslr/images/trash_can/frame_0005.jpg 30 409 | office31/dslr/images/trash_can/frame_0006.jpg 30 410 | office31/dslr/images/trash_can/frame_0007.jpg 30 411 | office31/dslr/images/trash_can/frame_0008.jpg 30 412 | office31/dslr/images/trash_can/frame_0009.jpg 30 413 | office31/dslr/images/trash_can/frame_0010.jpg 30 414 | office31/dslr/images/trash_can/frame_0011.jpg 30 415 | office31/dslr/images/trash_can/frame_0012.jpg 30 416 | office31/dslr/images/trash_can/frame_0013.jpg 30 417 | office31/dslr/images/trash_can/frame_0014.jpg 30 418 | office31/dslr/images/trash_can/frame_0015.jpg 30 419 | office31/dslr/images/bike_helmet/frame_0001.jpg 2 420 | office31/dslr/images/bike_helmet/frame_0002.jpg 2 421 | office31/dslr/images/bike_helmet/frame_0003.jpg 2 422 | office31/dslr/images/bike_helmet/frame_0004.jpg 2 423 | office31/dslr/images/bike_helmet/frame_0005.jpg 2 424 | office31/dslr/images/bike_helmet/frame_0006.jpg 2 425 | office31/dslr/images/bike_helmet/frame_0007.jpg 2 426 | office31/dslr/images/bike_helmet/frame_0008.jpg 2 427 | office31/dslr/images/bike_helmet/frame_0009.jpg 2 428 | office31/dslr/images/bike_helmet/frame_0010.jpg 2 429 | office31/dslr/images/bike_helmet/frame_0011.jpg 2 430 | office31/dslr/images/bike_helmet/frame_0012.jpg 2 431 | office31/dslr/images/bike_helmet/frame_0013.jpg 2 432 | office31/dslr/images/bike_helmet/frame_0014.jpg 2 433 | office31/dslr/images/bike_helmet/frame_0015.jpg 2 434 | office31/dslr/images/bike_helmet/frame_0016.jpg 2 435 | office31/dslr/images/bike_helmet/frame_0017.jpg 2 436 | office31/dslr/images/bike_helmet/frame_0018.jpg 2 437 | office31/dslr/images/bike_helmet/frame_0019.jpg 2 438 | office31/dslr/images/bike_helmet/frame_0020.jpg 2 439 | office31/dslr/images/bike_helmet/frame_0021.jpg 2 440 | office31/dslr/images/bike_helmet/frame_0022.jpg 2 441 | office31/dslr/images/bike_helmet/frame_0023.jpg 2 442 | office31/dslr/images/bike_helmet/frame_0024.jpg 2 443 | office31/dslr/images/headphones/frame_0001.jpg 10 444 | office31/dslr/images/headphones/frame_0002.jpg 10 445 | office31/dslr/images/headphones/frame_0003.jpg 10 446 | office31/dslr/images/headphones/frame_0004.jpg 10 447 | office31/dslr/images/headphones/frame_0005.jpg 10 448 | office31/dslr/images/headphones/frame_0006.jpg 10 449 | office31/dslr/images/headphones/frame_0007.jpg 10 450 | office31/dslr/images/headphones/frame_0008.jpg 10 451 | office31/dslr/images/headphones/frame_0009.jpg 10 452 | office31/dslr/images/headphones/frame_0010.jpg 10 453 | office31/dslr/images/headphones/frame_0011.jpg 10 454 | office31/dslr/images/headphones/frame_0012.jpg 10 455 | office31/dslr/images/headphones/frame_0013.jpg 10 456 | office31/dslr/images/desk_lamp/frame_0001.jpg 7 457 | office31/dslr/images/desk_lamp/frame_0002.jpg 7 458 | office31/dslr/images/desk_lamp/frame_0003.jpg 7 459 | office31/dslr/images/desk_lamp/frame_0004.jpg 7 460 | office31/dslr/images/desk_lamp/frame_0005.jpg 7 461 | office31/dslr/images/desk_lamp/frame_0006.jpg 7 462 | office31/dslr/images/desk_lamp/frame_0007.jpg 7 463 | office31/dslr/images/desk_lamp/frame_0008.jpg 7 464 | office31/dslr/images/desk_lamp/frame_0009.jpg 7 465 | office31/dslr/images/desk_lamp/frame_0010.jpg 7 466 | office31/dslr/images/desk_lamp/frame_0011.jpg 7 467 | office31/dslr/images/desk_lamp/frame_0012.jpg 7 468 | office31/dslr/images/desk_lamp/frame_0013.jpg 7 469 | office31/dslr/images/desk_lamp/frame_0014.jpg 7 470 | office31/dslr/images/desk_chair/frame_0001.jpg 6 471 | office31/dslr/images/desk_chair/frame_0002.jpg 6 472 | office31/dslr/images/desk_chair/frame_0003.jpg 6 473 | office31/dslr/images/desk_chair/frame_0004.jpg 6 474 | office31/dslr/images/desk_chair/frame_0005.jpg 6 475 | office31/dslr/images/desk_chair/frame_0006.jpg 6 476 | office31/dslr/images/desk_chair/frame_0007.jpg 6 477 | office31/dslr/images/desk_chair/frame_0008.jpg 6 478 | office31/dslr/images/desk_chair/frame_0009.jpg 6 479 | office31/dslr/images/desk_chair/frame_0010.jpg 6 480 | office31/dslr/images/desk_chair/frame_0011.jpg 6 481 | office31/dslr/images/desk_chair/frame_0012.jpg 6 482 | office31/dslr/images/desk_chair/frame_0013.jpg 6 483 | office31/dslr/images/bottle/frame_0001.jpg 4 484 | office31/dslr/images/bottle/frame_0002.jpg 4 485 | office31/dslr/images/bottle/frame_0003.jpg 4 486 | office31/dslr/images/bottle/frame_0004.jpg 4 487 | office31/dslr/images/bottle/frame_0005.jpg 4 488 | office31/dslr/images/bottle/frame_0006.jpg 4 489 | office31/dslr/images/bottle/frame_0007.jpg 4 490 | office31/dslr/images/bottle/frame_0008.jpg 4 491 | office31/dslr/images/bottle/frame_0009.jpg 4 492 | office31/dslr/images/bottle/frame_0010.jpg 4 493 | office31/dslr/images/bottle/frame_0011.jpg 4 494 | office31/dslr/images/bottle/frame_0012.jpg 4 495 | office31/dslr/images/bottle/frame_0013.jpg 4 496 | office31/dslr/images/bottle/frame_0014.jpg 4 497 | office31/dslr/images/bottle/frame_0015.jpg 4 498 | office31/dslr/images/bottle/frame_0016.jpg 4 499 | -------------------------------------------------------------------------------- /data/image_list/office31/webcam.txt: -------------------------------------------------------------------------------- 1 | office31//webcam/images/calculator/frame_0001.jpg 5 2 | office31//webcam/images/calculator/frame_0002.jpg 5 3 | office31//webcam/images/calculator/frame_0003.jpg 5 4 | office31//webcam/images/calculator/frame_0004.jpg 5 5 | office31//webcam/images/calculator/frame_0005.jpg 5 6 | office31//webcam/images/calculator/frame_0006.jpg 5 7 | office31//webcam/images/calculator/frame_0007.jpg 5 8 | office31//webcam/images/calculator/frame_0008.jpg 5 9 | office31//webcam/images/calculator/frame_0009.jpg 5 10 | office31//webcam/images/calculator/frame_0010.jpg 5 11 | office31//webcam/images/calculator/frame_0011.jpg 5 12 | office31//webcam/images/calculator/frame_0012.jpg 5 13 | office31//webcam/images/calculator/frame_0013.jpg 5 14 | office31//webcam/images/calculator/frame_0014.jpg 5 15 | office31//webcam/images/calculator/frame_0015.jpg 5 16 | office31//webcam/images/calculator/frame_0016.jpg 5 17 | office31//webcam/images/calculator/frame_0017.jpg 5 18 | office31//webcam/images/calculator/frame_0018.jpg 5 19 | office31//webcam/images/calculator/frame_0019.jpg 5 20 | office31//webcam/images/calculator/frame_0020.jpg 5 21 | office31//webcam/images/calculator/frame_0021.jpg 5 22 | office31//webcam/images/calculator/frame_0022.jpg 5 23 | office31//webcam/images/calculator/frame_0023.jpg 5 24 | office31//webcam/images/calculator/frame_0024.jpg 5 25 | office31//webcam/images/calculator/frame_0025.jpg 5 26 | office31//webcam/images/calculator/frame_0026.jpg 5 27 | office31//webcam/images/calculator/frame_0027.jpg 5 28 | office31//webcam/images/calculator/frame_0028.jpg 5 29 | office31//webcam/images/calculator/frame_0029.jpg 5 30 | office31//webcam/images/calculator/frame_0030.jpg 5 31 | office31//webcam/images/calculator/frame_0031.jpg 5 32 | office31//webcam/images/ring_binder/frame_0001.jpg 24 33 | office31//webcam/images/ring_binder/frame_0002.jpg 24 34 | office31//webcam/images/ring_binder/frame_0003.jpg 24 35 | office31//webcam/images/ring_binder/frame_0004.jpg 24 36 | office31//webcam/images/ring_binder/frame_0005.jpg 24 37 | office31//webcam/images/ring_binder/frame_0006.jpg 24 38 | office31//webcam/images/ring_binder/frame_0007.jpg 24 39 | office31//webcam/images/ring_binder/frame_0008.jpg 24 40 | office31//webcam/images/ring_binder/frame_0009.jpg 24 41 | office31//webcam/images/ring_binder/frame_0010.jpg 24 42 | office31//webcam/images/ring_binder/frame_0011.jpg 24 43 | office31//webcam/images/ring_binder/frame_0012.jpg 24 44 | office31//webcam/images/ring_binder/frame_0013.jpg 24 45 | office31//webcam/images/ring_binder/frame_0014.jpg 24 46 | office31//webcam/images/ring_binder/frame_0015.jpg 24 47 | office31//webcam/images/ring_binder/frame_0016.jpg 24 48 | office31//webcam/images/ring_binder/frame_0017.jpg 24 49 | office31//webcam/images/ring_binder/frame_0018.jpg 24 50 | office31//webcam/images/ring_binder/frame_0019.jpg 24 51 | office31//webcam/images/ring_binder/frame_0020.jpg 24 52 | office31//webcam/images/ring_binder/frame_0021.jpg 24 53 | office31//webcam/images/ring_binder/frame_0022.jpg 24 54 | office31//webcam/images/ring_binder/frame_0023.jpg 24 55 | office31//webcam/images/ring_binder/frame_0024.jpg 24 56 | office31//webcam/images/ring_binder/frame_0025.jpg 24 57 | office31//webcam/images/ring_binder/frame_0026.jpg 24 58 | office31//webcam/images/ring_binder/frame_0027.jpg 24 59 | office31//webcam/images/ring_binder/frame_0028.jpg 24 60 | office31//webcam/images/ring_binder/frame_0029.jpg 24 61 | office31//webcam/images/ring_binder/frame_0030.jpg 24 62 | office31//webcam/images/ring_binder/frame_0031.jpg 24 63 | office31//webcam/images/ring_binder/frame_0032.jpg 24 64 | office31//webcam/images/ring_binder/frame_0033.jpg 24 65 | office31//webcam/images/ring_binder/frame_0034.jpg 24 66 | office31//webcam/images/ring_binder/frame_0035.jpg 24 67 | office31//webcam/images/ring_binder/frame_0036.jpg 24 68 | office31//webcam/images/ring_binder/frame_0037.jpg 24 69 | office31//webcam/images/ring_binder/frame_0038.jpg 24 70 | office31//webcam/images/ring_binder/frame_0039.jpg 24 71 | office31//webcam/images/ring_binder/frame_0040.jpg 24 72 | office31//webcam/images/printer/frame_0001.jpg 21 73 | office31//webcam/images/printer/frame_0002.jpg 21 74 | office31//webcam/images/printer/frame_0003.jpg 21 75 | office31//webcam/images/printer/frame_0004.jpg 21 76 | office31//webcam/images/printer/frame_0005.jpg 21 77 | office31//webcam/images/printer/frame_0006.jpg 21 78 | office31//webcam/images/printer/frame_0007.jpg 21 79 | office31//webcam/images/printer/frame_0008.jpg 21 80 | office31//webcam/images/printer/frame_0009.jpg 21 81 | office31//webcam/images/printer/frame_0010.jpg 21 82 | office31//webcam/images/printer/frame_0011.jpg 21 83 | office31//webcam/images/printer/frame_0012.jpg 21 84 | office31//webcam/images/printer/frame_0013.jpg 21 85 | office31//webcam/images/printer/frame_0014.jpg 21 86 | office31//webcam/images/printer/frame_0015.jpg 21 87 | office31//webcam/images/printer/frame_0016.jpg 21 88 | office31//webcam/images/printer/frame_0017.jpg 21 89 | office31//webcam/images/printer/frame_0018.jpg 21 90 | office31//webcam/images/printer/frame_0019.jpg 21 91 | office31//webcam/images/printer/frame_0020.jpg 21 92 | office31//webcam/images/keyboard/frame_0001.jpg 11 93 | office31//webcam/images/keyboard/frame_0002.jpg 11 94 | office31//webcam/images/keyboard/frame_0003.jpg 11 95 | office31//webcam/images/keyboard/frame_0004.jpg 11 96 | office31//webcam/images/keyboard/frame_0005.jpg 11 97 | office31//webcam/images/keyboard/frame_0006.jpg 11 98 | office31//webcam/images/keyboard/frame_0007.jpg 11 99 | office31//webcam/images/keyboard/frame_0008.jpg 11 100 | office31//webcam/images/keyboard/frame_0009.jpg 11 101 | office31//webcam/images/keyboard/frame_0010.jpg 11 102 | office31//webcam/images/keyboard/frame_0011.jpg 11 103 | office31//webcam/images/keyboard/frame_0012.jpg 11 104 | office31//webcam/images/keyboard/frame_0013.jpg 11 105 | office31//webcam/images/keyboard/frame_0014.jpg 11 106 | office31//webcam/images/keyboard/frame_0015.jpg 11 107 | office31//webcam/images/keyboard/frame_0016.jpg 11 108 | office31//webcam/images/keyboard/frame_0017.jpg 11 109 | office31//webcam/images/keyboard/frame_0018.jpg 11 110 | office31//webcam/images/keyboard/frame_0019.jpg 11 111 | office31//webcam/images/keyboard/frame_0020.jpg 11 112 | office31//webcam/images/keyboard/frame_0021.jpg 11 113 | office31//webcam/images/keyboard/frame_0022.jpg 11 114 | office31//webcam/images/keyboard/frame_0023.jpg 11 115 | office31//webcam/images/keyboard/frame_0024.jpg 11 116 | office31//webcam/images/keyboard/frame_0025.jpg 11 117 | office31//webcam/images/keyboard/frame_0026.jpg 11 118 | office31//webcam/images/keyboard/frame_0027.jpg 11 119 | office31//webcam/images/scissors/frame_0001.jpg 26 120 | office31//webcam/images/scissors/frame_0002.jpg 26 121 | office31//webcam/images/scissors/frame_0003.jpg 26 122 | office31//webcam/images/scissors/frame_0004.jpg 26 123 | office31//webcam/images/scissors/frame_0005.jpg 26 124 | office31//webcam/images/scissors/frame_0006.jpg 26 125 | office31//webcam/images/scissors/frame_0007.jpg 26 126 | office31//webcam/images/scissors/frame_0008.jpg 26 127 | office31//webcam/images/scissors/frame_0009.jpg 26 128 | office31//webcam/images/scissors/frame_0010.jpg 26 129 | office31//webcam/images/scissors/frame_0011.jpg 26 130 | office31//webcam/images/scissors/frame_0012.jpg 26 131 | office31//webcam/images/scissors/frame_0013.jpg 26 132 | office31//webcam/images/scissors/frame_0014.jpg 26 133 | office31//webcam/images/scissors/frame_0015.jpg 26 134 | office31//webcam/images/scissors/frame_0016.jpg 26 135 | office31//webcam/images/scissors/frame_0017.jpg 26 136 | office31//webcam/images/scissors/frame_0018.jpg 26 137 | office31//webcam/images/scissors/frame_0019.jpg 26 138 | office31//webcam/images/scissors/frame_0020.jpg 26 139 | office31//webcam/images/scissors/frame_0021.jpg 26 140 | office31//webcam/images/scissors/frame_0022.jpg 26 141 | office31//webcam/images/scissors/frame_0023.jpg 26 142 | office31//webcam/images/scissors/frame_0024.jpg 26 143 | office31//webcam/images/scissors/frame_0025.jpg 26 144 | office31//webcam/images/laptop_computer/frame_0001.jpg 12 145 | office31//webcam/images/laptop_computer/frame_0002.jpg 12 146 | office31//webcam/images/laptop_computer/frame_0003.jpg 12 147 | office31//webcam/images/laptop_computer/frame_0004.jpg 12 148 | office31//webcam/images/laptop_computer/frame_0005.jpg 12 149 | office31//webcam/images/laptop_computer/frame_0006.jpg 12 150 | office31//webcam/images/laptop_computer/frame_0007.jpg 12 151 | office31//webcam/images/laptop_computer/frame_0008.jpg 12 152 | office31//webcam/images/laptop_computer/frame_0009.jpg 12 153 | office31//webcam/images/laptop_computer/frame_0010.jpg 12 154 | office31//webcam/images/laptop_computer/frame_0011.jpg 12 155 | office31//webcam/images/laptop_computer/frame_0012.jpg 12 156 | office31//webcam/images/laptop_computer/frame_0013.jpg 12 157 | office31//webcam/images/laptop_computer/frame_0014.jpg 12 158 | office31//webcam/images/laptop_computer/frame_0015.jpg 12 159 | office31//webcam/images/laptop_computer/frame_0016.jpg 12 160 | office31//webcam/images/laptop_computer/frame_0017.jpg 12 161 | office31//webcam/images/laptop_computer/frame_0018.jpg 12 162 | office31//webcam/images/laptop_computer/frame_0019.jpg 12 163 | office31//webcam/images/laptop_computer/frame_0020.jpg 12 164 | office31//webcam/images/laptop_computer/frame_0021.jpg 12 165 | office31//webcam/images/laptop_computer/frame_0022.jpg 12 166 | office31//webcam/images/laptop_computer/frame_0023.jpg 12 167 | office31//webcam/images/laptop_computer/frame_0024.jpg 12 168 | office31//webcam/images/laptop_computer/frame_0025.jpg 12 169 | office31//webcam/images/laptop_computer/frame_0026.jpg 12 170 | office31//webcam/images/laptop_computer/frame_0027.jpg 12 171 | office31//webcam/images/laptop_computer/frame_0028.jpg 12 172 | office31//webcam/images/laptop_computer/frame_0029.jpg 12 173 | office31//webcam/images/laptop_computer/frame_0030.jpg 12 174 | office31//webcam/images/mouse/frame_0001.jpg 16 175 | office31//webcam/images/mouse/frame_0002.jpg 16 176 | office31//webcam/images/mouse/frame_0003.jpg 16 177 | office31//webcam/images/mouse/frame_0004.jpg 16 178 | office31//webcam/images/mouse/frame_0005.jpg 16 179 | office31//webcam/images/mouse/frame_0006.jpg 16 180 | office31//webcam/images/mouse/frame_0007.jpg 16 181 | office31//webcam/images/mouse/frame_0008.jpg 16 182 | office31//webcam/images/mouse/frame_0009.jpg 16 183 | office31//webcam/images/mouse/frame_0010.jpg 16 184 | office31//webcam/images/mouse/frame_0011.jpg 16 185 | office31//webcam/images/mouse/frame_0012.jpg 16 186 | office31//webcam/images/mouse/frame_0013.jpg 16 187 | office31//webcam/images/mouse/frame_0014.jpg 16 188 | office31//webcam/images/mouse/frame_0015.jpg 16 189 | office31//webcam/images/mouse/frame_0016.jpg 16 190 | office31//webcam/images/mouse/frame_0017.jpg 16 191 | office31//webcam/images/mouse/frame_0018.jpg 16 192 | office31//webcam/images/mouse/frame_0019.jpg 16 193 | office31//webcam/images/mouse/frame_0020.jpg 16 194 | office31//webcam/images/mouse/frame_0021.jpg 16 195 | office31//webcam/images/mouse/frame_0022.jpg 16 196 | office31//webcam/images/mouse/frame_0023.jpg 16 197 | office31//webcam/images/mouse/frame_0024.jpg 16 198 | office31//webcam/images/mouse/frame_0025.jpg 16 199 | office31//webcam/images/mouse/frame_0026.jpg 16 200 | office31//webcam/images/mouse/frame_0027.jpg 16 201 | office31//webcam/images/mouse/frame_0028.jpg 16 202 | office31//webcam/images/mouse/frame_0029.jpg 16 203 | office31//webcam/images/mouse/frame_0030.jpg 16 204 | office31//webcam/images/monitor/frame_0001.jpg 15 205 | office31//webcam/images/monitor/frame_0002.jpg 15 206 | office31//webcam/images/monitor/frame_0003.jpg 15 207 | office31//webcam/images/monitor/frame_0004.jpg 15 208 | office31//webcam/images/monitor/frame_0005.jpg 15 209 | office31//webcam/images/monitor/frame_0006.jpg 15 210 | office31//webcam/images/monitor/frame_0007.jpg 15 211 | office31//webcam/images/monitor/frame_0008.jpg 15 212 | office31//webcam/images/monitor/frame_0009.jpg 15 213 | office31//webcam/images/monitor/frame_0010.jpg 15 214 | office31//webcam/images/monitor/frame_0011.jpg 15 215 | office31//webcam/images/monitor/frame_0012.jpg 15 216 | office31//webcam/images/monitor/frame_0013.jpg 15 217 | office31//webcam/images/monitor/frame_0014.jpg 15 218 | office31//webcam/images/monitor/frame_0015.jpg 15 219 | office31//webcam/images/monitor/frame_0016.jpg 15 220 | office31//webcam/images/monitor/frame_0017.jpg 15 221 | office31//webcam/images/monitor/frame_0018.jpg 15 222 | office31//webcam/images/monitor/frame_0019.jpg 15 223 | office31//webcam/images/monitor/frame_0020.jpg 15 224 | office31//webcam/images/monitor/frame_0021.jpg 15 225 | office31//webcam/images/monitor/frame_0022.jpg 15 226 | office31//webcam/images/monitor/frame_0023.jpg 15 227 | office31//webcam/images/monitor/frame_0024.jpg 15 228 | office31//webcam/images/monitor/frame_0025.jpg 15 229 | office31//webcam/images/monitor/frame_0026.jpg 15 230 | office31//webcam/images/monitor/frame_0027.jpg 15 231 | office31//webcam/images/monitor/frame_0028.jpg 15 232 | office31//webcam/images/monitor/frame_0029.jpg 15 233 | office31//webcam/images/monitor/frame_0030.jpg 15 234 | office31//webcam/images/monitor/frame_0031.jpg 15 235 | office31//webcam/images/monitor/frame_0032.jpg 15 236 | office31//webcam/images/monitor/frame_0033.jpg 15 237 | office31//webcam/images/monitor/frame_0034.jpg 15 238 | office31//webcam/images/monitor/frame_0035.jpg 15 239 | office31//webcam/images/monitor/frame_0036.jpg 15 240 | office31//webcam/images/monitor/frame_0037.jpg 15 241 | office31//webcam/images/monitor/frame_0038.jpg 15 242 | office31//webcam/images/monitor/frame_0039.jpg 15 243 | office31//webcam/images/monitor/frame_0040.jpg 15 244 | office31//webcam/images/monitor/frame_0041.jpg 15 245 | office31//webcam/images/monitor/frame_0042.jpg 15 246 | office31//webcam/images/monitor/frame_0043.jpg 15 247 | office31//webcam/images/mug/frame_0001.jpg 17 248 | office31//webcam/images/mug/frame_0002.jpg 17 249 | office31//webcam/images/mug/frame_0003.jpg 17 250 | office31//webcam/images/mug/frame_0004.jpg 17 251 | office31//webcam/images/mug/frame_0005.jpg 17 252 | office31//webcam/images/mug/frame_0006.jpg 17 253 | office31//webcam/images/mug/frame_0007.jpg 17 254 | office31//webcam/images/mug/frame_0008.jpg 17 255 | office31//webcam/images/mug/frame_0009.jpg 17 256 | office31//webcam/images/mug/frame_0010.jpg 17 257 | office31//webcam/images/mug/frame_0011.jpg 17 258 | office31//webcam/images/mug/frame_0012.jpg 17 259 | office31//webcam/images/mug/frame_0013.jpg 17 260 | office31//webcam/images/mug/frame_0014.jpg 17 261 | office31//webcam/images/mug/frame_0015.jpg 17 262 | office31//webcam/images/mug/frame_0016.jpg 17 263 | office31//webcam/images/mug/frame_0017.jpg 17 264 | office31//webcam/images/mug/frame_0018.jpg 17 265 | office31//webcam/images/mug/frame_0019.jpg 17 266 | office31//webcam/images/mug/frame_0020.jpg 17 267 | office31//webcam/images/mug/frame_0021.jpg 17 268 | office31//webcam/images/mug/frame_0022.jpg 17 269 | office31//webcam/images/mug/frame_0023.jpg 17 270 | office31//webcam/images/mug/frame_0024.jpg 17 271 | office31//webcam/images/mug/frame_0025.jpg 17 272 | office31//webcam/images/mug/frame_0026.jpg 17 273 | office31//webcam/images/mug/frame_0027.jpg 17 274 | office31//webcam/images/tape_dispenser/frame_0001.jpg 29 275 | office31//webcam/images/tape_dispenser/frame_0002.jpg 29 276 | office31//webcam/images/tape_dispenser/frame_0003.jpg 29 277 | office31//webcam/images/tape_dispenser/frame_0004.jpg 29 278 | office31//webcam/images/tape_dispenser/frame_0005.jpg 29 279 | office31//webcam/images/tape_dispenser/frame_0006.jpg 29 280 | office31//webcam/images/tape_dispenser/frame_0007.jpg 29 281 | office31//webcam/images/tape_dispenser/frame_0008.jpg 29 282 | office31//webcam/images/tape_dispenser/frame_0009.jpg 29 283 | office31//webcam/images/tape_dispenser/frame_0010.jpg 29 284 | office31//webcam/images/tape_dispenser/frame_0011.jpg 29 285 | office31//webcam/images/tape_dispenser/frame_0012.jpg 29 286 | office31//webcam/images/tape_dispenser/frame_0013.jpg 29 287 | office31//webcam/images/tape_dispenser/frame_0014.jpg 29 288 | office31//webcam/images/tape_dispenser/frame_0015.jpg 29 289 | office31//webcam/images/tape_dispenser/frame_0016.jpg 29 290 | office31//webcam/images/tape_dispenser/frame_0017.jpg 29 291 | office31//webcam/images/tape_dispenser/frame_0018.jpg 29 292 | office31//webcam/images/tape_dispenser/frame_0019.jpg 29 293 | office31//webcam/images/tape_dispenser/frame_0020.jpg 29 294 | office31//webcam/images/tape_dispenser/frame_0021.jpg 29 295 | office31//webcam/images/tape_dispenser/frame_0022.jpg 29 296 | office31//webcam/images/tape_dispenser/frame_0023.jpg 29 297 | office31//webcam/images/pen/frame_0001.jpg 19 298 | office31//webcam/images/pen/frame_0002.jpg 19 299 | office31//webcam/images/pen/frame_0003.jpg 19 300 | office31//webcam/images/pen/frame_0004.jpg 19 301 | office31//webcam/images/pen/frame_0005.jpg 19 302 | office31//webcam/images/pen/frame_0006.jpg 19 303 | office31//webcam/images/pen/frame_0007.jpg 19 304 | office31//webcam/images/pen/frame_0008.jpg 19 305 | office31//webcam/images/pen/frame_0009.jpg 19 306 | office31//webcam/images/pen/frame_0010.jpg 19 307 | office31//webcam/images/pen/frame_0011.jpg 19 308 | office31//webcam/images/pen/frame_0012.jpg 19 309 | office31//webcam/images/pen/frame_0013.jpg 19 310 | office31//webcam/images/pen/frame_0014.jpg 19 311 | office31//webcam/images/pen/frame_0015.jpg 19 312 | office31//webcam/images/pen/frame_0016.jpg 19 313 | office31//webcam/images/pen/frame_0017.jpg 19 314 | office31//webcam/images/pen/frame_0018.jpg 19 315 | office31//webcam/images/pen/frame_0019.jpg 19 316 | office31//webcam/images/pen/frame_0020.jpg 19 317 | office31//webcam/images/pen/frame_0021.jpg 19 318 | office31//webcam/images/pen/frame_0022.jpg 19 319 | office31//webcam/images/pen/frame_0023.jpg 19 320 | office31//webcam/images/pen/frame_0024.jpg 19 321 | office31//webcam/images/pen/frame_0025.jpg 19 322 | office31//webcam/images/pen/frame_0026.jpg 19 323 | office31//webcam/images/pen/frame_0027.jpg 19 324 | office31//webcam/images/pen/frame_0028.jpg 19 325 | office31//webcam/images/pen/frame_0029.jpg 19 326 | office31//webcam/images/pen/frame_0030.jpg 19 327 | office31//webcam/images/pen/frame_0031.jpg 19 328 | office31//webcam/images/pen/frame_0032.jpg 19 329 | office31//webcam/images/bike/frame_0001.jpg 1 330 | office31//webcam/images/bike/frame_0002.jpg 1 331 | office31//webcam/images/bike/frame_0003.jpg 1 332 | office31//webcam/images/bike/frame_0004.jpg 1 333 | office31//webcam/images/bike/frame_0005.jpg 1 334 | office31//webcam/images/bike/frame_0006.jpg 1 335 | office31//webcam/images/bike/frame_0007.jpg 1 336 | office31//webcam/images/bike/frame_0008.jpg 1 337 | office31//webcam/images/bike/frame_0009.jpg 1 338 | office31//webcam/images/bike/frame_0010.jpg 1 339 | office31//webcam/images/bike/frame_0011.jpg 1 340 | office31//webcam/images/bike/frame_0012.jpg 1 341 | office31//webcam/images/bike/frame_0013.jpg 1 342 | office31//webcam/images/bike/frame_0014.jpg 1 343 | office31//webcam/images/bike/frame_0015.jpg 1 344 | office31//webcam/images/bike/frame_0016.jpg 1 345 | office31//webcam/images/bike/frame_0017.jpg 1 346 | office31//webcam/images/bike/frame_0018.jpg 1 347 | office31//webcam/images/bike/frame_0019.jpg 1 348 | office31//webcam/images/bike/frame_0020.jpg 1 349 | office31//webcam/images/bike/frame_0021.jpg 1 350 | office31//webcam/images/punchers/frame_0001.jpg 23 351 | office31//webcam/images/punchers/frame_0002.jpg 23 352 | office31//webcam/images/punchers/frame_0003.jpg 23 353 | office31//webcam/images/punchers/frame_0004.jpg 23 354 | office31//webcam/images/punchers/frame_0005.jpg 23 355 | office31//webcam/images/punchers/frame_0006.jpg 23 356 | office31//webcam/images/punchers/frame_0007.jpg 23 357 | office31//webcam/images/punchers/frame_0008.jpg 23 358 | office31//webcam/images/punchers/frame_0009.jpg 23 359 | office31//webcam/images/punchers/frame_0010.jpg 23 360 | office31//webcam/images/punchers/frame_0011.jpg 23 361 | office31//webcam/images/punchers/frame_0012.jpg 23 362 | office31//webcam/images/punchers/frame_0013.jpg 23 363 | office31//webcam/images/punchers/frame_0014.jpg 23 364 | office31//webcam/images/punchers/frame_0015.jpg 23 365 | office31//webcam/images/punchers/frame_0016.jpg 23 366 | office31//webcam/images/punchers/frame_0017.jpg 23 367 | office31//webcam/images/punchers/frame_0018.jpg 23 368 | office31//webcam/images/punchers/frame_0019.jpg 23 369 | office31//webcam/images/punchers/frame_0020.jpg 23 370 | office31//webcam/images/punchers/frame_0021.jpg 23 371 | office31//webcam/images/punchers/frame_0022.jpg 23 372 | office31//webcam/images/punchers/frame_0023.jpg 23 373 | office31//webcam/images/punchers/frame_0024.jpg 23 374 | office31//webcam/images/punchers/frame_0025.jpg 23 375 | office31//webcam/images/punchers/frame_0026.jpg 23 376 | office31//webcam/images/punchers/frame_0027.jpg 23 377 | office31//webcam/images/back_pack/frame_0001.jpg 0 378 | office31//webcam/images/back_pack/frame_0002.jpg 0 379 | office31//webcam/images/back_pack/frame_0003.jpg 0 380 | office31//webcam/images/back_pack/frame_0004.jpg 0 381 | office31//webcam/images/back_pack/frame_0005.jpg 0 382 | office31//webcam/images/back_pack/frame_0006.jpg 0 383 | office31//webcam/images/back_pack/frame_0007.jpg 0 384 | office31//webcam/images/back_pack/frame_0008.jpg 0 385 | office31//webcam/images/back_pack/frame_0009.jpg 0 386 | office31//webcam/images/back_pack/frame_0010.jpg 0 387 | office31//webcam/images/back_pack/frame_0011.jpg 0 388 | office31//webcam/images/back_pack/frame_0012.jpg 0 389 | office31//webcam/images/back_pack/frame_0013.jpg 0 390 | office31//webcam/images/back_pack/frame_0014.jpg 0 391 | office31//webcam/images/back_pack/frame_0015.jpg 0 392 | office31//webcam/images/back_pack/frame_0016.jpg 0 393 | office31//webcam/images/back_pack/frame_0017.jpg 0 394 | office31//webcam/images/back_pack/frame_0018.jpg 0 395 | office31//webcam/images/back_pack/frame_0019.jpg 0 396 | office31//webcam/images/back_pack/frame_0020.jpg 0 397 | office31//webcam/images/back_pack/frame_0021.jpg 0 398 | office31//webcam/images/back_pack/frame_0022.jpg 0 399 | office31//webcam/images/back_pack/frame_0023.jpg 0 400 | office31//webcam/images/back_pack/frame_0024.jpg 0 401 | office31//webcam/images/back_pack/frame_0025.jpg 0 402 | office31//webcam/images/back_pack/frame_0026.jpg 0 403 | office31//webcam/images/back_pack/frame_0027.jpg 0 404 | office31//webcam/images/back_pack/frame_0028.jpg 0 405 | office31//webcam/images/back_pack/frame_0029.jpg 0 406 | office31//webcam/images/desktop_computer/frame_0001.jpg 8 407 | office31//webcam/images/desktop_computer/frame_0002.jpg 8 408 | office31//webcam/images/desktop_computer/frame_0003.jpg 8 409 | office31//webcam/images/desktop_computer/frame_0004.jpg 8 410 | office31//webcam/images/desktop_computer/frame_0005.jpg 8 411 | office31//webcam/images/desktop_computer/frame_0006.jpg 8 412 | office31//webcam/images/desktop_computer/frame_0007.jpg 8 413 | office31//webcam/images/desktop_computer/frame_0008.jpg 8 414 | office31//webcam/images/desktop_computer/frame_0009.jpg 8 415 | office31//webcam/images/desktop_computer/frame_0010.jpg 8 416 | office31//webcam/images/desktop_computer/frame_0011.jpg 8 417 | office31//webcam/images/desktop_computer/frame_0012.jpg 8 418 | office31//webcam/images/desktop_computer/frame_0013.jpg 8 419 | office31//webcam/images/desktop_computer/frame_0014.jpg 8 420 | office31//webcam/images/desktop_computer/frame_0015.jpg 8 421 | office31//webcam/images/desktop_computer/frame_0016.jpg 8 422 | office31//webcam/images/desktop_computer/frame_0017.jpg 8 423 | office31//webcam/images/desktop_computer/frame_0018.jpg 8 424 | office31//webcam/images/desktop_computer/frame_0019.jpg 8 425 | office31//webcam/images/desktop_computer/frame_0020.jpg 8 426 | office31//webcam/images/desktop_computer/frame_0021.jpg 8 427 | office31//webcam/images/speaker/frame_0001.jpg 27 428 | office31//webcam/images/speaker/frame_0002.jpg 27 429 | office31//webcam/images/speaker/frame_0003.jpg 27 430 | office31//webcam/images/speaker/frame_0004.jpg 27 431 | office31//webcam/images/speaker/frame_0005.jpg 27 432 | office31//webcam/images/speaker/frame_0006.jpg 27 433 | office31//webcam/images/speaker/frame_0007.jpg 27 434 | office31//webcam/images/speaker/frame_0008.jpg 27 435 | office31//webcam/images/speaker/frame_0009.jpg 27 436 | office31//webcam/images/speaker/frame_0010.jpg 27 437 | office31//webcam/images/speaker/frame_0011.jpg 27 438 | office31//webcam/images/speaker/frame_0012.jpg 27 439 | office31//webcam/images/speaker/frame_0013.jpg 27 440 | office31//webcam/images/speaker/frame_0014.jpg 27 441 | office31//webcam/images/speaker/frame_0015.jpg 27 442 | office31//webcam/images/speaker/frame_0016.jpg 27 443 | office31//webcam/images/speaker/frame_0017.jpg 27 444 | office31//webcam/images/speaker/frame_0018.jpg 27 445 | office31//webcam/images/speaker/frame_0019.jpg 27 446 | office31//webcam/images/speaker/frame_0020.jpg 27 447 | office31//webcam/images/speaker/frame_0021.jpg 27 448 | office31//webcam/images/speaker/frame_0022.jpg 27 449 | office31//webcam/images/speaker/frame_0023.jpg 27 450 | office31//webcam/images/speaker/frame_0024.jpg 27 451 | office31//webcam/images/speaker/frame_0025.jpg 27 452 | office31//webcam/images/speaker/frame_0026.jpg 27 453 | office31//webcam/images/speaker/frame_0027.jpg 27 454 | office31//webcam/images/speaker/frame_0028.jpg 27 455 | office31//webcam/images/speaker/frame_0029.jpg 27 456 | office31//webcam/images/speaker/frame_0030.jpg 27 457 | office31//webcam/images/mobile_phone/frame_0001.jpg 14 458 | office31//webcam/images/mobile_phone/frame_0002.jpg 14 459 | office31//webcam/images/mobile_phone/frame_0003.jpg 14 460 | office31//webcam/images/mobile_phone/frame_0004.jpg 14 461 | office31//webcam/images/mobile_phone/frame_0005.jpg 14 462 | office31//webcam/images/mobile_phone/frame_0006.jpg 14 463 | office31//webcam/images/mobile_phone/frame_0007.jpg 14 464 | office31//webcam/images/mobile_phone/frame_0008.jpg 14 465 | office31//webcam/images/mobile_phone/frame_0009.jpg 14 466 | office31//webcam/images/mobile_phone/frame_0010.jpg 14 467 | office31//webcam/images/mobile_phone/frame_0011.jpg 14 468 | office31//webcam/images/mobile_phone/frame_0012.jpg 14 469 | office31//webcam/images/mobile_phone/frame_0013.jpg 14 470 | office31//webcam/images/mobile_phone/frame_0014.jpg 14 471 | office31//webcam/images/mobile_phone/frame_0015.jpg 14 472 | office31//webcam/images/mobile_phone/frame_0016.jpg 14 473 | office31//webcam/images/mobile_phone/frame_0017.jpg 14 474 | office31//webcam/images/mobile_phone/frame_0018.jpg 14 475 | office31//webcam/images/mobile_phone/frame_0019.jpg 14 476 | office31//webcam/images/mobile_phone/frame_0020.jpg 14 477 | office31//webcam/images/mobile_phone/frame_0021.jpg 14 478 | office31//webcam/images/mobile_phone/frame_0022.jpg 14 479 | office31//webcam/images/mobile_phone/frame_0023.jpg 14 480 | office31//webcam/images/mobile_phone/frame_0024.jpg 14 481 | office31//webcam/images/mobile_phone/frame_0025.jpg 14 482 | office31//webcam/images/mobile_phone/frame_0026.jpg 14 483 | office31//webcam/images/mobile_phone/frame_0027.jpg 14 484 | office31//webcam/images/mobile_phone/frame_0028.jpg 14 485 | office31//webcam/images/mobile_phone/frame_0029.jpg 14 486 | office31//webcam/images/mobile_phone/frame_0030.jpg 14 487 | office31//webcam/images/paper_notebook/frame_0001.jpg 18 488 | office31//webcam/images/paper_notebook/frame_0002.jpg 18 489 | office31//webcam/images/paper_notebook/frame_0003.jpg 18 490 | office31//webcam/images/paper_notebook/frame_0004.jpg 18 491 | office31//webcam/images/paper_notebook/frame_0005.jpg 18 492 | office31//webcam/images/paper_notebook/frame_0006.jpg 18 493 | office31//webcam/images/paper_notebook/frame_0007.jpg 18 494 | office31//webcam/images/paper_notebook/frame_0008.jpg 18 495 | office31//webcam/images/paper_notebook/frame_0009.jpg 18 496 | office31//webcam/images/paper_notebook/frame_0010.jpg 18 497 | office31//webcam/images/paper_notebook/frame_0011.jpg 18 498 | office31//webcam/images/paper_notebook/frame_0012.jpg 18 499 | office31//webcam/images/paper_notebook/frame_0013.jpg 18 500 | office31//webcam/images/paper_notebook/frame_0014.jpg 18 501 | office31//webcam/images/paper_notebook/frame_0015.jpg 18 502 | office31//webcam/images/paper_notebook/frame_0016.jpg 18 503 | office31//webcam/images/paper_notebook/frame_0017.jpg 18 504 | office31//webcam/images/paper_notebook/frame_0018.jpg 18 505 | office31//webcam/images/paper_notebook/frame_0019.jpg 18 506 | office31//webcam/images/paper_notebook/frame_0020.jpg 18 507 | office31//webcam/images/paper_notebook/frame_0021.jpg 18 508 | office31//webcam/images/paper_notebook/frame_0022.jpg 18 509 | office31//webcam/images/paper_notebook/frame_0023.jpg 18 510 | office31//webcam/images/paper_notebook/frame_0024.jpg 18 511 | office31//webcam/images/paper_notebook/frame_0025.jpg 18 512 | office31//webcam/images/paper_notebook/frame_0026.jpg 18 513 | office31//webcam/images/paper_notebook/frame_0027.jpg 18 514 | office31//webcam/images/paper_notebook/frame_0028.jpg 18 515 | office31//webcam/images/ruler/frame_0001.jpg 25 516 | office31//webcam/images/ruler/frame_0002.jpg 25 517 | office31//webcam/images/ruler/frame_0003.jpg 25 518 | office31//webcam/images/ruler/frame_0004.jpg 25 519 | office31//webcam/images/ruler/frame_0005.jpg 25 520 | office31//webcam/images/ruler/frame_0006.jpg 25 521 | office31//webcam/images/ruler/frame_0007.jpg 25 522 | office31//webcam/images/ruler/frame_0008.jpg 25 523 | office31//webcam/images/ruler/frame_0009.jpg 25 524 | office31//webcam/images/ruler/frame_0010.jpg 25 525 | office31//webcam/images/ruler/frame_0011.jpg 25 526 | office31//webcam/images/letter_tray/frame_0001.jpg 13 527 | office31//webcam/images/letter_tray/frame_0002.jpg 13 528 | office31//webcam/images/letter_tray/frame_0003.jpg 13 529 | office31//webcam/images/letter_tray/frame_0004.jpg 13 530 | office31//webcam/images/letter_tray/frame_0005.jpg 13 531 | office31//webcam/images/letter_tray/frame_0006.jpg 13 532 | office31//webcam/images/letter_tray/frame_0007.jpg 13 533 | office31//webcam/images/letter_tray/frame_0008.jpg 13 534 | office31//webcam/images/letter_tray/frame_0009.jpg 13 535 | office31//webcam/images/letter_tray/frame_0010.jpg 13 536 | office31//webcam/images/letter_tray/frame_0011.jpg 13 537 | office31//webcam/images/letter_tray/frame_0012.jpg 13 538 | office31//webcam/images/letter_tray/frame_0013.jpg 13 539 | office31//webcam/images/letter_tray/frame_0014.jpg 13 540 | office31//webcam/images/letter_tray/frame_0015.jpg 13 541 | office31//webcam/images/letter_tray/frame_0016.jpg 13 542 | office31//webcam/images/letter_tray/frame_0017.jpg 13 543 | office31//webcam/images/letter_tray/frame_0018.jpg 13 544 | office31//webcam/images/letter_tray/frame_0019.jpg 13 545 | office31//webcam/images/file_cabinet/frame_0001.jpg 9 546 | office31//webcam/images/file_cabinet/frame_0002.jpg 9 547 | office31//webcam/images/file_cabinet/frame_0003.jpg 9 548 | office31//webcam/images/file_cabinet/frame_0004.jpg 9 549 | office31//webcam/images/file_cabinet/frame_0005.jpg 9 550 | office31//webcam/images/file_cabinet/frame_0006.jpg 9 551 | office31//webcam/images/file_cabinet/frame_0007.jpg 9 552 | office31//webcam/images/file_cabinet/frame_0008.jpg 9 553 | office31//webcam/images/file_cabinet/frame_0009.jpg 9 554 | office31//webcam/images/file_cabinet/frame_0010.jpg 9 555 | office31//webcam/images/file_cabinet/frame_0011.jpg 9 556 | office31//webcam/images/file_cabinet/frame_0012.jpg 9 557 | office31//webcam/images/file_cabinet/frame_0013.jpg 9 558 | office31//webcam/images/file_cabinet/frame_0014.jpg 9 559 | office31//webcam/images/file_cabinet/frame_0015.jpg 9 560 | office31//webcam/images/file_cabinet/frame_0016.jpg 9 561 | office31//webcam/images/file_cabinet/frame_0017.jpg 9 562 | office31//webcam/images/file_cabinet/frame_0018.jpg 9 563 | office31//webcam/images/file_cabinet/frame_0019.jpg 9 564 | office31//webcam/images/phone/frame_0001.jpg 20 565 | office31//webcam/images/phone/frame_0002.jpg 20 566 | office31//webcam/images/phone/frame_0003.jpg 20 567 | office31//webcam/images/phone/frame_0004.jpg 20 568 | office31//webcam/images/phone/frame_0005.jpg 20 569 | office31//webcam/images/phone/frame_0006.jpg 20 570 | office31//webcam/images/phone/frame_0007.jpg 20 571 | office31//webcam/images/phone/frame_0008.jpg 20 572 | office31//webcam/images/phone/frame_0009.jpg 20 573 | office31//webcam/images/phone/frame_0010.jpg 20 574 | office31//webcam/images/phone/frame_0011.jpg 20 575 | office31//webcam/images/phone/frame_0012.jpg 20 576 | office31//webcam/images/phone/frame_0013.jpg 20 577 | office31//webcam/images/phone/frame_0014.jpg 20 578 | office31//webcam/images/phone/frame_0015.jpg 20 579 | office31//webcam/images/phone/frame_0016.jpg 20 580 | office31//webcam/images/bookcase/frame_0001.jpg 3 581 | office31//webcam/images/bookcase/frame_0002.jpg 3 582 | office31//webcam/images/bookcase/frame_0003.jpg 3 583 | office31//webcam/images/bookcase/frame_0004.jpg 3 584 | office31//webcam/images/bookcase/frame_0005.jpg 3 585 | office31//webcam/images/bookcase/frame_0006.jpg 3 586 | office31//webcam/images/bookcase/frame_0007.jpg 3 587 | office31//webcam/images/bookcase/frame_0008.jpg 3 588 | office31//webcam/images/bookcase/frame_0009.jpg 3 589 | office31//webcam/images/bookcase/frame_0010.jpg 3 590 | office31//webcam/images/bookcase/frame_0011.jpg 3 591 | office31//webcam/images/bookcase/frame_0012.jpg 3 592 | office31//webcam/images/projector/frame_0001.jpg 22 593 | office31//webcam/images/projector/frame_0002.jpg 22 594 | office31//webcam/images/projector/frame_0003.jpg 22 595 | office31//webcam/images/projector/frame_0004.jpg 22 596 | office31//webcam/images/projector/frame_0005.jpg 22 597 | office31//webcam/images/projector/frame_0006.jpg 22 598 | office31//webcam/images/projector/frame_0007.jpg 22 599 | office31//webcam/images/projector/frame_0008.jpg 22 600 | office31//webcam/images/projector/frame_0009.jpg 22 601 | office31//webcam/images/projector/frame_0010.jpg 22 602 | office31//webcam/images/projector/frame_0011.jpg 22 603 | office31//webcam/images/projector/frame_0012.jpg 22 604 | office31//webcam/images/projector/frame_0013.jpg 22 605 | office31//webcam/images/projector/frame_0014.jpg 22 606 | office31//webcam/images/projector/frame_0015.jpg 22 607 | office31//webcam/images/projector/frame_0016.jpg 22 608 | office31//webcam/images/projector/frame_0017.jpg 22 609 | office31//webcam/images/projector/frame_0018.jpg 22 610 | office31//webcam/images/projector/frame_0019.jpg 22 611 | office31//webcam/images/projector/frame_0020.jpg 22 612 | office31//webcam/images/projector/frame_0021.jpg 22 613 | office31//webcam/images/projector/frame_0022.jpg 22 614 | office31//webcam/images/projector/frame_0023.jpg 22 615 | office31//webcam/images/projector/frame_0024.jpg 22 616 | office31//webcam/images/projector/frame_0025.jpg 22 617 | office31//webcam/images/projector/frame_0026.jpg 22 618 | office31//webcam/images/projector/frame_0027.jpg 22 619 | office31//webcam/images/projector/frame_0028.jpg 22 620 | office31//webcam/images/projector/frame_0029.jpg 22 621 | office31//webcam/images/projector/frame_0030.jpg 22 622 | office31//webcam/images/stapler/frame_0001.jpg 28 623 | office31//webcam/images/stapler/frame_0002.jpg 28 624 | office31//webcam/images/stapler/frame_0003.jpg 28 625 | office31//webcam/images/stapler/frame_0004.jpg 28 626 | office31//webcam/images/stapler/frame_0005.jpg 28 627 | office31//webcam/images/stapler/frame_0006.jpg 28 628 | office31//webcam/images/stapler/frame_0007.jpg 28 629 | office31//webcam/images/stapler/frame_0008.jpg 28 630 | office31//webcam/images/stapler/frame_0009.jpg 28 631 | office31//webcam/images/stapler/frame_0010.jpg 28 632 | office31//webcam/images/stapler/frame_0011.jpg 28 633 | office31//webcam/images/stapler/frame_0012.jpg 28 634 | office31//webcam/images/stapler/frame_0013.jpg 28 635 | office31//webcam/images/stapler/frame_0014.jpg 28 636 | office31//webcam/images/stapler/frame_0015.jpg 28 637 | office31//webcam/images/stapler/frame_0016.jpg 28 638 | office31//webcam/images/stapler/frame_0017.jpg 28 639 | office31//webcam/images/stapler/frame_0018.jpg 28 640 | office31//webcam/images/stapler/frame_0019.jpg 28 641 | office31//webcam/images/stapler/frame_0020.jpg 28 642 | office31//webcam/images/stapler/frame_0021.jpg 28 643 | office31//webcam/images/stapler/frame_0022.jpg 28 644 | office31//webcam/images/stapler/frame_0023.jpg 28 645 | office31//webcam/images/stapler/frame_0024.jpg 28 646 | office31//webcam/images/trash_can/frame_0001.jpg 30 647 | office31//webcam/images/trash_can/frame_0002.jpg 30 648 | office31//webcam/images/trash_can/frame_0003.jpg 30 649 | office31//webcam/images/trash_can/frame_0004.jpg 30 650 | office31//webcam/images/trash_can/frame_0005.jpg 30 651 | office31//webcam/images/trash_can/frame_0006.jpg 30 652 | office31//webcam/images/trash_can/frame_0007.jpg 30 653 | office31//webcam/images/trash_can/frame_0008.jpg 30 654 | office31//webcam/images/trash_can/frame_0009.jpg 30 655 | office31//webcam/images/trash_can/frame_0010.jpg 30 656 | office31//webcam/images/trash_can/frame_0011.jpg 30 657 | office31//webcam/images/trash_can/frame_0012.jpg 30 658 | office31//webcam/images/trash_can/frame_0013.jpg 30 659 | office31//webcam/images/trash_can/frame_0014.jpg 30 660 | office31//webcam/images/trash_can/frame_0015.jpg 30 661 | office31//webcam/images/trash_can/frame_0016.jpg 30 662 | office31//webcam/images/trash_can/frame_0017.jpg 30 663 | office31//webcam/images/trash_can/frame_0018.jpg 30 664 | office31//webcam/images/trash_can/frame_0019.jpg 30 665 | office31//webcam/images/trash_can/frame_0020.jpg 30 666 | office31//webcam/images/trash_can/frame_0021.jpg 30 667 | office31//webcam/images/bike_helmet/frame_0001.jpg 2 668 | office31//webcam/images/bike_helmet/frame_0002.jpg 2 669 | office31//webcam/images/bike_helmet/frame_0003.jpg 2 670 | office31//webcam/images/bike_helmet/frame_0004.jpg 2 671 | office31//webcam/images/bike_helmet/frame_0005.jpg 2 672 | office31//webcam/images/bike_helmet/frame_0006.jpg 2 673 | office31//webcam/images/bike_helmet/frame_0007.jpg 2 674 | office31//webcam/images/bike_helmet/frame_0008.jpg 2 675 | office31//webcam/images/bike_helmet/frame_0009.jpg 2 676 | office31//webcam/images/bike_helmet/frame_0010.jpg 2 677 | office31//webcam/images/bike_helmet/frame_0011.jpg 2 678 | office31//webcam/images/bike_helmet/frame_0012.jpg 2 679 | office31//webcam/images/bike_helmet/frame_0013.jpg 2 680 | office31//webcam/images/bike_helmet/frame_0014.jpg 2 681 | office31//webcam/images/bike_helmet/frame_0015.jpg 2 682 | office31//webcam/images/bike_helmet/frame_0016.jpg 2 683 | office31//webcam/images/bike_helmet/frame_0017.jpg 2 684 | office31//webcam/images/bike_helmet/frame_0018.jpg 2 685 | office31//webcam/images/bike_helmet/frame_0019.jpg 2 686 | office31//webcam/images/bike_helmet/frame_0020.jpg 2 687 | office31//webcam/images/bike_helmet/frame_0021.jpg 2 688 | office31//webcam/images/bike_helmet/frame_0022.jpg 2 689 | office31//webcam/images/bike_helmet/frame_0023.jpg 2 690 | office31//webcam/images/bike_helmet/frame_0024.jpg 2 691 | office31//webcam/images/bike_helmet/frame_0025.jpg 2 692 | office31//webcam/images/bike_helmet/frame_0026.jpg 2 693 | office31//webcam/images/bike_helmet/frame_0027.jpg 2 694 | office31//webcam/images/bike_helmet/frame_0028.jpg 2 695 | office31//webcam/images/headphones/frame_0001.jpg 10 696 | office31//webcam/images/headphones/frame_0002.jpg 10 697 | office31//webcam/images/headphones/frame_0003.jpg 10 698 | office31//webcam/images/headphones/frame_0004.jpg 10 699 | office31//webcam/images/headphones/frame_0005.jpg 10 700 | office31//webcam/images/headphones/frame_0006.jpg 10 701 | office31//webcam/images/headphones/frame_0007.jpg 10 702 | office31//webcam/images/headphones/frame_0008.jpg 10 703 | office31//webcam/images/headphones/frame_0009.jpg 10 704 | office31//webcam/images/headphones/frame_0010.jpg 10 705 | office31//webcam/images/headphones/frame_0011.jpg 10 706 | office31//webcam/images/headphones/frame_0012.jpg 10 707 | office31//webcam/images/headphones/frame_0013.jpg 10 708 | office31//webcam/images/headphones/frame_0014.jpg 10 709 | office31//webcam/images/headphones/frame_0015.jpg 10 710 | office31//webcam/images/headphones/frame_0016.jpg 10 711 | office31//webcam/images/headphones/frame_0017.jpg 10 712 | office31//webcam/images/headphones/frame_0018.jpg 10 713 | office31//webcam/images/headphones/frame_0019.jpg 10 714 | office31//webcam/images/headphones/frame_0020.jpg 10 715 | office31//webcam/images/headphones/frame_0021.jpg 10 716 | office31//webcam/images/headphones/frame_0022.jpg 10 717 | office31//webcam/images/headphones/frame_0023.jpg 10 718 | office31//webcam/images/headphones/frame_0024.jpg 10 719 | office31//webcam/images/headphones/frame_0025.jpg 10 720 | office31//webcam/images/headphones/frame_0026.jpg 10 721 | office31//webcam/images/headphones/frame_0027.jpg 10 722 | office31//webcam/images/desk_lamp/frame_0001.jpg 7 723 | office31//webcam/images/desk_lamp/frame_0002.jpg 7 724 | office31//webcam/images/desk_lamp/frame_0003.jpg 7 725 | office31//webcam/images/desk_lamp/frame_0004.jpg 7 726 | office31//webcam/images/desk_lamp/frame_0005.jpg 7 727 | office31//webcam/images/desk_lamp/frame_0006.jpg 7 728 | office31//webcam/images/desk_lamp/frame_0007.jpg 7 729 | office31//webcam/images/desk_lamp/frame_0008.jpg 7 730 | office31//webcam/images/desk_lamp/frame_0009.jpg 7 731 | office31//webcam/images/desk_lamp/frame_0010.jpg 7 732 | office31//webcam/images/desk_lamp/frame_0011.jpg 7 733 | office31//webcam/images/desk_lamp/frame_0012.jpg 7 734 | office31//webcam/images/desk_lamp/frame_0013.jpg 7 735 | office31//webcam/images/desk_lamp/frame_0014.jpg 7 736 | office31//webcam/images/desk_lamp/frame_0015.jpg 7 737 | office31//webcam/images/desk_lamp/frame_0016.jpg 7 738 | office31//webcam/images/desk_lamp/frame_0017.jpg 7 739 | office31//webcam/images/desk_lamp/frame_0018.jpg 7 740 | office31//webcam/images/desk_chair/frame_0001.jpg 6 741 | office31//webcam/images/desk_chair/frame_0002.jpg 6 742 | office31//webcam/images/desk_chair/frame_0003.jpg 6 743 | office31//webcam/images/desk_chair/frame_0004.jpg 6 744 | office31//webcam/images/desk_chair/frame_0005.jpg 6 745 | office31//webcam/images/desk_chair/frame_0006.jpg 6 746 | office31//webcam/images/desk_chair/frame_0007.jpg 6 747 | office31//webcam/images/desk_chair/frame_0008.jpg 6 748 | office31//webcam/images/desk_chair/frame_0009.jpg 6 749 | office31//webcam/images/desk_chair/frame_0010.jpg 6 750 | office31//webcam/images/desk_chair/frame_0011.jpg 6 751 | office31//webcam/images/desk_chair/frame_0012.jpg 6 752 | office31//webcam/images/desk_chair/frame_0013.jpg 6 753 | office31//webcam/images/desk_chair/frame_0014.jpg 6 754 | office31//webcam/images/desk_chair/frame_0015.jpg 6 755 | office31//webcam/images/desk_chair/frame_0016.jpg 6 756 | office31//webcam/images/desk_chair/frame_0017.jpg 6 757 | office31//webcam/images/desk_chair/frame_0018.jpg 6 758 | office31//webcam/images/desk_chair/frame_0019.jpg 6 759 | office31//webcam/images/desk_chair/frame_0020.jpg 6 760 | office31//webcam/images/desk_chair/frame_0021.jpg 6 761 | office31//webcam/images/desk_chair/frame_0022.jpg 6 762 | office31//webcam/images/desk_chair/frame_0023.jpg 6 763 | office31//webcam/images/desk_chair/frame_0024.jpg 6 764 | office31//webcam/images/desk_chair/frame_0025.jpg 6 765 | office31//webcam/images/desk_chair/frame_0026.jpg 6 766 | office31//webcam/images/desk_chair/frame_0027.jpg 6 767 | office31//webcam/images/desk_chair/frame_0028.jpg 6 768 | office31//webcam/images/desk_chair/frame_0029.jpg 6 769 | office31//webcam/images/desk_chair/frame_0030.jpg 6 770 | office31//webcam/images/desk_chair/frame_0031.jpg 6 771 | office31//webcam/images/desk_chair/frame_0032.jpg 6 772 | office31//webcam/images/desk_chair/frame_0033.jpg 6 773 | office31//webcam/images/desk_chair/frame_0034.jpg 6 774 | office31//webcam/images/desk_chair/frame_0035.jpg 6 775 | office31//webcam/images/desk_chair/frame_0036.jpg 6 776 | office31//webcam/images/desk_chair/frame_0037.jpg 6 777 | office31//webcam/images/desk_chair/frame_0038.jpg 6 778 | office31//webcam/images/desk_chair/frame_0039.jpg 6 779 | office31//webcam/images/desk_chair/frame_0040.jpg 6 780 | office31//webcam/images/bottle/frame_0001.jpg 4 781 | office31//webcam/images/bottle/frame_0002.jpg 4 782 | office31//webcam/images/bottle/frame_0003.jpg 4 783 | office31//webcam/images/bottle/frame_0004.jpg 4 784 | office31//webcam/images/bottle/frame_0005.jpg 4 785 | office31//webcam/images/bottle/frame_0006.jpg 4 786 | office31//webcam/images/bottle/frame_0007.jpg 4 787 | office31//webcam/images/bottle/frame_0008.jpg 4 788 | office31//webcam/images/bottle/frame_0009.jpg 4 789 | office31//webcam/images/bottle/frame_0010.jpg 4 790 | office31//webcam/images/bottle/frame_0011.jpg 4 791 | office31//webcam/images/bottle/frame_0012.jpg 4 792 | office31//webcam/images/bottle/frame_0013.jpg 4 793 | office31//webcam/images/bottle/frame_0014.jpg 4 794 | office31//webcam/images/bottle/frame_0015.jpg 4 795 | office31//webcam/images/bottle/frame_0016.jpg 4 796 | -------------------------------------------------------------------------------- /dataset/ASDADataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | from .image_list import ImageList 6 | 7 | class ASDADataset: 8 | # Active Semi-supervised DA Dataset class 9 | def __init__(self, dataset, domain, data_dir, num_classes, batch_size=128, num_workers=4, transforms=None): 10 | self.dataset = dataset 11 | self.domain = domain 12 | self.data_dir = data_dir 13 | self.batch_size = batch_size 14 | self.num_workers = num_workers 15 | self.num_classes = num_classes 16 | self.train_loader, self.valid_loader, self.test_loader = None, None, None 17 | 18 | self.build_dsets(transforms) 19 | 20 | def get_num_classes(self): 21 | return self.num_classes 22 | 23 | def get_dsets(self): 24 | return self.train_dataset, self.query_dataset, self.valid_dataset, self.test_dataset 25 | 26 | def build_dsets(self, transforms=None): 27 | assert transforms is not None 28 | 29 | if self.dataset == "domainnet": 30 | train_list = open(os.path.join(self.data_dir, "image_list", self.dataset, self.domain+"_train.txt")).readlines() 31 | test_list = open(os.path.join(self.data_dir, "image_list", self.dataset, self.domain+"_test.txt")).readlines() 32 | valid_list = train_list.copy() 33 | else: 34 | train_list = open(os.path.join(self.data_dir, "image_list", self.dataset, self.domain+".txt")).readlines() 35 | test_list = train_list.copy() 36 | valid_list = train_list.copy() 37 | 38 | train_dataset = ImageList(train_list, root=self.data_dir, transform=transforms['train']) 39 | query_dataset = ImageList(train_list, root=self.data_dir, transform=transforms['query']) 40 | valid_dataset = ImageList(valid_list, root=self.data_dir, transform=transforms['test']) 41 | test_dataset = ImageList(test_list, root=self.data_dir, transform=transforms['test']) 42 | 43 | self.train_dataset = train_dataset 44 | self.query_dataset = query_dataset 45 | self.valid_dataset = valid_dataset 46 | self.test_dataset = test_dataset 47 | 48 | 49 | def get_loaders(self, valid_type='val', valid_ratio=1.0, rebuilt=False): 50 | 51 | if self.train_loader and self.valid_loader and self.test_loader and not rebuilt: 52 | return self.train_loader, self.valid_loader, self.test_loader 53 | 54 | num_train = len(self.train_dataset) 55 | self.train_size = num_train 56 | 57 | if valid_type == 'split': 58 | indices = list(range(num_train)) 59 | split = int(np.floor(valid_ratio * num_train)) 60 | np.random.shuffle(indices) 61 | train_idx, valid_idx = indices[split:], indices[:split] 62 | 63 | elif valid_type == 'val': 64 | train_idx = np.arange(len(self.train_dataset)) 65 | if valid_ratio == 1.0: 66 | valid_idx = np.arange(len(self.valid_dataset)) 67 | else: 68 | indices = list(range(len(self.valid_dataset))) 69 | split = int(np.floor(valid_ratio * num_train)) 70 | np.random.shuffle(indices) 71 | valid_idx = indices[:split] 72 | else: 73 | raise NotImplementedError 74 | 75 | train_sampler = SubsetRandomSampler(train_idx) 76 | valid_sampler = SubsetRandomSampler(valid_idx) 77 | 78 | self.train_loader = torch.utils.data.DataLoader(self.train_dataset, sampler=train_sampler, \ 79 | batch_size=self.batch_size, num_workers=self.num_workers) 80 | self.valid_loader = torch.utils.data.DataLoader(self.valid_dataset, sampler=valid_sampler, batch_size=self.batch_size, 81 | num_workers=self.num_workers) 82 | self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers) 83 | 84 | self.train_idx = train_idx 85 | self.valid_idx = valid_idx 86 | 87 | return self.train_loader, self.valid_loader, self.test_loader -------------------------------------------------------------------------------- /dataset/image_list.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os.path as osp 3 | import numpy as np 4 | 5 | def make_dataset(image_list, labels): 6 | if labels: 7 | len_ = len(image_list) 8 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)] 9 | else: 10 | if type(image_list[0]) is tuple: 11 | return image_list 12 | 13 | if len(image_list[0].split()) > 2: 14 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list] 15 | else: 16 | images = [(val.split()[0], int(val.split()[1])) for val in image_list] 17 | return images 18 | 19 | 20 | def pil_loader(root, path): 21 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 22 | with open(osp.join(root, path), 'rb') as f: 23 | img = Image.open(f) 24 | return img.convert('RGB') 25 | 26 | 27 | class ImageList(object): 28 | """A generic data loader where the images are arranged in this way: :: 29 | root/dog/xxx.png 30 | root/dog/xxy.png 31 | root/dog/xxz.png 32 | root/cat/123.png 33 | root/cat/nsdf3.png 34 | root/cat/asd932_.png 35 | Args: 36 | root (string): Root directory path. 37 | transform (callable, optional): A function/transform that takes in an PIL image 38 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 39 | target_transform (callable, optional): A function/transform that takes in the 40 | target and transforms it. 41 | loader (callable, optional): A function to load an image given its path. 42 | Attributes: 43 | classes (list): List of the class names. 44 | class_to_idx (dict): Dict with items (class_name, class_index). 45 | imgs (list): List of (image path, class_index) tuples 46 | """ 47 | 48 | def __init__(self, image_list, labels=None, root='data', transform=None, target_transform=None, rand_transform=None): 49 | samples = make_dataset(image_list, labels) 50 | self.samples = samples 51 | self.transform = transform 52 | self.target_transform = target_transform 53 | self.rand_transform = rand_transform 54 | self.rand_num = 0 55 | self.loader = pil_loader 56 | self.root = root 57 | 58 | def __getitem__(self, index): 59 | 60 | path, target = self.samples[index] 61 | target = int(target) 62 | 63 | sample_ = self.loader(self.root, path) 64 | 65 | if self.transform is not None: 66 | sample = self.transform(sample_) 67 | if self.target_transform is not None: 68 | target = self.target_transform(target) 69 | if self.rand_transform is not None: 70 | rand_sample = [] 71 | for i in range(self.rand_num): 72 | rand_sample.append(self.rand_transform(sample_)) 73 | 74 | return sample, target, index, *rand_sample 75 | else: 76 | return sample, target, index 77 | 78 | def __len__(self): 79 | return len(self.samples) 80 | 81 | def add_item(self, addition): 82 | # self.samples = np.concatenate((self.samples, addition), axis=0) 83 | self.samples.extend(addition) 84 | return self.samples 85 | 86 | def remove_item(self, reduced): 87 | self.samples = np.delete(self.samples, reduced, axis=0) 88 | return self.samples 89 | -------------------------------------------------------------------------------- /dataset/randaugment.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def ShearX(img, v): # [-0.3, 0.3] 12 | assert -0.3 <= v <= 0.3 13 | if random.random() > 0.5: 14 | v = -v 15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 16 | 17 | 18 | def ShearY(img, v): # [-0.3, 0.3] 19 | assert -0.3 <= v <= 0.3 20 | if random.random() > 0.5: 21 | v = -v 22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 23 | 24 | 25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 26 | assert -0.45 <= v <= 0.45 27 | if random.random() > 0.5: 28 | v = -v 29 | v = v * img.size[0] 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 31 | 32 | 33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert 0 <= v 35 | if random.random() > 0.5: 36 | v = -v 37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 38 | 39 | 40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 41 | assert -0.45 <= v <= 0.45 42 | if random.random() > 0.5: 43 | v = -v 44 | v = v * img.size[1] 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | 48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 53 | 54 | 55 | def Rotate(img, v): # [-30, 30] 56 | assert -30 <= v <= 30 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.rotate(v) 60 | 61 | 62 | def AutoContrast(img, _): 63 | return PIL.ImageOps.autocontrast(img) 64 | 65 | 66 | def Invert(img, _): 67 | return PIL.ImageOps.invert(img) 68 | 69 | 70 | def Equalize(img, _): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Flip(img, _): # not from the paper 75 | return PIL.ImageOps.mirror(img) 76 | 77 | 78 | def Solarize(img, v): # [0, 256] 79 | assert 0 <= v <= 256 80 | return PIL.ImageOps.solarize(img, v) 81 | 82 | 83 | def SolarizeAdd(img, addition=0, threshold=128): 84 | img_np = np.array(img).astype(np.int) 85 | img_np = img_np + addition 86 | img_np = np.clip(img_np, 0, 255) 87 | img_np = img_np.astype(np.uint8) 88 | img = Image.fromarray(img_np) 89 | return PIL.ImageOps.solarize(img, threshold) 90 | 91 | 92 | def Posterize(img, v): # [4, 8] 93 | v = int(v) 94 | v = max(1, v) 95 | return PIL.ImageOps.posterize(img, v) 96 | 97 | 98 | def Contrast(img, v): # [0.1,1.9] 99 | assert 0.1 <= v <= 1.9 100 | return PIL.ImageEnhance.Contrast(img).enhance(v) 101 | 102 | 103 | def Color(img, v): # [0.1,1.9] 104 | assert 0.1 <= v <= 1.9 105 | return PIL.ImageEnhance.Color(img).enhance(v) 106 | 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | 113 | def Sharpness(img, v): # [0.1,1.9] 114 | assert 0.1 <= v <= 1.9 115 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 116 | 117 | 118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 119 | assert 0.0 <= v <= 0.2 120 | if v <= 0.: 121 | return img 122 | 123 | v = v * img.size[0] 124 | return CutoutAbs(img, v) 125 | 126 | 127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 128 | # assert 0 <= v <= 20 129 | if v < 0: 130 | return img 131 | w, h = img.size 132 | x0 = np.random.uniform(w) 133 | y0 = np.random.uniform(h) 134 | 135 | x0 = int(max(0, x0 - v / 2.)) 136 | y0 = int(max(0, y0 - v / 2.)) 137 | x1 = min(w, x0 + v) 138 | y1 = min(h, y0 + v) 139 | 140 | xy = (x0, y0, x1, y1) 141 | color = (125, 123, 114) 142 | # color = (0, 0, 0) 143 | img = img.copy() 144 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 145 | return img 146 | 147 | 148 | def SamplePairing(imgs): # [0, 0.4] 149 | def f(img1, v): 150 | i = np.random.choice(len(imgs)) 151 | img2 = PIL.Image.fromarray(imgs[i]) 152 | return PIL.Image.blend(img1, img2, v) 153 | 154 | return f 155 | 156 | 157 | def Identity(img, v): 158 | return img 159 | 160 | 161 | def augment_list(): # 16 oeprations and their ranges 162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 163 | # l = [ 164 | # (Identity, 0., 1.0), 165 | # (ShearX, 0., 0.3), # 0 166 | # (ShearY, 0., 0.3), # 1 167 | # (TranslateX, 0., 0.33), # 2 168 | # (TranslateY, 0., 0.33), # 3 169 | # (Rotate, 0, 30), # 4 170 | # (AutoContrast, 0, 1), # 5 171 | # (Invert, 0, 1), # 6 172 | # (Equalize, 0, 1), # 7 173 | # (Solarize, 0, 110), # 8 174 | # (Posterize, 4, 8), # 9 175 | # # (Contrast, 0.1, 1.9), # 10 176 | # (Color, 0.1, 1.9), # 11 177 | # (Brightness, 0.1, 1.9), # 12 178 | # (Sharpness, 0.1, 1.9), # 13 179 | # # (Cutout, 0, 0.2), # 14 180 | # # (SamplePairing(imgs), 0, 0.4), # 15 181 | # ] 182 | 183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 184 | l = [ 185 | (AutoContrast, 0, 1), 186 | (Equalize, 0, 1), 187 | (Invert, 0, 1), 188 | (Rotate, 0, 30), 189 | (Posterize, 0, 4), 190 | (Solarize, 0, 256), 191 | (SolarizeAdd, 0, 110), 192 | (Color, 0.1, 1.9), 193 | (Contrast, 0.1, 1.9), 194 | (Brightness, 0.1, 1.9), 195 | (Sharpness, 0.1, 1.9), 196 | (ShearX, 0., 0.3), 197 | (ShearY, 0., 0.3), 198 | (CutoutAbs, 0, 40), 199 | (TranslateXabs, 0., 100), 200 | (TranslateYabs, 0., 100), 201 | ] 202 | 203 | return l 204 | 205 | 206 | # class Lighting(object): 207 | # """Lighting noise(AlexNet - style PCA - based noise)""" 208 | # 209 | # def __init__(self, alphastd, eigval, eigvec): 210 | # self.alphastd = alphastd 211 | # self.eigval = torch.Tensor(eigval) 212 | # self.eigvec = torch.Tensor(eigvec) 213 | # 214 | # def __call__(self, img): 215 | # if self.alphastd == 0: 216 | # return img 217 | # 218 | # alpha = img.new().resize_(3).normal_(0, self.alphastd) 219 | # rgb = self.eigvec.type_as(img).clone() \ 220 | # .mul(alpha.view(1, 3).expand(3, 3)) \ 221 | # .mul(self.eigval.view(1, 3).expand(3, 3)) \ 222 | # .sum(1).squeeze() 223 | # 224 | # return img.add(rgb.view(3, 1, 1).expand_as(img)) 225 | 226 | 227 | # class CutoutDefault(object): 228 | # """ 229 | # Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 230 | # """ 231 | # def __init__(self, length): 232 | # self.length = length 233 | # 234 | # def __call__(self, img): 235 | # h, w = img.size(1), img.size(2) 236 | # mask = np.ones((h, w), np.float32) 237 | # y = np.random.randint(h) 238 | # x = np.random.randint(w) 239 | # 240 | # y1 = np.clip(y - self.length // 2, 0, h) 241 | # y2 = np.clip(y + self.length // 2, 0, h) 242 | # x1 = np.clip(x - self.length // 2, 0, w) 243 | # x2 = np.clip(x + self.length // 2, 0, w) 244 | # 245 | # mask[y1: y2, x1: x2] = 0. 246 | # mask = torch.from_numpy(mask) 247 | # mask = mask.expand_as(img) 248 | # img *= mask 249 | # return img 250 | 251 | 252 | class RandAugment: 253 | def __init__(self, n, m): 254 | self.n = n 255 | self.m = m # [0, 30] 256 | self.augment_list = augment_list() 257 | 258 | def __call__(self, img): 259 | 260 | if self.n == 0: 261 | return img 262 | 263 | ops = random.choices(self.augment_list, k=self.n) 264 | for op, minval, maxval in ops: 265 | val = (float(self.m) / 30) * float(maxval - minval) + minval 266 | img = op(img, val) 267 | 268 | return img -------------------------------------------------------------------------------- /dataset/randaugmentMC.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from 2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py 4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py 5 | import logging 6 | import random 7 | 8 | import numpy as np 9 | import PIL 10 | import PIL.ImageOps 11 | import PIL.ImageEnhance 12 | import PIL.ImageDraw 13 | from PIL import Image 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | PARAMETER_MAX = 10 18 | 19 | 20 | def AutoContrast(img, **kwarg): 21 | return PIL.ImageOps.autocontrast(img) 22 | 23 | 24 | def Brightness(img, v, max_v, bias=0): 25 | v = _float_parameter(v, max_v) + bias 26 | return PIL.ImageEnhance.Brightness(img).enhance(v) 27 | 28 | 29 | def Color(img, v, max_v, bias=0): 30 | v = _float_parameter(v, max_v) + bias 31 | return PIL.ImageEnhance.Color(img).enhance(v) 32 | 33 | 34 | def Contrast(img, v, max_v, bias=0): 35 | v = _float_parameter(v, max_v) + bias 36 | return PIL.ImageEnhance.Contrast(img).enhance(v) 37 | 38 | 39 | def Cutout(img, v, max_v, bias=0): 40 | if v == 0: 41 | return img 42 | v = _float_parameter(v, max_v) + bias 43 | v = int(v * min(img.size)) 44 | return CutoutAbs(img, v) 45 | 46 | 47 | def CutoutAbs(img, v, **kwarg): 48 | w, h = img.size 49 | x0 = np.random.uniform(0, w) 50 | y0 = np.random.uniform(0, h) 51 | x0 = int(max(0, x0 - v / 2.)) 52 | y0 = int(max(0, y0 - v / 2.)) 53 | x1 = int(min(w, x0 + v)) 54 | y1 = int(min(h, y0 + v)) 55 | xy = (x0, y0, x1, y1) 56 | # gray 57 | color = (127, 127, 127) 58 | img = img.copy() 59 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 60 | return img 61 | 62 | 63 | def Equalize(img, **kwarg): 64 | return PIL.ImageOps.equalize(img) 65 | 66 | 67 | def Identity(img, **kwarg): 68 | return img 69 | 70 | 71 | def Invert(img, **kwarg): 72 | return PIL.ImageOps.invert(img) 73 | 74 | 75 | def Posterize(img, v, max_v, bias=0): 76 | v = _int_parameter(v, max_v) + bias 77 | return PIL.ImageOps.posterize(img, v) 78 | 79 | 80 | def Rotate(img, v, max_v, bias=0): 81 | v = _int_parameter(v, max_v) + bias 82 | if random.random() < 0.5: 83 | v = -v 84 | return img.rotate(v) 85 | 86 | 87 | def Sharpness(img, v, max_v, bias=0): 88 | v = _float_parameter(v, max_v) + bias 89 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 90 | 91 | 92 | def ShearX(img, v, max_v, bias=0): 93 | v = _float_parameter(v, max_v) + bias 94 | if random.random() < 0.5: 95 | v = -v 96 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 97 | 98 | 99 | def ShearY(img, v, max_v, bias=0): 100 | v = _float_parameter(v, max_v) + bias 101 | if random.random() < 0.5: 102 | v = -v 103 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 104 | 105 | 106 | def Solarize(img, v, max_v, bias=0): 107 | v = _int_parameter(v, max_v) + bias 108 | return PIL.ImageOps.solarize(img, 256 - v) 109 | 110 | 111 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 112 | v = _int_parameter(v, max_v) + bias 113 | if random.random() < 0.5: 114 | v = -v 115 | img_np = np.array(img).astype(np.int) 116 | img_np = img_np + v 117 | img_np = np.clip(img_np, 0, 255) 118 | img_np = img_np.astype(np.uint8) 119 | img = Image.fromarray(img_np) 120 | return PIL.ImageOps.solarize(img, threshold) 121 | 122 | 123 | def TranslateX(img, v, max_v, bias=0): 124 | v = _float_parameter(v, max_v) + bias 125 | if random.random() < 0.5: 126 | v = -v 127 | v = int(v * img.size[0]) 128 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 129 | 130 | 131 | def TranslateY(img, v, max_v, bias=0): 132 | v = _float_parameter(v, max_v) + bias 133 | if random.random() < 0.5: 134 | v = -v 135 | v = int(v * img.size[1]) 136 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 137 | 138 | 139 | def _float_parameter(v, max_v): 140 | return float(v) * max_v / PARAMETER_MAX 141 | 142 | 143 | def _int_parameter(v, max_v): 144 | return int(v * max_v / PARAMETER_MAX) 145 | 146 | 147 | def fixmatch_augment_pool(): 148 | # FixMatch paper 149 | augs = [(AutoContrast, None, None), 150 | (Brightness, 0.9, 0.05), 151 | (Color, 0.9, 0.05), 152 | (Contrast, 0.9, 0.05), 153 | (Equalize, None, None), 154 | (Identity, None, None), 155 | (Posterize, 4, 4), 156 | (Rotate, 30, 0), 157 | (Sharpness, 0.9, 0.05), 158 | (ShearX, 0.3, 0), 159 | (ShearY, 0.3, 0), 160 | (Solarize, 256, 0), 161 | (TranslateX, 0.3, 0), 162 | (TranslateY, 0.3, 0)] 163 | return augs 164 | 165 | 166 | def my_augment_pool(): 167 | # Test 168 | augs = [(AutoContrast, None, None), 169 | (Brightness, 1.8, 0.1), 170 | (Color, 1.8, 0.1), 171 | (Contrast, 1.8, 0.1), 172 | (Cutout, 0.2, 0), 173 | (Equalize, None, None), 174 | (Invert, None, None), 175 | (Posterize, 4, 4), 176 | (Rotate, 30, 0), 177 | (Sharpness, 1.8, 0.1), 178 | (ShearX, 0.3, 0), 179 | (ShearY, 0.3, 0), 180 | (Solarize, 256, 0), 181 | (SolarizeAdd, 110, 0), 182 | (TranslateX, 0.45, 0), 183 | (TranslateY, 0.45, 0)] 184 | return augs 185 | 186 | 187 | class RandAugmentPC(object): 188 | def __init__(self, n, m): 189 | assert n >= 1 190 | assert 1 <= m <= 10 191 | self.n = n 192 | self.m = m 193 | self.augment_pool = my_augment_pool() 194 | 195 | def __call__(self, img): 196 | ops = random.choices(self.augment_pool, k=self.n) 197 | for op, max_v, bias in ops: 198 | prob = np.random.uniform(0.2, 0.8) 199 | if random.random() + prob >= 1: 200 | img = op(img, v=self.m, max_v=max_v, bias=bias) 201 | img = CutoutAbs(img, 16) 202 | return img 203 | 204 | 205 | class RandAugmentMC(object): 206 | def __init__(self, n, m): 207 | assert n >= 1 208 | assert 1 <= m <= 10 209 | self.n = n 210 | self.m = m 211 | self.augment_pool = fixmatch_augment_pool() 212 | 213 | def __call__(self, img): 214 | ops = random.choices(self.augment_pool, k=self.n) 215 | for op, max_v, bias in ops: 216 | v = np.random.randint(1, self.m) 217 | if random.random() < 0.5: 218 | img = op(img, v=v, max_v=max_v, bias=bias) 219 | img = CutoutAbs(img, 16) 220 | return img 221 | -------------------------------------------------------------------------------- /dataset/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | from .randaugment import RandAugment 4 | from torchvision.transforms import (Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop, 5 | RandomResizedCrop, RandomHorizontalFlip) 6 | from .randaugmentMC import RandAugmentMC 7 | 8 | class ResizeImage(): 9 | def __init__(self, size): 10 | if isinstance(size, int): 11 | self.size = (int(size), int(size)) 12 | else: 13 | self.size = size 14 | 15 | def __call__(self, img): 16 | th, tw = self.size 17 | return img.resize((th, tw)) 18 | 19 | 20 | def build_transforms(cfg, domain='source'): 21 | if cfg.DATASET.NAME in ["mnist", "svhn"]: 22 | transforms = None 23 | else: 24 | # train 25 | choices = cfg.DATASET.SOURCE_TRANSFORMS if domain=='source' else cfg.DATASET.TARGET_TRANSFORMS 26 | train_transform = build_transform(choices) 27 | # query 28 | choices = cfg.DATASET.QUERY_TRANSFORMS 29 | query_transform = build_transform(choices) 30 | # test 31 | choices = cfg.DATASET.TEST_TRANSFORMS 32 | test_transform = build_transform(choices) 33 | 34 | transforms = {'train':train_transform, 'query':query_transform, 'test':test_transform} 35 | 36 | return transforms 37 | 38 | def build_transform(choices): 39 | transform = [] 40 | if 'Resize' in choices: 41 | transform += [Resize((256, 256))] # make sure resize to equal length and width 42 | 43 | if 'ResizeImage' in choices: 44 | transform += [ResizeImage(256)] 45 | 46 | if 'RandomHorizontalFlip' in choices: 47 | transform += [RandomHorizontalFlip(p=0.5)] 48 | 49 | if 'RandomCrop' in choices: 50 | transform += [RandomCrop(224)] 51 | 52 | if 'RandomResizedCrop' in choices: 53 | transform += [RandomResizedCrop(224)] 54 | 55 | if 'CenterCrop' in choices: 56 | transform += [CenterCrop(224)] 57 | 58 | 59 | transform += [ToTensor()] 60 | 61 | if 'Normalize' in choices: 62 | normalize = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 63 | 64 | transform += [normalize] 65 | 66 | return Compose(transform) 67 | 68 | 69 | rand_transform = transforms.Compose([ 70 | RandAugment(1, 2.0), 71 | transforms.Resize(256), 72 | transforms.RandomHorizontalFlip(0.5), 73 | transforms.RandomCrop(224), 74 | transforms.ToTensor(), 75 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 76 | std=[0.229, 0.224, 0.225]), 77 | ]) 78 | 79 | rand_transform2 = transforms.Compose([ 80 | ResizeImage(256), 81 | transforms.RandomHorizontalFlip(), 82 | transforms.RandomCrop(224), 83 | RandAugmentMC(n=2, m=10), 84 | transforms.ToTensor(), 85 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 86 | ]) -------------------------------------------------------------------------------- /fig/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tsun/LADA/5de718ecd0b1eccc337ff54d4cd854ffb580d7e2/fig/framework.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import logging 4 | import os 5 | import time 6 | import copy 7 | import pprint as pp 8 | from tqdm import tqdm, trange 9 | from collections import defaultdict 10 | import numpy as np 11 | import shutil 12 | import socket 13 | hostName = socket.gethostname() 14 | pid = os.getpid() 15 | 16 | from config.defaults import _C as cfg 17 | from utils.logger import init_logger 18 | from utils.utils import resetRNGseed 19 | from dataset.ASDADataset import ASDADataset 20 | from dataset.transform import build_transforms 21 | 22 | from model import get_model 23 | import utils.utils as utils 24 | 25 | from active.sampler import get_strategy 26 | from active.budget import BudgetAllocator 27 | 28 | 29 | # repeatability 30 | torch.backends.cudnn.deterministic = True 31 | torch.backends.cudnn.benchmark = False 32 | 33 | def run_active_adaptation(cfg, source_model, src_dset, num_classes, device): 34 | source = cfg.DATASET.SOURCE_DOMAIN 35 | target = cfg.DATASET.TARGET_DOMAIN 36 | da_strat = cfg.ADA.DA 37 | al_strat = cfg.ADA.AL 38 | 39 | transforms = build_transforms(cfg, 'target') 40 | tgt_dset = ASDADataset(cfg.DATASET.NAME, cfg.DATASET.TARGET_DOMAIN, data_dir=cfg.DATASET.ROOT, 41 | num_classes=cfg.DATASET.NUM_CLASS, batch_size=cfg.DATALOADER.BATCH_SIZE, 42 | num_workers=cfg.DATALOADER.NUM_WORKERS, transforms=transforms) 43 | 44 | target_test_loader = tgt_dset.get_loaders()[2] 45 | 46 | # Evaluate source model on target test 47 | if cfg.TRAINER.EVAL_ACC: 48 | transfer_perf, _ = utils.test(source_model, device, target_test_loader) 49 | logging.info('{}->{} performance (Before {}): Task={:.2f}'.format(source, target, da_strat, transfer_perf)) 50 | 51 | 52 | # Main Active DA loop 53 | logging.info('------------------------------------------------------') 54 | model_init = 'source' if cfg.TRAINER.TRAIN_ON_SOURCE else 'scratch' 55 | logging.info('Running strategy: Init={} AL={} DA={}'.format(model_init, al_strat, da_strat)) 56 | logging.info('------------------------------------------------------') 57 | 58 | # Run unsupervised DA at round 0, where applicable 59 | start_perf = 0. 60 | 61 | # Instantiate active sampling strategy 62 | sampling_strategy = get_strategy(al_strat, src_dset, tgt_dset, source_model, device, num_classes, cfg) 63 | del source_model 64 | 65 | if cfg.TRAINER.MAX_UDA_EPOCHS > 0: 66 | target_model = sampling_strategy.train_uda(epochs=cfg.TRAINER.MAX_UDA_EPOCHS) 67 | 68 | # Evaluate adapted source model on target test 69 | if cfg.TRAINER.EVAL_ACC: 70 | start_perf, _ = utils.test(target_model, device, target_test_loader) 71 | logging.info('{}->{} performance (After {}): {:.2f}'.format(source, target, da_strat, start_perf)) 72 | logging.info('------------------------------------------------------') 73 | 74 | 75 | # Run Active DA 76 | # Keep track of labeled vs unlabeled data 77 | idxs_lb = np.zeros(len(tgt_dset.train_idx), dtype=bool) 78 | 79 | budget = np.round(len(tgt_dset.train_idx) * cfg.ADA.BUDGET) if cfg.ADA.BUDGET <= 1.0 else np.round(cfg.ADA.BUDGET) 80 | budget_allocator = BudgetAllocator(budget=budget, cfg=cfg) 81 | 82 | tqdm_rat = trange(cfg.TRAINER.MAX_EPOCHS) 83 | target_accs = defaultdict(list) 84 | target_accs[0.0].append(start_perf) 85 | 86 | for epoch in tqdm_rat: 87 | 88 | curr_budget, used_budget = budget_allocator.get_budget(epoch) 89 | tqdm_rat.set_description('# Target labels={} Allowed labels={}'.format(used_budget, curr_budget)) 90 | tqdm_rat.refresh() 91 | 92 | # Select instances via AL strategy 93 | if curr_budget > 0: 94 | logging.info('Selecting instances...') 95 | idxs = sampling_strategy.query(curr_budget, epoch) 96 | idxs_lb[idxs] = True 97 | sampling_strategy.update(idxs_lb) 98 | else: 99 | logging.info('No budget for current epoch, skipped...') 100 | 101 | # Update model with new data via DA strategy 102 | target_model = sampling_strategy.train(epoch=epoch) 103 | 104 | # Evaluate on target test and train splits 105 | if cfg.TRAINER.EVAL_ACC: 106 | test_perf, _ = utils.test(target_model, device, target_test_loader) 107 | out_str = '{}->{} Test performance (Epoch {}, # Target labels={:d}): {:.2f}'.format(source, target, epoch, 108 | int(curr_budget+used_budget), test_perf) 109 | logging.info(out_str) 110 | logging.info('------------------------------------------------------') 111 | 112 | target_accs[curr_budget+used_budget].append(test_perf) 113 | 114 | 115 | logging.info("\n{}".format(target_accs)) 116 | 117 | return target_accs 118 | 119 | 120 | def ADAtrain(cfg, task): 121 | logging.info("Running task: {}".format(task)) 122 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 123 | transforms = build_transforms(cfg, 'source') 124 | src_dset = ASDADataset(cfg.DATASET.NAME, cfg.DATASET.SOURCE_DOMAIN, data_dir=cfg.DATASET.ROOT, 125 | num_classes=cfg.DATASET.NUM_CLASS, batch_size=cfg.DATALOADER.BATCH_SIZE, 126 | num_workers=cfg.DATALOADER.NUM_WORKERS, transforms=transforms) 127 | 128 | src_train_loader, src_valid_loader, src_test_loader = src_dset.get_loaders(valid_type=cfg.DATASET.SOURCE_VALID_TYPE, 129 | valid_ratio=cfg.DATASET.SOURCE_VALID_RATIO) 130 | 131 | # model 132 | source_model = get_model(cfg.MODEL.BACKBONE.NAME, num_cls=cfg.DATASET.NUM_CLASS, normalize=cfg.MODEL.NORMALIZE, temp=cfg.MODEL.TEMP).to(device) 133 | source_file = '{}_{}_source_{}.pth'.format(cfg.DATASET.SOURCE_DOMAIN, cfg.MODEL.BACKBONE.NAME, cfg.TRAINER.MAX_SOURCE_EPOCHS) 134 | source_dir = os.path.join('checkpoints', 'source') 135 | if not os.path.exists(source_dir): 136 | os.makedirs(source_dir) 137 | source_path = os.path.join(source_dir, source_file) 138 | best_source_file = '{}_{}_source_best_{}.pth'.format(cfg.DATASET.SOURCE_DOMAIN, cfg.MODEL.BACKBONE.NAME, cfg.TRAINER.MAX_SOURCE_EPOCHS) 139 | best_source_path = os.path.join(source_dir, best_source_file) 140 | 141 | if cfg.TRAINER.TRAIN_ON_SOURCE and cfg.TRAINER.MAX_SOURCE_EPOCHS>0: 142 | if cfg.TRAINER.LOAD_FROM_CHECKPOINT and os.path.exists(source_path): 143 | logging.info('Loading source checkpoint: {}'.format(source_path)) 144 | source_model.load_state_dict(torch.load(source_path, map_location=device), strict=False) 145 | best_source_model = source_model 146 | else: 147 | logging.info('Training {} model...'.format(cfg.DATASET.SOURCE_DOMAIN)) 148 | best_val_acc, best_source_model = 0.0, None 149 | source_optimizer = utils.get_optim(cfg.OPTIM.SOURCE_NAME, source_model.parameters(cfg.OPTIM.SOURCE_LR, cfg.OPTIM.BASE_LR_MULT), lr=cfg.OPTIM.SOURCE_LR) 150 | 151 | for epoch in range(cfg.TRAINER.MAX_SOURCE_EPOCHS): 152 | utils.train(source_model, device, src_train_loader, source_optimizer, epoch) 153 | 154 | val_acc, _ = utils.test(source_model, device, src_valid_loader, split="source valid") 155 | logging.info('[Epoch: {}] Valid Accuracy: {:.3f} '.format(epoch, val_acc)) 156 | 157 | if (val_acc > best_val_acc): 158 | best_val_acc = val_acc 159 | best_source_model = copy.deepcopy(source_model) 160 | torch.save(best_source_model.state_dict(), best_source_path) 161 | 162 | del source_model 163 | # rename file in case of abnormal exit 164 | shutil.move(best_source_path, source_path) 165 | else: 166 | best_source_model = source_model 167 | 168 | # Evaluate on source test set 169 | if cfg.TRAINER.EVAL_ACC: 170 | test_acc, _ = utils.test(best_source_model, device, src_test_loader, split="source test") 171 | logging.info('{} Test Accuracy: {:.3f} '.format(cfg.DATASET.SOURCE_DOMAIN, test_acc)) 172 | 173 | # Run active adaptation experiments 174 | target_accs = run_active_adaptation(cfg, best_source_model, src_dset, cfg.DATASET.NUM_CLASS, device) 175 | pp.pprint(target_accs) 176 | 177 | 178 | 179 | def main(): 180 | parser = argparse.ArgumentParser(description='Optimal Budget Allocation for Active Domain Adaptation') 181 | parser.add_argument('--cfg', default='', metavar='FILE', help='path to config file', type=str) 182 | parser.add_argument('--timestamp', default=time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()), 183 | type=str, help='timestamp') 184 | parser.add_argument('--gpu', default='0', type=str, help='which gpu to use') 185 | parser.add_argument('--note', default=None, type=str, help='note to experiment') 186 | parser.add_argument('--log', default='./log', type=str, help='logging directory') 187 | parser.add_argument('--nolog', action='store_true', help='whether use logger') 188 | parser.add_argument("opts", help="Modify config options using the command-line", 189 | default=None, nargs=argparse.REMAINDER) 190 | 191 | args = parser.parse_args() 192 | 193 | cfg.merge_from_file(args.cfg) 194 | cfg.merge_from_list(args.opts) 195 | 196 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 197 | 198 | logger = "{}_{}_{}_{}".format(args.timestamp, cfg.DATASET.NAME, cfg.ADA.AL, cfg.ADA.DA) if not args.nolog else None 199 | init_logger(logger, dir=args.log) 200 | 201 | if args.note is not None: 202 | logging.info("Experiment note : {}".format(args.note)) 203 | logging.info('------------------------------------------------------') 204 | logging.info("Running on {} gpu={} pid={}".format(hostName, args.gpu, pid)) 205 | logging.info(cfg) 206 | logging.info('------------------------------------------------------') 207 | 208 | if type(cfg.SEED) is tuple or type(cfg.SEED) is list: 209 | seeds = cfg.SEED 210 | else: 211 | seeds = [cfg.SEED] 212 | 213 | for seed in seeds: 214 | logging.info("Using random seed: {}".format(seed)) 215 | resetRNGseed(seed) 216 | 217 | if cfg.ADA.TASKS is not None: 218 | ada_tasks = cfg.ADA.TASKS 219 | else: 220 | ada_tasks = [[source, target] for source in cfg.DATASET.SOURCE_DOMAINS 221 | for target in cfg.DATASET.TARGET_DOMAINS if source != target] 222 | 223 | for [source, target] in ada_tasks: 224 | cfg.DATASET.SOURCE_DOMAIN = source 225 | cfg.DATASET.TARGET_DOMAIN = target 226 | 227 | cfg.freeze() 228 | ADAtrain(cfg, task=source + '-->' + target) 229 | cfg.defrost() 230 | 231 | 232 | if __name__ == '__main__': 233 | main() -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import get_model 2 | from .adaptor import AdaptNet 3 | from .network import TaskNet -------------------------------------------------------------------------------- /model/adaptor.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .models import register_model, get_model 6 | 7 | @register_model('AdaptNet') 8 | class AdaptNet(nn.Module): 9 | "Defines an Adapt Network." 10 | 11 | def __init__(self, num_cls=10, model='LeNet', src_weights_init=None, weights_init=None, weight_sharing='full', 12 | normalize=False, temp=0.05): 13 | super(AdaptNet, self).__init__() 14 | self.name = 'AdaptNet' 15 | self.base_model = model 16 | 17 | self.num_cls = num_cls 18 | self.cls_criterion = nn.CrossEntropyLoss() 19 | self.gan_criterion = nn.CrossEntropyLoss() 20 | self.weight_sharing = weight_sharing 21 | self.normalize = normalize 22 | self.temp = temp 23 | self.setup_net() 24 | if weights_init is not None: 25 | self.load(weights_init) 26 | elif src_weights_init is not None: 27 | self.load_src_net(src_weights_init) 28 | else: 29 | raise Exception('AdaptNet must be initialized with weights.') 30 | 31 | def custom_copy(self, src_net, weight_sharing): 32 | """ 33 | Vary degree of weight sharing between source and target CNN's 34 | """ 35 | tgt_net = copy.deepcopy(src_net) 36 | if weight_sharing != 'None': 37 | if weight_sharing == 'classifier': 38 | tgt_net.classifier = src_net.classifier 39 | elif weight_sharing == 'full': 40 | tgt_net = src_net 41 | return tgt_net 42 | 43 | def setup_net(self): 44 | """Setup source, target and discriminator networks.""" 45 | self.src_net = get_model(self.base_model, num_cls=self.num_cls, normalize=self.normalize, temp=self.temp) 46 | self.tgt_net = self.custom_copy(self.src_net, self.weight_sharing) 47 | 48 | input_dim = self.num_cls 49 | self.discrim = nn.Sequential( 50 | nn.Linear(input_dim, 500), 51 | nn.ReLU(), 52 | nn.Linear(500, 500), 53 | nn.ReLU(), 54 | nn.Linear(500, 2), 55 | ) 56 | 57 | self.image_size = self.src_net.image_size 58 | self.num_channels = self.src_net.num_channels 59 | 60 | 61 | def load(self, init_path): 62 | "Loads full src and tgt models." 63 | net_init_dict = torch.load(init_path, map_location=torch.device('cpu')) 64 | self.load_state_dict(net_init_dict, strict=False) 65 | 66 | def load_src_net(self, init_path): 67 | """Initialize source and target with source 68 | weights.""" 69 | if type(init_path) is str: 70 | self.src_net.load(init_path) 71 | self.tgt_net.load(init_path) 72 | else: 73 | # initialize from model 74 | self.src_net.load_state_dict(init_path.state_dict()) 75 | self.tgt_net.load_state_dict(init_path.state_dict()) 76 | 77 | def save(self, out_path): 78 | torch.save(self.state_dict(), out_path) 79 | 80 | def save_tgt_net(self, out_path): 81 | torch.save(self.tgt_net.state_dict(), out_path) -------------------------------------------------------------------------------- /model/grl.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any, Tuple 2 | import numpy as np 3 | import torch.nn as nn 4 | from torch.autograd import Function 5 | import torch 6 | 7 | 8 | class GradientReverseFunction(Function): 9 | 10 | @staticmethod 11 | def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor: 12 | ctx.coeff = coeff 13 | output = input * 1.0 14 | return output 15 | 16 | @staticmethod 17 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: 18 | return grad_output.neg() * ctx.coeff, None 19 | 20 | 21 | class GradientReverseLayer(nn.Module): 22 | def __init__(self): 23 | super(GradientReverseLayer, self).__init__() 24 | 25 | def forward(self, *input): 26 | return GradientReverseFunction.apply(*input) 27 | 28 | 29 | class WarmStartGradientReverseLayer(nn.Module): 30 | """Gradient Reverse Layer :math:`\mathcal{R}(x)` with warm start 31 | 32 | The forward and backward behaviours are: 33 | 34 | .. math:: 35 | \mathcal{R}(x) = x, 36 | 37 | \dfrac{ d\mathcal{R}} {dx} = - \lambda I. 38 | 39 | :math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule: 40 | 41 | .. math:: 42 | \lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo 43 | 44 | where :math:`i` is the iteration step. 45 | 46 | Args: 47 | alpha (float, optional): :math:`α`. Default: 1.0 48 | lo (float, optional): Initial value of :math:`\lambda`. Default: 0.0 49 | hi (float, optional): Final value of :math:`\lambda`. Default: 1.0 50 | max_iters (int, optional): :math:`N`. Default: 1000 51 | auto_step (bool, optional): If True, increase :math:`i` each time `forward` is called. 52 | Otherwise use function `step` to increase :math:`i`. Default: False 53 | """ 54 | 55 | def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1., 56 | max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False): 57 | super(WarmStartGradientReverseLayer, self).__init__() 58 | self.alpha = alpha 59 | self.lo = lo 60 | self.hi = hi 61 | self.iter_num = 0 62 | self.max_iters = max_iters 63 | self.auto_step = auto_step 64 | 65 | def forward(self, input: torch.Tensor) -> torch.Tensor: 66 | """""" 67 | coeff = np.float(2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters)) 68 | - (self.hi - self.lo) + self.lo) 69 | if self.auto_step: 70 | self.step() 71 | return GradientReverseFunction.apply(input, coeff) 72 | 73 | def step(self): 74 | """Increase iteration number :math:`i` by 1""" 75 | self.iter_num += 1 76 | -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | models = {} 4 | def register_model(name): 5 | def decorator(cls): 6 | models[name] = cls 7 | return cls 8 | return decorator 9 | 10 | def get_model(name, num_cls=10, **args): 11 | net = models[name](num_cls=num_cls, **args) 12 | if torch.cuda.is_available(): 13 | net = net.cuda() 14 | return net 15 | -------------------------------------------------------------------------------- /model/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models 4 | import torch.nn.functional as F 5 | from .grl import GradientReverseFunction 6 | from .models import register_model 7 | 8 | 9 | class TaskNet(nn.Module): 10 | num_channels = 3 11 | image_size = 32 12 | name = 'TaskNet' 13 | 14 | def __init__(self, num_cls=10, normalize=False, temp=0.05): 15 | super(TaskNet, self).__init__() 16 | self.num_cls = num_cls 17 | self.setup_net() 18 | self.criterion = nn.CrossEntropyLoss() 19 | self.normalize = normalize 20 | self.temp = temp 21 | 22 | def forward(self, x, with_emb=False, reverse_grad=False): 23 | x = self.conv_params(x) 24 | x = x.view(x.size(0), -1) 25 | #x = x.clone() 26 | emb = self.fc_params(x) 27 | 28 | if isinstance(self.classifier, nn.Sequential): # LeNet 29 | emb = self.classifier[:-1](emb) 30 | if reverse_grad: emb = GradientReverseFunction.apply(emb) 31 | if self.normalize: emb = F.normalize(emb) / self.temp 32 | score = self.classifier[-1](emb) 33 | else: # ResNet 34 | if reverse_grad: emb = GradientReverseFunction.apply(emb) 35 | if self.normalize: emb = F.normalize(emb) / self.temp 36 | score = self.classifier(emb) 37 | 38 | if with_emb: 39 | return score, emb 40 | else: 41 | return score 42 | 43 | def setup_net(self): 44 | """Method to be implemented in each class.""" 45 | pass 46 | 47 | def load(self, init_path): 48 | net_init_dict = torch.load(init_path) 49 | self.load_state_dict(net_init_dict, strict=False) 50 | 51 | def save(self, out_path): 52 | torch.save(self.state_dict(), out_path) 53 | 54 | def parameters(self, lr, lr_scalar=0.1): 55 | parameter_list = [ 56 | {'params': self.conv_params.parameters(), 'lr': lr * lr_scalar}, 57 | {'params': self.fc_params.parameters(), 'lr': lr}, 58 | {'params': self.classifier.parameters(), 'lr': lr}, 59 | ] 60 | 61 | return parameter_list 62 | 63 | 64 | @register_model('ResNet34Fc') 65 | class ResNet34Fc(TaskNet): 66 | num_channels = 3 67 | name = 'ResNet34Fc' 68 | 69 | def setup_net(self): 70 | model = torchvision.models.resnet34(pretrained=True) 71 | model.fc = nn.Identity() 72 | self.conv_params = model 73 | self.fc_params = nn.Linear(512, 512) 74 | self.classifier = nn.Linear(512, self.num_cls, bias=False) 75 | 76 | 77 | class BatchNorm1d(nn.Module): 78 | def __init__(self, dim): 79 | super(BatchNorm1d, self).__init__() 80 | self.BatchNorm1d = nn.BatchNorm1d(dim) 81 | 82 | def __call__(self, x): 83 | if x.size(0) == 1: 84 | x = torch.cat((x,x), 0) 85 | x = self.BatchNorm1d(x)[:1] 86 | else: 87 | x = self.BatchNorm1d(x) 88 | return x 89 | 90 | 91 | @register_model('ResNet50Fc') 92 | class ResNet50Fc(TaskNet): 93 | num_channels = 3 94 | name = 'ResNet50Fc' 95 | 96 | def setup_net(self): 97 | model = torchvision.models.resnet50(pretrained=True) 98 | model.fc = nn.Identity() 99 | self.conv_params = model 100 | self.fc_params = nn.Sequential(nn.Linear(2048, 256), BatchNorm1d(256)) 101 | self.classifier = nn.Linear(256, self.num_cls) 102 | 103 | 104 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # Table 1 2 | # office-home 5%-budget 3 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA ft 4 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA mme 5 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA RAA 6 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA LAA 7 | 8 | # office-home 10%-budget 9 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA RAA LADA.S_M 5 ADA.BUDGET 0.1 10 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA LAA LADA.S_M 5 ADA.BUDGET 0.1 11 | 12 | # office-home rsut 10%-budget 13 | python main.py --cfg configs/officehome_RSUT.yaml --gpu 0 --log log/oh_RSUT/LADA ADA.AL LAS ADA.DA RAA LADA.S_M 5 ADA.BUDGET 0.1 14 | python main.py --cfg configs/officehome_RSUT.yaml --gpu 0 --log log/oh_RSUT/LADA ADA.AL LAS ADA.DA LAA LADA.S_M 5 ADA.BUDGET 0.1 15 | 16 | # office-31 5%-budget 17 | python main.py --cfg configs/office31.yaml --gpu 0 --log log/office31/LADA ADA.AL LAS ADA.DA LAA 18 | 19 | # visda 5%-budget 20 | python main.py --cfg configs/visda.yaml --gpu 0 --log log/visda/LADA ADA.AL LAS ADA.DA LAA 21 | 22 | 23 | -------------------------------------------------------------------------------- /solver/CDACsolver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from dataset.image_list import ImageList 5 | from .solver import BaseSolver, register_solver 6 | from collections import defaultdict 7 | import numpy as np 8 | from dataset.transform import rand_transform2 9 | 10 | def get_losses_unlabeled(net, im_data, im_data_bar, im_data_bar2, target, BCE, w_cons, device): 11 | """ Get losses for unlabeled samples.""" 12 | 13 | output, feat = net(im_data, with_emb=True, reverse_grad=True) 14 | output_bar, feat_bar = net(im_data_bar, with_emb=True, reverse_grad=True) 15 | prob, prob_bar = F.softmax(output, dim=1), F.softmax(output_bar, dim=1) 16 | 17 | # loss for adversarial adpative clustering 18 | aac_loss = advbce_unlabeled(target=target, feat=feat, prob=prob, prob_bar=prob_bar, device=device, bce=BCE) 19 | 20 | output = net.forward_emb(feat) 21 | output_bar = net.forward_emb(feat_bar) 22 | output_bar2 = net(im_data_bar2) 23 | 24 | prob = F.softmax(output, dim=1) 25 | prob_bar = F.softmax(output_bar, dim=1) 26 | prob_bar2 = F.softmax(output_bar2, dim=1) 27 | 28 | max_probs, pseudo_labels = torch.max(prob.detach_(), dim=-1) 29 | mask = max_probs.ge(0.95).float() 30 | 31 | # loss for pseudo labeling 32 | pl_loss = (F.cross_entropy(output_bar2, pseudo_labels, reduction='none') * mask).mean() 33 | 34 | # loss for consistency 35 | con_loss = w_cons * F.mse_loss(prob_bar, prob_bar2) 36 | 37 | return aac_loss, pl_loss, con_loss 38 | 39 | 40 | def advbce_unlabeled(target, feat, prob, prob_bar, device, bce): 41 | """ Construct adversarial adpative clustering loss.""" 42 | target_ulb = pairwise_target(feat, target, device) 43 | prob_bottleneck_row, _ = PairEnum2D(prob) 44 | _, prob_bottleneck_col = PairEnum2D(prob_bar) 45 | adv_bce_loss = -bce(prob_bottleneck_row, prob_bottleneck_col, target_ulb) 46 | return adv_bce_loss 47 | 48 | 49 | def pairwise_target(feat, target, device, topk=5): 50 | """ Produce pairwise similarity label.""" 51 | feat_detach = feat.detach() 52 | # For unlabeled data 53 | if target is None: 54 | rank_feat = feat_detach 55 | rank_idx = torch.argsort(rank_feat, dim=1, descending=True) 56 | rank_idx1, rank_idx2 = PairEnum2D(rank_idx) 57 | rank_idx1, rank_idx2 = rank_idx1[:, :topk], rank_idx2[:, :topk] 58 | rank_idx1, _ = torch.sort(rank_idx1, dim=1) 59 | rank_idx2, _ = torch.sort(rank_idx2, dim=1) 60 | rank_diff = rank_idx1 - rank_idx2 61 | rank_diff = torch.sum(torch.abs(rank_diff), dim=1) 62 | target_ulb = torch.ones_like(rank_diff).float().to(device) 63 | target_ulb[rank_diff > 0] = 0 64 | # For labeled data 65 | elif target is not None: 66 | target_row, target_col = PairEnum1D(target) 67 | target_ulb = torch.zeros(target.size(0) * target.size(0)).float().to(device) 68 | target_ulb[target_row == target_col] = 1 69 | else: 70 | raise ValueError('Please check your target.') 71 | return target_ulb 72 | 73 | 74 | def PairEnum1D(x): 75 | """ Enumerate all pairs of feature in x with 1 dimension.""" 76 | assert x.ndimension() == 1, 'Input dimension must be 1' 77 | x1 = x.repeat(x.size(0), ) 78 | x2 = x.repeat(x.size(0)).view(-1, x.size(0)).transpose(1, 0).reshape(-1) 79 | return x1, x2 80 | 81 | 82 | def PairEnum2D(x): 83 | """ Enumerate all pairs of feature in x with 2 dimensions.""" 84 | assert x.ndimension() == 2, 'Input dimension must be 2' 85 | x1 = x.repeat(x.size(0), 1) 86 | x2 = x.repeat(1, x.size(0)).view(-1, x.size(1)) 87 | return x1, x2 88 | 89 | 90 | def sigmoid_rampup(current, rampup_length): 91 | """ Exponential rampup from https://arxiv.org/abs/1610.02242""" 92 | if rampup_length == 0: 93 | return 1.0 94 | else: 95 | current = np.clip(current, 0.0, rampup_length) 96 | phase = 1.0 - current / rampup_length 97 | return float(np.exp(-5.0 * phase * phase)) 98 | 99 | 100 | class BCE(nn.Module): 101 | eps = 1e-7 102 | 103 | def forward(self, prob1, prob2, simi): 104 | P = prob1.mul_(prob2) 105 | P = P.sum(1) 106 | P.mul_(simi).add_(simi.eq(-1).type_as(P)) 107 | neglogP = -P.add_(BCE.eps).log_() 108 | return neglogP.mean() 109 | 110 | 111 | class BCE_softlabels(nn.Module): 112 | """ Construct binary cross-entropy loss.""" 113 | eps = 1e-7 114 | 115 | def forward(self, prob1, prob2, simi): 116 | P = prob1.mul_(prob2) 117 | P = P.sum(1) 118 | neglogP = - (simi * torch.log(P + BCE.eps) + (1. - simi) * torch.log(1. - P + BCE.eps)) 119 | return neglogP.mean() 120 | 121 | 122 | def inv_lr_scheduler(param_lr, optimizer, iter_num, gamma=0.0001, 123 | power=0.75, init_lr=0.001): 124 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 125 | lr = init_lr * (1 + gamma * iter_num) ** (- power) 126 | i = 0 127 | for param_group in optimizer.param_groups: 128 | param_group['lr'] = lr * param_lr[i] 129 | i += 1 130 | return optimizer 131 | 132 | 133 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0): 134 | return np.float(2.0 * (high - low) / 135 | (1.0 + np.exp(- alpha * iter_num / max_iter)) - 136 | (high - low) + low) 137 | 138 | 139 | @register_solver('CDAC') 140 | class CDACSolver(BaseSolver): 141 | """ 142 | Implements Cross-Domain Adaptive Clustering for Semi-Supervised Domain Adaptation: https://arxiv.org/abs/2104.09415 143 | https://github.com/lijichang/CVPR2021-SSDA 144 | """ 145 | 146 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, 147 | ada_stage, device, cfg, **kwargs): 148 | super(CDACSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, 149 | joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs) 150 | 151 | def solve(self, epoch): 152 | self.net.train() 153 | 154 | self.tgt_unsup_loader.dataset.rand_transform = rand_transform2 155 | self.tgt_unsup_loader.dataset.rand_num = 2 156 | 157 | data_iter_s = iter(self.src_loader) 158 | data_iter_t = iter(self.tgt_sup_loader) 159 | data_iter_t_unl = iter(self.tgt_unsup_loader) 160 | 161 | len_train_source = len(self.src_loader) 162 | len_train_target = len(self.tgt_sup_loader) 163 | len_train_target_semi = len(self.tgt_unsup_loader) 164 | 165 | BCE = BCE_softlabels().to(self.device) 166 | criterion = nn.CrossEntropyLoss().to(self.device) 167 | 168 | iter_per_epoch = len(self.src_loader) 169 | for batch_idx in range(iter_per_epoch): 170 | rampup = sigmoid_rampup(batch_idx+epoch*iter_per_epoch, 20000) 171 | w_cons = 30.0 * rampup 172 | 173 | self.tgt_opt = inv_lr_scheduler([0.1, 1.0, 1.0], self.tgt_opt, batch_idx+epoch*iter_per_epoch, 174 | init_lr=0.01) 175 | 176 | if len(self.tgt_sup_loader) > 0: 177 | if batch_idx % len_train_target == 0: 178 | data_iter_t = iter(self.tgt_sup_loader) 179 | if batch_idx % len_train_target_semi == 0: 180 | data_iter_t_unl = iter(self.tgt_unsup_loader) 181 | if batch_idx % len_train_source == 0: 182 | data_iter_s = iter(self.src_loader) 183 | data_t = next(data_iter_t) 184 | data_t_unl = next(data_iter_t_unl) 185 | data_s = next(data_iter_s) 186 | 187 | # load labeled source data 188 | x_s, target_s = data_s[0], data_s[1] 189 | im_data_s = x_s.to(self.device) 190 | gt_labels_s = target_s.to(self.device) 191 | 192 | # load labeled target data 193 | x_t, target_t = data_t[0], data_t[1] 194 | im_data_t = x_t.to(self.device) 195 | gt_labels_t = target_t.to(self.device) 196 | 197 | # load unlabeled target data 198 | x_tu, x_bar_tu, x_bar2_tu = data_t_unl[0], data_t_unl[3], data_t_unl[4] 199 | im_data_tu = x_tu.to(self.device) 200 | im_data_bar_tu = x_bar_tu.to(self.device) 201 | im_data_bar2_tu = x_bar2_tu.to(self.device) 202 | 203 | self.tgt_opt.zero_grad() 204 | # construct losses for overall labeled data 205 | data = torch.cat((im_data_s, im_data_t), 0) 206 | target = torch.cat((gt_labels_s, gt_labels_t), 0) 207 | out1 = self.net(data) 208 | ce_loss = criterion(out1, target) 209 | 210 | ce_loss.backward(retain_graph=True) 211 | self.tgt_opt.step() 212 | self.tgt_opt.zero_grad() 213 | 214 | # construct losses for unlabeled target data 215 | aac_loss, pl_loss, con_loss = get_losses_unlabeled(self.net, im_data=im_data_tu, im_data_bar=im_data_bar_tu, 216 | im_data_bar2=im_data_bar2_tu, target=None, BCE=BCE, 217 | w_cons=w_cons, device=self.device) 218 | loss = (aac_loss + pl_loss + con_loss) * self.cfg.ADA.UNSUP_WT * 10 219 | else: 220 | if batch_idx % len_train_source == 0: 221 | data_iter_s = iter(self.src_loader) 222 | data_s, label_s, _ = next(data_iter_s) 223 | data_s, label_s = data_s.to(self.device), label_s.to(self.device) 224 | 225 | self.tgt_opt.zero_grad() 226 | output_s = self.net(data_s) 227 | loss = nn.CrossEntropyLoss()(output_s, label_s) * self.cfg.ADA.SRC_SUP_WT 228 | 229 | loss.backward() 230 | self.tgt_opt.step() 231 | 232 | 233 | 234 | 235 | -------------------------------------------------------------------------------- /solver/MCCsolver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | import logging 6 | import numpy as np 7 | from .solver import BaseSolver, register_solver 8 | 9 | def Entropy(input_): 10 | bs = input_.size(0) 11 | epsilon = 1e-5 12 | entropy = -input_ * torch.log(input_ + epsilon) 13 | entropy = torch.sum(entropy, dim=1) 14 | return entropy 15 | 16 | 17 | @register_solver('MCC') 18 | class MCCSolver(BaseSolver): 19 | """ 20 | Implements MCC from Minimum Class Confusion for Versatile Domain Adaptation: https://arxiv.org/abs/1912.03699 21 | https://github.com/thuml/Versatile-Domain-Adaptation 22 | """ 23 | 24 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, 25 | ada_stage, device, cfg, **kwargs): 26 | super(MCCSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, 27 | joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs) 28 | 29 | def solve(self, epoch): 30 | src_iter = iter(self.src_loader) 31 | tgt_un_iter = iter(self.tgt_unsup_loader) 32 | tgt_s_iter = iter(self.tgt_sup_loader) 33 | iter_per_epoch = len(self.src_loader) 34 | 35 | self.net.train() 36 | 37 | for batch_idx in range(iter_per_epoch): 38 | if batch_idx % len(self.src_loader) == 0: 39 | src_iter = iter(self.src_loader) 40 | 41 | if batch_idx % len(self.tgt_unsup_loader) == 0: 42 | tgt_un_iter = iter(self.tgt_unsup_loader) 43 | 44 | data_s, label_s, _ = next(src_iter) 45 | data_s, label_s = data_s.to(self.device), label_s.to(self.device) 46 | 47 | self.tgt_opt.zero_grad() 48 | output_s = self.net(data_s) 49 | loss = nn.CrossEntropyLoss()(output_s, label_s) * self.cfg.ADA.SRC_SUP_WT 50 | 51 | if len(self.tgt_sup_loader) > 0: 52 | try: 53 | data_ts, label_ts, idx_ts = next(tgt_s_iter) 54 | except: 55 | tgt_s_iter = iter(self.tgt_sup_loader) 56 | data_ts, label_ts, idx_ts = next(tgt_s_iter) 57 | 58 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device) 59 | output_ts = self.net(data_ts) 60 | 61 | loss += nn.CrossEntropyLoss()(output_ts, label_ts) 62 | 63 | data_tu, label_tu, _ = next(tgt_un_iter) 64 | data_tu, label_tu = data_tu.to(self.device), label_tu.to(self.device) 65 | output_tu = self.net(data_tu) 66 | 67 | outputs_target_temp = output_tu / self.cfg.MODEL.TEMP 68 | target_softmax_out_temp = nn.Softmax(dim=1)(outputs_target_temp) 69 | target_entropy_weight = Entropy(target_softmax_out_temp).detach() 70 | target_entropy_weight = 1 + torch.exp(-target_entropy_weight) 71 | target_entropy_weight = self.cfg.DATALOADER.BATCH_SIZE * target_entropy_weight / torch.sum(target_entropy_weight) 72 | cov_matrix_t = target_softmax_out_temp.mul(target_entropy_weight.view(-1, 1)).transpose(1, 0).mm( 73 | target_softmax_out_temp) 74 | cov_matrix_t = cov_matrix_t / (torch.sum(cov_matrix_t, dim=1)+1e-12) 75 | mcc_loss = (torch.sum(cov_matrix_t) - torch.trace(cov_matrix_t)) / self.cfg.DATASET.NUM_CLASS 76 | 77 | loss += mcc_loss 78 | 79 | loss.backward() 80 | self.tgt_opt.step() 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /solver/MMEsolver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .solver import BaseSolver, register_solver 6 | 7 | @register_solver('mme') 8 | class MMESolver(BaseSolver): 9 | """ 10 | Implements MME from Semi-supervised Domain Adaptation via Minimax Entropy: https://arxiv.org/abs/1904.06487 11 | """ 12 | 13 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs): 14 | super(MMESolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, 15 | device, cfg, **kwargs) 16 | 17 | def solve(self, epoch): 18 | """ 19 | Semisupervised adaptation via MME: XE on labeled source + XE on labeled target + \ 20 | adversarial ent. minimization on unlabeled target 21 | """ 22 | self.net.train() 23 | 24 | if not self.ada_stage: 25 | src_sup_wt, lambda_unsup = 1.0, 0.1 26 | else: 27 | src_sup_wt, lambda_unsup = self.cfg.ADA.SRC_SUP_WT, self.cfg.ADA.UNSUP_WT 28 | 29 | tgt_sup_iter = iter(self.tgt_sup_loader) 30 | 31 | joint_loader = zip(self.src_loader, self.tgt_unsup_loader) # changed to tgt_loader to be consistent with CLUE implementation 32 | for batch_idx, ((data_s, label_s, _), (data_tu, label_tu, _)) in enumerate(joint_loader): 33 | data_s, label_s = data_s.to(self.device), label_s.to(self.device) 34 | data_tu = data_tu.to(self.device) 35 | 36 | if self.ada_stage: 37 | try: 38 | data_ts, label_ts, _ = next(tgt_sup_iter) 39 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device) 40 | except: 41 | # no labeled target data 42 | try: 43 | tgt_sup_iter = iter(self.tgt_sup_loader) 44 | data_ts, label_ts, _ = next(tgt_sup_iter) 45 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device) 46 | except: 47 | data_ts, label_ts = None, None 48 | 49 | # zero gradients for optimizer 50 | self.tgt_opt.zero_grad() 51 | 52 | # extract features 53 | score_s = self.net(data_s) 54 | xeloss_src = src_sup_wt * nn.CrossEntropyLoss()(score_s, label_s) 55 | 56 | xeloss_tgt = 0 57 | if self.ada_stage and data_ts is not None: 58 | score_ts = self.net(data_ts) 59 | xeloss_tgt = nn.CrossEntropyLoss()(score_ts, label_ts) 60 | 61 | xeloss = xeloss_src + xeloss_tgt 62 | 63 | xeloss.backward() 64 | self.tgt_opt.step() 65 | 66 | # Add adversarial entropy 67 | self.tgt_opt.zero_grad() 68 | 69 | score_tu = self.net(data_tu, reverse_grad=True) 70 | probs_tu = F.softmax(score_tu, dim=1) 71 | loss_adent = lambda_unsup * torch.mean(torch.sum(probs_tu * (torch.log(probs_tu + 1e-5)), 1)) 72 | loss_adent.backward() 73 | 74 | self.tgt_opt.step() 75 | 76 | 77 | -------------------------------------------------------------------------------- /solver/PAAsolver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import copy 5 | import numpy as np 6 | 7 | from dataset.image_list import ImageList 8 | from dataset.transform import rand_transform 9 | from .solver import BaseSolver, register_solver 10 | 11 | 12 | @register_solver('RAA') 13 | class RAASolver(BaseSolver): 14 | """ 15 | Implement Random Anchor set Augmentation (RAA) 16 | """ 17 | 18 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, 19 | ada_stage, device, cfg, **kwargs): 20 | super(RAASolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, 21 | joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs) 22 | 23 | def solve(self, epoch, seq_query_loader): 24 | K = self.cfg.LADA.A_K 25 | th = self.cfg.LADA.A_TH 26 | 27 | # create an anchor set 28 | if len(self.tgt_sup_loader) > 0: 29 | tgt_sup_dataset = self.tgt_sup_loader.dataset 30 | tgt_sup_samples = [tgt_sup_dataset.samples[i] for i in self.tgt_sup_loader.sampler.indices] 31 | seed_dataset = ImageList(tgt_sup_samples, root=tgt_sup_dataset.root, transform=tgt_sup_dataset.transform) 32 | seed_dataset.rand_transform = rand_transform 33 | seed_dataset.rand_num = self.cfg.LADA.A_RAND_NUM 34 | seed_loader = torch.utils.data.DataLoader(seed_dataset, shuffle=True, 35 | batch_size=self.tgt_sup_loader.batch_size, num_workers=self.tgt_sup_loader.num_workers) 36 | seed_idxs = self.tgt_sup_loader.sampler.indices.tolist() 37 | seed_iter = iter(seed_loader) 38 | seed_labels = [seed_dataset.samples[i][1] for i in range(len(seed_dataset))] 39 | 40 | if K > 0: 41 | # build nearest neighbors 42 | self.net.eval() 43 | tgt_idxs = [] 44 | tgt_embs = [] 45 | tgt_labels = [] 46 | tgt_data = [] 47 | seq_query_loader = copy.deepcopy(seq_query_loader) 48 | seq_query_loader.dataset.transform = copy.deepcopy(self.tgt_loader.dataset.transform) 49 | with torch.no_grad(): 50 | for sample_ in seq_query_loader: 51 | sample = copy.deepcopy(sample_) 52 | del sample_ 53 | data, label, idx = sample[0], sample[1], sample[2] 54 | data, label = data.to(self.device), label.to(self.device) 55 | score, emb = self.net(data, with_emb=True) 56 | tgt_embs.append(F.normalize(emb).detach().clone().cpu()) 57 | tgt_labels.append(label.cpu()) 58 | tgt_idxs.append(idx.cpu()) 59 | tgt_data.append(data.cpu()) 60 | 61 | tgt_embs = torch.cat(tgt_embs) 62 | tgt_data = torch.cat(tgt_data) 63 | tgt_idxs = torch.cat(tgt_idxs) 64 | 65 | self.net.train() 66 | 67 | src_iter = iter(self.src_loader) 68 | iter_per_epoch = len(self.src_loader) 69 | 70 | for batch_idx in range(iter_per_epoch): 71 | if batch_idx % len(self.src_loader) == 0: 72 | src_iter = iter(self.src_loader) 73 | 74 | data_s, label_s, _ = next(src_iter) 75 | data_s, label_s = data_s.to(self.device), label_s.to(self.device) 76 | 77 | self.tgt_opt.zero_grad() 78 | output_s = self.net(data_s) 79 | loss = nn.CrossEntropyLoss()(output_s, label_s) 80 | 81 | if len(self.tgt_sup_loader) > 0: 82 | try: 83 | data_ts, label_ts, idx_ts, *data_rand_ts = next(seed_iter) 84 | except: 85 | seed_iter = iter(seed_loader) 86 | data_ts, label_ts, idx_ts, *data_rand_ts = next(seed_iter) 87 | 88 | 89 | if len(data_rand_ts)>0: 90 | for i, r_data in enumerate(data_rand_ts): 91 | alpha = 0.2 92 | mask = torch.FloatTensor(np.random.beta(alpha, alpha, size=(data_ts.shape[0], 1, 1, 1))) 93 | data_ts = (data_ts * mask) + (r_data * (1 - mask)) 94 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device) 95 | output_ts, emb_ts = self.net(data_ts, with_emb=True) 96 | loss += nn.CrossEntropyLoss()(output_ts, label_ts) 97 | else: 98 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device) 99 | output_ts, emb_ts = self.net(data_ts, with_emb=True) 100 | loss += nn.CrossEntropyLoss()(output_ts, label_ts) 101 | 102 | loss.backward() 103 | self.tgt_opt.step() 104 | 105 | if len(self.tgt_sup_loader) > 0 and K > 0 and len(seed_idxs) < tgt_embs.shape[0]: 106 | nn_idxs = torch.randint(0, tgt_data.shape[0], (data_ts.shape[0],)).to(self.device) 107 | 108 | data_nn = tgt_data[nn_idxs].to(self.device) 109 | 110 | with torch.no_grad(): 111 | output_nn, emb_nn = self.net(data_nn, with_emb=True) 112 | prob_nn = torch.softmax(output_nn, dim=-1) 113 | tgt_embs[nn_idxs] = F.normalize(emb_nn).detach().clone().cpu() 114 | 115 | conf_samples = [] 116 | conf_idx = [] 117 | conf_pl = [] 118 | dist = np.eye(prob_nn.shape[-1])[np.array(seed_labels)].sum(0) + 1 119 | sp = 1 - dist / dist.max() + dist.min() / dist.max() 120 | 121 | for i in range(prob_nn.shape[0]): 122 | idx = tgt_idxs[nn_idxs[i]].item() 123 | pl_i = prob_nn[i].argmax(-1).item() 124 | if np.random.random() <= sp[pl_i] and prob_nn[i].max(-1)[0] >= th and idx not in seed_idxs: 125 | conf_samples.append((self.tgt_loader.dataset.samples[idx][0], pl_i)) 126 | conf_idx.append(idx) 127 | conf_pl.append(pl_i) 128 | 129 | seed_dataset.add_item(conf_samples) 130 | seed_idxs.extend(conf_idx) 131 | seed_labels.extend(conf_pl) 132 | 133 | 134 | @register_solver('LAA') 135 | class LAASolver(BaseSolver): 136 | """ 137 | Local context-aware Anchor set Augmentation (LAA) 138 | """ 139 | 140 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, 141 | ada_stage, device, cfg, **kwargs): 142 | super(LAASolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, 143 | joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs) 144 | 145 | def solve(self, epoch, seq_query_loader): 146 | K = self.cfg.LADA.A_K 147 | th = self.cfg.LADA.A_TH 148 | 149 | # create an anchor set 150 | if len(self.tgt_sup_loader) > 0: 151 | tgt_sup_dataset = self.tgt_sup_loader.dataset 152 | tgt_sup_samples = [tgt_sup_dataset.samples[i] for i in self.tgt_sup_loader.sampler.indices] 153 | seed_dataset = ImageList(tgt_sup_samples, root=tgt_sup_dataset.root, transform=tgt_sup_dataset.transform) 154 | seed_dataset.rand_transform = rand_transform 155 | seed_dataset.rand_num = self.cfg.LADA.A_RAND_NUM 156 | seed_loader = torch.utils.data.DataLoader(seed_dataset, shuffle=True, 157 | batch_size=self.tgt_sup_loader.batch_size, num_workers=self.tgt_sup_loader.num_workers) 158 | seed_idxs = self.tgt_sup_loader.sampler.indices.tolist() 159 | seed_iter = iter(seed_loader) 160 | seed_labels = [seed_dataset.samples[i][1] for i in range(len(seed_dataset))] 161 | 162 | if K > 0: 163 | # build nearest neighbors 164 | self.net.eval() 165 | tgt_idxs = [] 166 | tgt_embs = [] 167 | tgt_labels = [] 168 | tgt_data = [] 169 | seq_query_loader = copy.deepcopy(seq_query_loader) 170 | seq_query_loader.dataset.transform = copy.deepcopy(self.tgt_loader.dataset.transform) 171 | with torch.no_grad(): 172 | for sample_ in seq_query_loader: 173 | sample = copy.deepcopy(sample_) 174 | del sample_ 175 | data, label, idx = sample[0], sample[1], sample[2] 176 | data, label = data.to(self.device), label.to(self.device) 177 | score, emb = self.net(data, with_emb=True) 178 | tgt_embs.append(F.normalize(emb).detach().clone().cpu()) 179 | tgt_labels.append(label.cpu()) 180 | tgt_idxs.append(idx.cpu()) 181 | tgt_data.append(data.cpu()) 182 | 183 | tgt_embs = torch.cat(tgt_embs) 184 | tgt_data = torch.cat(tgt_data) 185 | tgt_idxs = torch.cat(tgt_idxs) 186 | 187 | self.net.train() 188 | 189 | src_iter = iter(self.src_loader) 190 | iter_per_epoch = len(self.src_loader) 191 | 192 | for batch_idx in range(iter_per_epoch): 193 | if batch_idx % len(self.src_loader) == 0: 194 | src_iter = iter(self.src_loader) 195 | 196 | data_s, label_s, _ = next(src_iter) 197 | data_s, label_s = data_s.to(self.device), label_s.to(self.device) 198 | 199 | self.tgt_opt.zero_grad() 200 | output_s = self.net(data_s) 201 | loss = nn.CrossEntropyLoss()(output_s, label_s) 202 | 203 | if len(self.tgt_sup_loader) > 0: 204 | try: 205 | data_ts, label_ts, idx_ts, *data_rand_ts = next(seed_iter) 206 | except: 207 | seed_iter = iter(seed_loader) 208 | data_ts, label_ts, idx_ts, *data_rand_ts = next(seed_iter) 209 | 210 | if len(data_rand_ts) > 0: 211 | for i, r_data in enumerate(data_rand_ts): 212 | alpha = 0.2 213 | mask = torch.FloatTensor(np.random.beta(alpha, alpha, size=(data_ts.shape[0], 1, 1, 1))) 214 | data_ts = (data_ts * mask) + (r_data * (1 - mask)) 215 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device) 216 | output_ts, emb_ts = self.net(data_ts, with_emb=True) 217 | loss += nn.CrossEntropyLoss()(output_ts, label_ts) 218 | else: 219 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device) 220 | output_ts, emb_ts = self.net(data_ts, with_emb=True) 221 | loss += nn.CrossEntropyLoss()(output_ts, label_ts) 222 | 223 | loss.backward() 224 | self.tgt_opt.step() 225 | 226 | if len(self.tgt_sup_loader) > 0 and K > 0 and len(seed_idxs) < tgt_embs.shape[0]: 227 | mask = torch.ones(tgt_embs.shape[0]) 228 | re_idxs = tgt_idxs[mask == 1] 229 | 230 | sim = F.normalize(emb_ts.cpu()).mm(tgt_embs[re_idxs].transpose(1, 0)) 231 | sim_topk, topk = torch.topk(sim, k=K, dim=1) 232 | 233 | rand_nn = torch.randint(0, topk.shape[1], (topk.shape[0], 1)) 234 | nn_idxs = torch.gather(topk, dim=-1, index=rand_nn).squeeze(1) 235 | nn_idxs = re_idxs[nn_idxs] 236 | 237 | data_nn = tgt_data[nn_idxs].to(self.device) 238 | 239 | with torch.no_grad(): 240 | output_nn, emb_nn = self.net(data_nn, with_emb=True) 241 | prob_nn = torch.softmax(output_nn, dim=-1) 242 | tgt_embs[nn_idxs] = F.normalize(emb_nn).detach().clone().cpu() 243 | 244 | conf_samples = [] 245 | conf_idx = [] 246 | conf_pl = [] 247 | dist = np.eye(prob_nn.shape[-1])[np.array(seed_labels)].sum(0) + 1 248 | dist = dist / dist.max() 249 | sp = 1 - dist / dist.max() + dist.min() / dist.max() 250 | 251 | for i in range(prob_nn.shape[0]): 252 | idx = tgt_idxs[nn_idxs[i]].item() 253 | pl_i = prob_nn[i].argmax(-1).item() 254 | if np.random.random() <= sp[pl_i] and prob_nn[i].max(-1)[0] >= th and idx not in seed_idxs: 255 | conf_samples.append((self.tgt_loader.dataset.samples[idx][0], pl_i)) 256 | conf_idx.append(idx) 257 | conf_pl.append(pl_i) 258 | 259 | seed_dataset.add_item(conf_samples) 260 | seed_idxs.extend(conf_idx) 261 | seed_labels.extend(conf_pl) 262 | 263 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .solver import * 2 | from .MMEsolver import * 3 | from .PAAsolver import * 4 | from .MCCsolver import * 5 | from .CDACsolver import * -------------------------------------------------------------------------------- /solver/solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .utils import ConditionalEntropyLoss 5 | from model.grl import GradientReverseFunction 6 | 7 | 8 | solvers = {} 9 | def register_solver(name): 10 | def decorator(cls): 11 | solvers[name] = cls 12 | return cls 13 | return decorator 14 | 15 | def get_solver(name, *args, kwargs={}): 16 | solver = solvers[name](*args, **kwargs) 17 | return solver 18 | 19 | class BaseSolver: 20 | """ 21 | Base DA solver class 22 | """ 23 | 24 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg): 25 | self.net = net 26 | self.src_loader = src_loader 27 | self.tgt_loader = tgt_loader 28 | self.tgt_sup_loader = tgt_sup_loader 29 | self.tgt_unsup_loader = tgt_unsup_loader 30 | self.joint_sup_loader = joint_sup_loader 31 | self.tgt_opt = tgt_opt 32 | self.ada_stage = ada_stage 33 | self.device = device 34 | self.cfg = cfg 35 | 36 | def solve(self, epoch): 37 | pass 38 | 39 | @register_solver('ft_joint') 40 | class JointFTSolver(BaseSolver): 41 | """ 42 | Finetune on target labels 43 | """ 44 | 45 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs): 46 | super(JointFTSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, 47 | ada_stage, device, cfg, **kwargs) 48 | 49 | def solve(self, epoch): 50 | """ 51 | Finetune on source and target labels jointly 52 | """ 53 | self.net.train() 54 | joint_sup_iter = iter(self.joint_sup_loader) 55 | 56 | while True: 57 | try: 58 | data, target, _ = next(joint_sup_iter) 59 | data, target = data.to(self.device), target.to(self.device) 60 | except: 61 | break 62 | 63 | self.tgt_opt.zero_grad() 64 | output = self.net(data) 65 | loss = nn.CrossEntropyLoss()(output, target) 66 | 67 | loss.backward() 68 | self.tgt_opt.step() 69 | 70 | 71 | @register_solver('ft_tgt') 72 | class TargetFTSolver(BaseSolver): 73 | """ 74 | Finetune on target labels 75 | """ 76 | 77 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs): 78 | super(TargetFTSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, 79 | ada_stage, device, cfg, **kwargs) 80 | 81 | def solve(self, epoch): 82 | """ 83 | Finetune on target labels 84 | """ 85 | self.net.train() 86 | if self.ada_stage: tgt_sup_iter = iter(self.tgt_sup_loader) 87 | 88 | while True: 89 | try: 90 | data_t, target_t, _ = next(tgt_sup_iter) 91 | data_t, target_t = data_t.to(self.device), target_t.to(self.device) 92 | except: 93 | break 94 | 95 | self.tgt_opt.zero_grad() 96 | output = self.net(data_t) 97 | loss = nn.CrossEntropyLoss()(output, target_t) 98 | loss.backward() 99 | self.tgt_opt.step() 100 | 101 | 102 | @register_solver('ft') 103 | class FTSolver(BaseSolver): 104 | """ 105 | Finetune on source and target labels with separate loaders 106 | """ 107 | 108 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs): 109 | super(FTSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, 110 | device, cfg, **kwargs) 111 | 112 | def solve(self, epoch): 113 | self.net.train() 114 | 115 | if self.ada_stage: 116 | src_sup_wt = 1.0 117 | else: 118 | src_sup_wt = self.cfg.ADA.SRC_SUP_WT 119 | 120 | tgt_sup_wt = self.cfg.ADA.TGT_SUP_WT 121 | 122 | tgt_sup_iter = iter(self.tgt_sup_loader) 123 | 124 | for batch_idx, (data_s, label_s, _) in enumerate(self.src_loader): 125 | data_s, label_s = data_s.to(self.device), label_s.to(self.device) 126 | 127 | if self.ada_stage: 128 | try: 129 | data_ts, label_ts, _ = next(tgt_sup_iter) 130 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device) 131 | except: 132 | # no labeled target data 133 | try: 134 | tgt_sup_iter = iter(self.tgt_sup_loader) 135 | data_ts, label_ts, _ = next(tgt_sup_iter) 136 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device) 137 | except: 138 | data_ts, label_ts = None, None 139 | 140 | # zero gradients for optimizer 141 | self.tgt_opt.zero_grad() 142 | 143 | # extract features 144 | score_s = self.net(data_s) 145 | xeloss_src = src_sup_wt * nn.CrossEntropyLoss()(score_s, label_s) 146 | 147 | xeloss_tgt = 0 148 | if self.ada_stage and data_ts is not None: 149 | score_ts = self.net(data_ts) 150 | xeloss_tgt = tgt_sup_wt * nn.CrossEntropyLoss()(score_ts, label_ts) 151 | 152 | xeloss = xeloss_src + xeloss_tgt 153 | xeloss.backward() 154 | self.tgt_opt.step() 155 | 156 | 157 | @register_solver('dann') 158 | class DANNSolver(BaseSolver): 159 | """ 160 | Implements DANN from Unsupervised Domain Adaptation by Backpropagation: https://arxiv.org/abs/1409.7495 161 | """ 162 | 163 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs): 164 | super(DANNSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, 165 | ada_stage, device, cfg, **kwargs) 166 | 167 | def solve(self, epoch, disc, disc_opt): 168 | """ 169 | Semi-supervised adaptation via DANN: XE on labeled source + XE on labeled target + \ 170 | ent. minimization on target + DANN on source<->target 171 | """ 172 | gan_criterion = nn.CrossEntropyLoss() 173 | cent = ConditionalEntropyLoss().to(self.device) 174 | 175 | self.net.train() 176 | disc.train() 177 | 178 | if not self.ada_stage: 179 | src_sup_wt, lambda_unsup, lambda_cent = 1.0, 0.1, 0.01 # Hardcoded for unsupervised DA 180 | else: 181 | src_sup_wt, lambda_unsup, lambda_cent = self.cfg.ADA.SRC_SUP_WT, self.cfg.ADA.UNSUP_WT, self.cfg.ADA.CEN_WT 182 | tgt_sup_iter = iter(self.tgt_sup_loader) 183 | 184 | joint_loader = zip(self.src_loader, self.tgt_loader) # changed to tgt_loader to be consistent with CLUE implementation 185 | for batch_idx, ((data_s, label_s, _), (data_tu, label_tu, _)) in enumerate(joint_loader): 186 | data_s, label_s = data_s.to(self.device), label_s.to(self.device) 187 | data_tu = data_tu.to(self.device) 188 | 189 | if self.ada_stage: 190 | try: 191 | data_ts, label_ts, _ = next(tgt_sup_iter) 192 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device) 193 | except: 194 | # no labeled target data 195 | try: 196 | tgt_sup_iter = iter(self.tgt_sup_loader) 197 | data_ts, label_ts, _ = next(tgt_sup_iter) 198 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device) 199 | except: 200 | data_ts, label_ts = None, None 201 | 202 | # zero gradients for optimizers 203 | self.tgt_opt.zero_grad() 204 | disc_opt.zero_grad() 205 | 206 | # Train with target labels 207 | score_s = self.net(data_s) 208 | xeloss_src = src_sup_wt * nn.CrossEntropyLoss()(score_s, label_s) 209 | 210 | xeloss_tgt = 0 211 | if self.ada_stage and data_ts is not None: 212 | score_ts = self.net(data_ts) 213 | xeloss_tgt = nn.CrossEntropyLoss()(score_ts, label_ts) 214 | 215 | # extract and concat features 216 | score_tu = self.net(data_tu) 217 | f = torch.cat((score_s, score_tu), 0) 218 | 219 | # predict with discriminator 220 | f_rev = GradientReverseFunction.apply(f) 221 | pred_concat = disc(f_rev) 222 | 223 | target_dom_s = torch.ones(len(data_s)).long().to(self.device) 224 | target_dom_t = torch.zeros(len(data_tu)).long().to(self.device) 225 | label_concat = torch.cat((target_dom_s, target_dom_t), 0) 226 | 227 | # compute loss for disciminator 228 | loss_domain = gan_criterion(pred_concat, label_concat) 229 | loss_cent = cent(score_tu) 230 | 231 | loss_final = (xeloss_src + xeloss_tgt) + (lambda_unsup * loss_domain) + (lambda_cent * loss_cent) 232 | 233 | loss_final.backward() 234 | 235 | self.tgt_opt.step() 236 | disc_opt.step() 237 | 238 | 239 | -------------------------------------------------------------------------------- /solver/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class ConditionalEntropyLoss(torch.nn.Module): 5 | """ 6 | Conditional entropy loss utility class 7 | """ 8 | def __init__(self): 9 | super(ConditionalEntropyLoss, self).__init__() 10 | 11 | def forward(self, x): 12 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1) 13 | b = b.sum(dim=1) 14 | return -1.0 * b.mean(dim=0) 15 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | pil_logger = logging.getLogger('PIL') 5 | pil_logger.setLevel(logging.INFO) 6 | 7 | logger_init = False 8 | 9 | def init_logger(_log_file, dir='log/'): 10 | logger = logging.getLogger() 11 | for handler in logger.handlers[:]: 12 | logger.removeHandler(handler) 13 | 14 | logger.setLevel('DEBUG') 15 | BASIC_FORMAT = "%(asctime)s:%(levelname)s:%(message)s" 16 | DATE_FORMAT = '%Y-%m-%d %H.%M.%S' 17 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT) 18 | chlr = logging.StreamHandler() 19 | chlr.setFormatter(formatter) 20 | logger.addHandler(chlr) 21 | 22 | if _log_file is not None: 23 | if not os.path.exists(dir): 24 | os.makedirs(dir) 25 | log_file = osp.join(dir, _log_file + '.log') 26 | fhlr = logging.FileHandler(log_file) 27 | fhlr.setFormatter(formatter) 28 | logger.addHandler(fhlr) 29 | 30 | global logger_init 31 | logger_init = True -------------------------------------------------------------------------------- /utils/lr_schedule.py: -------------------------------------------------------------------------------- 1 | 2 | def inv_lr_scheduler(param_lr, optimizer, iter_num, gamma=0.0001, 3 | power=0.75, init_lr=0.001): 4 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 5 | lr = init_lr * (1 + gamma * iter_num) ** (- power) 6 | i = 0 7 | for param_group in optimizer.param_groups: 8 | param_group['lr'] = lr * param_lr[i] 9 | i += 1 10 | return optimizer 11 | 12 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | import os 7 | # from model import get_model 8 | import torch.optim as optim 9 | # from solver import get_solver 10 | import logging 11 | 12 | def resetRNGseed(seed): 13 | np.random.seed(seed) 14 | random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def train(model, device, train_loader, optimizer, epoch): 21 | """ 22 | Test model on provided data for single epoch 23 | """ 24 | model.train() 25 | total_loss, correct = 0.0, 0 26 | for batch_idx, (data, target, _) in enumerate(tqdm(train_loader)): 27 | data, target = data.to(device), target.to(device) 28 | optimizer.zero_grad() 29 | output = model(data) 30 | loss = nn.CrossEntropyLoss()(output, target) 31 | total_loss += loss.item() 32 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 33 | corr = pred.eq(target.view_as(pred)).sum().item() 34 | correct += corr 35 | loss.backward() 36 | optimizer.step() 37 | 38 | train_acc = 100. * correct / len(train_loader.sampler) 39 | avg_loss = total_loss / len(train_loader.sampler) 40 | logging.info('Train Epoch: {} | Avg. Loss: {:.3f} | Train Acc: {:.3f}'.format(epoch, avg_loss, train_acc)) 41 | return avg_loss 42 | 43 | def test(model, device, test_loader, split="target test"): 44 | """ 45 | Test model on provided data 46 | """ 47 | # logging.info('Evaluating model on {}...'.format(split)) 48 | model.eval() 49 | test_loss = 0 50 | correct = 0 51 | with torch.no_grad(): 52 | for data, target, _ in test_loader: 53 | data, target = data.to(device), target.to(device) 54 | output = model(data) 55 | loss = nn.CrossEntropyLoss()(output, target) 56 | test_loss += loss.item() # sum up batch loss 57 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 58 | corr = pred.eq(target.view_as(pred)).sum().item() 59 | correct += corr 60 | del loss, output 61 | 62 | test_loss /= len(test_loader.sampler) 63 | test_acc = 100. * correct / len(test_loader.sampler) 64 | 65 | return test_acc, test_loss 66 | 67 | 68 | # def run_unsupervised_da(model, src_train_loader, tgt_sup_loader, tgt_unsup_loader, train_idx, num_classes, device, 69 | # cfg): 70 | # """ 71 | # Unsupervised adaptation of source model to target at round 0 72 | # Returns: 73 | # Model post adaptation 74 | # """ 75 | # source = cfg.DATASET.SOURCE_DOMAIN 76 | # target = cfg.DATASET.TARGET_DOMAIN 77 | # da_strat = cfg.ADA.DA 78 | # 79 | # adapt_dir = os.path.join('checkpoints', 'adapt') 80 | # adapt_net_file = os.path.join(adapt_dir, '{}_{}_{}_{}.pth'.format(da_strat, source, target, cfg.MODEL.BACKBONE.NAME)) 81 | # 82 | # if not os.path.exists(adapt_dir): 83 | # os.makedirs(adapt_dir) 84 | # 85 | # if os.path.exists(adapt_net_file): 86 | # logging.info('Found pretrained checkpoint, loading...') 87 | # adapt_model = get_model('AdaptNet', num_cls=num_classes, weights_init=adapt_net_file, model=cfg.MODEL.BACKBONE.NAME) 88 | # else: 89 | # logging.info('No pretrained checkpoint found, training...') 90 | # source_file = '{}_{}_source.pth'.format(source, cfg.MODEL.BACKBONE.NAME) 91 | # source_path = os.path.join('checkpoints', 'source', source_file) 92 | # adapt_model = get_model('AdaptNet', num_cls=num_classes, src_weights_init=source_path, model=cfg.MODEL.BACKBONE.NAME) 93 | # opt_net_tgt = optim.Adadelta(adapt_model.tgt_net.parameters(cfg.OPTIM.UDA_LR, cfg.OPTIM.BASE_LR_MULT), lr=cfg.OPTIM.UDA_LR, weight_decay=0.00001) 94 | # uda_solver = get_solver(da_strat, adapt_model.tgt_net, src_train_loader, tgt_sup_loader, tgt_unsup_loader, 95 | # train_idx, opt_net_tgt, 0, device, cfg) 96 | # for epoch in range(cfg.TRAINER.MAX_UDA_EPOCHS): 97 | # if da_strat == 'dann': 98 | # opt_dis_adapt = optim.Adadelta(adapt_model.discriminator.parameters(), lr=cfg.OPTIM.UDA_LR, weight_decay=0.00001) 99 | # uda_solver.solve(epoch, adapt_model.discriminator, opt_dis_adapt) 100 | # elif da_strat in ['mme', 'ft']: 101 | # uda_solver.solve(epoch) 102 | # adapt_model.save(adapt_net_file) 103 | # 104 | # model, src_model, discriminator = adapt_model.tgt_net, adapt_model.src_net, adapt_model.discriminator 105 | # return model, src_model, discriminator 106 | 107 | 108 | def get_optim(name, *args, **kwargs): 109 | if name == 'Adadelta': 110 | return optim.Adadelta(*args, **kwargs) 111 | elif name == 'Adam': 112 | return optim.Adam(*args, **kwargs) 113 | elif name == 'SGD': 114 | return optim.SGD(*args, **kwargs, momentum=0.9, nesterov=True) --------------------------------------------------------------------------------