├── LICENSE
├── README.md
├── active
├── LASsampler.py
├── MHPsampler.py
├── __init__.py
├── budget.py
├── sampler.py
└── utils.py
├── config
├── __init__.py
└── defaults.py
├── configs
├── office31.yaml
├── officehome.yaml
├── officehome_RSUT.yaml
└── visda.yaml
├── data
└── image_list
│ ├── domainnet
│ ├── clipart_test.txt
│ ├── clipart_train.txt
│ ├── infograph_test.txt
│ ├── infograph_train.txt
│ ├── painting_test.txt
│ ├── painting_train.txt
│ ├── quickdraw_test.txt
│ ├── quickdraw_train.txt
│ ├── real_test.txt
│ ├── real_train.txt
│ ├── sketch_test.txt
│ └── sketch_train.txt
│ ├── office31
│ ├── amazon.txt
│ ├── dslr.txt
│ └── webcam.txt
│ ├── officehome
│ ├── Art.txt
│ ├── Clipart.txt
│ ├── Product.txt
│ └── RealWorld.txt
│ ├── officehome_RSUT
│ ├── Clipart_RS.txt
│ ├── Clipart_UT.txt
│ ├── Product_RS.txt
│ ├── Product_UT.txt
│ ├── RealWorld_RS.txt
│ └── RealWorld_UT.txt
│ └── visda
│ ├── train.txt
│ └── validation.txt
├── dataset
├── ASDADataset.py
├── image_list.py
├── randaugment.py
├── randaugmentMC.py
└── transform.py
├── fig
└── framework.png
├── main.py
├── model
├── __init__.py
├── adaptor.py
├── grl.py
├── models.py
└── network.py
├── run.sh
├── solver
├── CDACsolver.py
├── MCCsolver.py
├── MMEsolver.py
├── PAAsolver.py
├── __init__.py
├── solver.py
└── utils.py
└── utils
├── logger.py
├── lr_schedule.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Tao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Local Context-Aware Active Domain Adaptation
2 |
3 | Pytorch implementation of LADA.
4 | > [Local Context-Aware Active Domain Adaptation](https://arxiv.org/abs/2208.12856)
5 | > Tao Sun, Cheng Lu, and Haibin Ling
6 | > *ICCV 2023*
7 | >
8 | ## Abstract
9 | Active Domain Adaptation (ADA) queries the labels of a small number of selected target samples to help adapting a model from a source domain to a target domain. The local context of queried data is important, especially when the domain gap is large. However, this has not been fully explored by existing ADA works.
10 |
11 | In this paper, we propose a Local context-aware ADA framework, named LADA, to address this issue. To select informative target samples, we devise a novel criterion based on the local inconsistency of model predictions. Since the labeling budget is usually small, fine-tuning model on only queried data can be inefficient. We progressively augment labeled target data with the confident neighbors in a class-balanced manner.
12 |
13 | Experiments validate that the proposed criterion chooses more informative target samples than existing active selection strategies. Furthermore, our full method surpasses recent ADA arts on various benchmarks.
14 |
15 |
16 |
17 |
18 |
19 | ## Usage
20 | ### Prerequisites
21 | We experimented with python==3.8, pytorch==1.8.0, cudatoolkit==11.1.
22 |
23 | To start, download the [office31](https://faculty.cc.gatech.edu/~judy/domainadapt/), [Office-Home](https://www.hemanthdv.org/officeHomeDataset.html), [VisDA](https://ai.bu.edu/visda-2017/) datasets and set up the path in ./data folder.
24 |
25 | ### Supported methods
26 | | Active Criteria | Paper | Implementation |
27 | |-----------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:----------------------------------------:|
28 | | Random | - | [random](active/sampler.py) |
29 | | Entropy | - | [entropy](active/sampler.py) |
30 | | Margin | - | [margin](active/sampler.py) |
31 | | LeastConfidence | - | [leastConfidence](active/sampler.py) |
32 | | CoreSet | [ICLR 2018](https://openreview.net/pdf?id=H1aIuk-RW) | [coreset](active/sampler.py) |
33 | | AADA | [WACV 2020](https://openaccess.thecvf.com/content_WACV_2020/papers/Su_Active_Adversarial_Domain_Adaptation_WACV_2020_paper.pdf) | [AADA](active/sampler.py) |
34 | | BADGE | [ICLR 2020](https://openreview.net/pdf?id=ryghZJBKPS) | [BADGE](active/sampler.py) |
35 | | CLUE | [ICCV 2021](https://openaccess.thecvf.com/content/ICCV2021/papers/Prabhu_Active_Domain_Adaptation_via_Clustering_Uncertainty-Weighted_Embeddings_ICCV_2021_paper.pdf) | [CLUE](active/sampler.py) |
36 | | MHP | [CVPR 2023](https://openaccess.thecvf.com/content/CVPR2023/papers/Wang_MHPL_Minimum_Happy_Points_Learning_for_Active_Source_Free_Domain_CVPR_2023_paper.pdf) | [MHP](active/MHPsampler.py) |
37 | | LAS (ours) | [ICCV 2023](https://arxiv.org/abs/2208.12856) | [LAS](active/LASsampler.py) |
38 |
39 |
40 | | Domain Adaptation | Paper | Implementation |
41 | |-------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|
42 | | Fine-tuning (joint label set) | - | [ft_joint](active/solver.py) |
43 | | Fine-tuning | - | [ft](active/solver.py) |
44 | | DANN | [JMLR 2016](https://jmlr.org/papers/volume17/15-239/15-239.pdf) | [dann](active/solver.py) |
45 | | MME | [ICCV 2019](https://openaccess.thecvf.com/content_ICCV_2019/papers/Saito_Semi-Supervised_Domain_Adaptation_via_Minimax_Entropy_ICCV_2019_paper.pdf) | [mme](active/MMEsolver.py) |
46 | | MCC | [ECCV 2020](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123660460.pdf) | [MCC](active/MCCsolver.py) |
47 | | CDAC | [CVPR 2021](https://openaccess.thecvf.com/content/CVPR2021/papers/Li_Cross-Domain_Adaptive_Clustering_for_Semi-Supervised_Domain_Adaptation_CVPR_2021_paper.pdf) | [CDAC](solver/CDACsolver.py) |
48 | | RAA (ours) | [ICCV 2023](https://arxiv.org/abs/2208.12856) | [RAA](solver/PAAsolver.py) |
49 | | LAA (ours) | [ICCV 2023](https://arxiv.org/abs/2208.12856) | [LAA](solver/PAAsolver.py) |
50 |
51 |
52 |
53 |
54 | ### Training
55 | To obtain results of baseline active selection criteria on office home with 5% labeling budget,
56 | ```shell
57 | for ADA_DA in 'ft' 'mme'; do
58 | for ADA_AL in 'random' 'entropy' 'margin' 'coreset' 'leastConfidence' 'BADGE' 'AADA' 'CLUE' 'MHP'; do
59 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/baseline ADA.AL $ADA_AL ADA.DA $ADA_DA
60 | done
61 | done
62 | ```
63 |
64 | To reproduce results of LADA on office home with 5% labeling budget,
65 | ```shell
66 | # LAS + fine-tuning with CE loss
67 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA ft
68 | # LAS + MME model adaptation
69 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA mme
70 | # LAS + Random Anchor set Augmentation (RAA)
71 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA RAA
72 | # LAS + Local context-aware Anchor set Augmentation (LAA)
73 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA LAA
74 | ```
75 |
76 | More commands can be found in *run.sh*.
77 |
78 | ## Acknowledgements
79 | The pipline and implementation of baseline methods are adapted from [CLUE](https://github.com/virajprabhu/CLUE) and [deep-active-learning](https://github.com/ej0cl6/deep-active-learning). We adopt configuration files as [EADA](https://github.com/BIT-DA/EADA).
80 |
81 |
82 | ## Citation
83 | If you find our paper and code useful for your research, please consider citing
84 | ```bibtex
85 | @article{sun2022local,
86 | author = {Sun, Tao and Lu, Cheng and Ling, Haibin},
87 | title = {Local Context-Aware Active Domain Adaptation},
88 | journal = {IEEE/CVF International Conference on Computer Vision},
89 | year = {2023}
90 | }
91 | ```
--------------------------------------------------------------------------------
/active/LASsampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 | from sklearn.cluster import KMeans
5 | from sklearn.metrics.pairwise import euclidean_distances
6 |
7 | from .utils import ActualSequentialSampler
8 | from .sampler import register_strategy, SamplingStrategy
9 |
10 | @register_strategy('LAS')
11 | class LASSampling(SamplingStrategy):
12 | '''
13 | Implement Local context-aware sampling (LAS)
14 | '''
15 |
16 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg):
17 | super(LASSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg)
18 |
19 | def query(self, n, epoch):
20 | idxs_unlabeled = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb]
21 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs_unlabeled])
22 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler,
23 | num_workers=self.cfg.DATALOADER.NUM_WORKERS,
24 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False)
25 | # build nearest neighbors
26 | self.model.eval()
27 | all_probs = []
28 | all_embs = []
29 | with torch.no_grad():
30 | for batch_idx, (data, target, _, *_) in enumerate(data_loader):
31 | data, target = data.to(self.device), target.to(self.device)
32 | scores, embs = self.model(data, with_emb=True)
33 | all_embs.append(embs.cpu())
34 | probs = F.softmax(scores, dim=-1)
35 | all_probs.append(probs.cpu())
36 |
37 | all_probs = torch.cat(all_probs)
38 | all_embs = F.normalize(torch.cat(all_embs), dim=-1)
39 |
40 | # get Q_score
41 | sim = all_embs.cpu().mm(all_embs.transpose(1, 0))
42 | K = self.cfg.LADA.S_K
43 | sim_topk, topk = torch.topk(sim, k=K + 1, dim=1)
44 | sim_topk, topk = sim_topk[:, 1:], topk[:, 1:]
45 | wgt_topk = (sim_topk / sim_topk.sum(dim=1, keepdim=True))
46 |
47 | Q_score = -((all_probs[topk] * all_probs.unsqueeze(1)).sum(-1) * wgt_topk).sum(-1)
48 |
49 | # propagate Q_score
50 | for i in range(self.cfg.LADA.S_PROP_ITER):
51 | Q_score += (wgt_topk * Q_score[topk]).sum(-1) * self.cfg.LADA.S_PROP_COEF
52 |
53 | m_idxs = Q_score.sort(descending=True)[1]
54 |
55 | # oversample and find centroids
56 | M = self.cfg.LADA.S_M
57 | m_topk = m_idxs[:n * (1 + M)]
58 | km = KMeans(n_clusters=n)
59 | km.fit(all_embs[m_topk])
60 | dists = euclidean_distances(km.cluster_centers_, all_embs[m_topk])
61 | sort_idxs = dists.argsort(axis=1)
62 | q_idxs = []
63 | ax, rem = 0, n
64 | while rem > 0:
65 | q_idxs.extend(list(sort_idxs[:, ax][:rem]))
66 | q_idxs = list(set(q_idxs))
67 | rem = n - len(q_idxs)
68 | ax += 1
69 |
70 | q_idxs = m_idxs[q_idxs].cpu().numpy()
71 | self.query_dset.rand_transform = None
72 |
73 | return idxs_unlabeled[q_idxs]
74 |
75 |
--------------------------------------------------------------------------------
/active/MHPsampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from sklearn.cluster import KMeans
6 | from sklearn.metrics.pairwise import euclidean_distances
7 |
8 | from .utils import ActualSequentialSampler
9 | from .sampler import register_strategy, SamplingStrategy
10 |
11 |
12 | @register_strategy('MHP')
13 | class MHPSampling(SamplingStrategy):
14 | '''
15 | Implements MHPL: Minimum Happy Points Learning for Active Source Free Domain Adaptation (CVPR'23)
16 | '''
17 |
18 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg):
19 | super(MHPSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg)
20 |
21 | def query(self, n, epoch):
22 | idxs_unlabeled = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb]
23 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs_unlabeled])
24 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler,
25 | num_workers=self.cfg.DATALOADER.NUM_WORKERS,
26 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False)
27 | self.model.eval()
28 | all_probs = []
29 | all_embs = []
30 | with torch.no_grad():
31 | for batch_idx, (data, target, _) in enumerate(data_loader):
32 | data, target = data.to(self.device), target.to(self.device)
33 | scores, embs = self.model(data, with_emb=True)
34 | all_embs.append(embs.cpu())
35 | probs = F.softmax(scores, dim=-1)
36 | all_probs.append(probs)
37 |
38 | all_probs = torch.cat(all_probs)
39 | all_embs = F.normalize(torch.cat(all_embs), dim=-1)
40 |
41 | # find KNN
42 | sim = all_embs.cpu().mm(all_embs.transpose(1, 0))
43 | K = self.cfg.LADA.S_K
44 | sim_topk, topk = torch.topk(sim, k=K + 1, dim=1)
45 | sim_topk, topk = sim_topk[:, 1:], topk[:, 1:]
46 |
47 | # get NP scores
48 | all_preds = all_probs.argmax(-1)
49 | Sp = (torch.eye(self.num_classes)[all_preds[topk]]).sum(1)
50 | Sp = Sp / Sp.sum(-1, keepdim=True)
51 | NP = -(torch.log(Sp+1e-9)*Sp).sum(-1)
52 |
53 | # get NA scores
54 | NA = sim_topk.sum(-1) / K
55 | NAU = NP*NA
56 | sort_idxs = NAU.argsort(descending=True)
57 |
58 | q_idxs = []
59 | ax, rem = 0, n
60 | while rem > 0:
61 | if topk[sort_idxs[ax]][0] not in q_idxs:
62 | q_idxs.append(sort_idxs[ax])
63 | rem = n - len(q_idxs)
64 | ax += 1
65 |
66 | q_idxs = np.array(q_idxs)
67 |
68 | return idxs_unlabeled[q_idxs]
69 |
70 |
--------------------------------------------------------------------------------
/active/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import *
2 | from .LASsampler import *
3 |
--------------------------------------------------------------------------------
/active/budget.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class BudgetAllocator():
4 | def __init__(self, budget, cfg):
5 | self.budget = budget
6 | self.max_epochs = cfg.TRAINER.MAX_EPOCHS
7 | self.cfg = cfg
8 | self.build_budgets()
9 |
10 | def build_budgets(self):
11 | self.budgets = np.zeros(self.cfg.TRAINER.MAX_EPOCHS, dtype=np.int32)
12 | rounds = self.cfg.ADA.ROUNDS or np.arange(0, self.cfg.TRAINER.MAX_EPOCHS, self.cfg.TRAINER.MAX_EPOCHS // self.cfg.ADA.ROUND)
13 |
14 | for r in rounds:
15 | self.budgets[r] = self.budget // len(rounds)
16 |
17 | self.budgets[rounds[-1]] += self.budget - self.budgets.sum()
18 |
19 | def get_budget(self, epoch):
20 | curr_budget = self.budgets[epoch]
21 | used_budget = self.budgets[:epoch].sum()
22 | return curr_budget, used_budget
23 |
--------------------------------------------------------------------------------
/active/sampler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Implements active learning sampling strategies
4 | Adapted from https://github.com/ej0cl6/deep-active-learning
5 | """
6 |
7 | import os
8 | import copy
9 | import random
10 | import numpy as np
11 |
12 | from sklearn.cluster import KMeans
13 | from sklearn.metrics.pairwise import euclidean_distances
14 | from sklearn.metrics import pairwise_distances
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | import torch.optim as optim
20 | from torch.utils.data.sampler import Sampler, SubsetRandomSampler
21 | from torch.utils.data import DataLoader
22 | import logging
23 |
24 | import utils.utils as utils
25 | from solver import get_solver
26 | from model import get_model
27 | from dataset.image_list import ImageList
28 | from .utils import row_norms, kmeans_plus_plus_opt, get_embedding, ActualSequentialSampler
29 |
30 | torch.manual_seed(1234)
31 | torch.cuda.manual_seed(1234)
32 | random.seed(1234)
33 | np.random.seed(1234)
34 |
35 | al_dict = {}
36 |
37 | def register_strategy(name):
38 | def decorator(cls):
39 | al_dict[name] = cls
40 | return cls
41 |
42 | return decorator
43 |
44 |
45 | def get_strategy(sample, *args):
46 | if sample not in al_dict: raise NotImplementedError
47 | return al_dict[sample](*args)
48 |
49 |
50 | class SamplingStrategy:
51 | """
52 | Sampling Strategy wrapper class
53 | """
54 |
55 | def __init__(self, src_dset, tgt_dset, source_model, device, num_classes, cfg):
56 | self.src_dset = src_dset
57 | self.tgt_dset = tgt_dset
58 | self.num_classes = num_classes
59 | self.model = copy.deepcopy(source_model) # initialized with source model
60 | self.device = device
61 | self.cfg = cfg
62 | self.discrim = nn.Sequential(
63 | nn.Linear(self.cfg.DATASET.NUM_CLASS, 500),
64 | nn.ReLU(),
65 | nn.Linear(500, 500),
66 | nn.ReLU(),
67 | nn.Linear(500, 2)).to(self.device) # should be initialized by running train_uda
68 | self.idxs_lb = np.zeros(len(self.tgt_dset.train_idx), dtype=bool)
69 | self.solver = None
70 | self.lr_scheduler = None
71 | self.opt_discrim = None
72 | self.opt_net_tgt = None
73 | self.query_dset = tgt_dset.get_dsets()[1] # change to query dataset
74 |
75 | def query(self, n, epoch):
76 | pass
77 |
78 | def update(self, idxs_lb):
79 | self.idxs_lb = idxs_lb
80 |
81 | def pred(self, idxs=None, with_emb=False):
82 | if idxs is None:
83 | idxs = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb]
84 |
85 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs])
86 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler, num_workers=self.cfg.DATALOADER.NUM_WORKERS,
87 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False)
88 | self.model.eval()
89 | all_log_probs = []
90 | all_scores = []
91 | all_embs = []
92 | with torch.no_grad():
93 | for batch_idx, (data, target, _) in enumerate(data_loader):
94 | data, target = data.to(self.device), target.to(self.device)
95 | if with_emb:
96 | scores, embs = self.model(data, with_emb=True)
97 | all_embs.append(embs.cpu())
98 | else:
99 | scores = self.model(data, with_emb=False)
100 | log_probs = nn.LogSoftmax(dim=1)(scores)
101 | all_log_probs.append(log_probs)
102 | all_scores.append(scores)
103 |
104 | all_log_probs = torch.cat(all_log_probs)
105 | all_probs = torch.exp(all_log_probs)
106 | all_scores = torch.cat(all_scores)
107 | if with_emb:
108 | all_embs = torch.cat(all_embs)
109 | return idxs, all_probs, all_log_probs, all_scores, all_embs
110 | else:
111 | return idxs, all_probs, all_log_probs, all_scores
112 |
113 | def train_uda(self, epochs=1):
114 | """
115 | Unsupervised adaptation of source model to target at round 0
116 | Returns:
117 | Model post adaptation
118 | """
119 | source = self.cfg.DATASET.SOURCE_DOMAIN
120 | target = self.cfg.DATASET.TARGET_DOMAIN
121 | uda_strat = self.cfg.ADA.UDA
122 |
123 | adapt_dir = os.path.join('checkpoints', 'adapt')
124 | adapt_net_file = os.path.join(adapt_dir, '{}_{}_{}_{}_{}.pth'.format(uda_strat, source, target,
125 | self.cfg.MODEL.BACKBONE.NAME, self.cfg.TRAINER.MAX_UDA_EPOCHS))
126 |
127 | if not os.path.exists(adapt_dir):
128 | os.makedirs(adapt_dir)
129 |
130 | if self.cfg.TRAINER.LOAD_FROM_CHECKPOINT and os.path.exists(adapt_net_file):
131 | logging.info('Found pretrained uda checkpoint, loading...')
132 | adapt_model = get_model('AdaptNet', num_cls=self.num_classes, weights_init=adapt_net_file,
133 | model=self.cfg.MODEL.BACKBONE.NAME)
134 | else:
135 | src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader = self.build_loaders()
136 |
137 | src_train_loader = self.src_dset.get_loaders()[0]
138 | target_train_dset = self.tgt_dset.get_dsets()[0]
139 | train_sampler = SubsetRandomSampler(self.tgt_dset.train_idx[self.idxs_lb])
140 | tgt_sup_loader = torch.utils.data.DataLoader(target_train_dset, sampler=train_sampler,
141 | num_workers=self.cfg.DATALOADER.NUM_WORKERS, \
142 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False)
143 | tgt_unsup_loader = torch.utils.data.DataLoader(target_train_dset, shuffle=True,
144 | num_workers=self.cfg.DATALOADER.NUM_WORKERS, \
145 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False)
146 |
147 | logging.info('No pretrained checkpoint found, training...')
148 | adapt_model = get_model('AdaptNet', num_cls=self.num_classes, src_weights_init=self.model,
149 | model=self.cfg.MODEL.BACKBONE.NAME, normalize=self.cfg.MODEL.NORMALIZE, temp=self.cfg.MODEL.TEMP)
150 | opt_net_tgt = utils.get_optim(self.cfg.OPTIM.UDA_NAME, adapt_model.tgt_net.parameters(self.cfg.OPTIM.UDA_LR, self.cfg.OPTIM.BASE_LR_MULT),
151 | lr=self.cfg.OPTIM.UDA_LR)
152 | uda_solver = get_solver(uda_strat, adapt_model.tgt_net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader,
153 | joint_sup_loader, opt_net_tgt, False, self.device, self.cfg)
154 |
155 | for epoch in range(epochs):
156 | print("Running uda epoch {}/{}".format(epoch, epochs))
157 | if uda_strat in ['dann']:
158 | opt_discrim = optim.Adadelta(adapt_model.discrim.parameters(), lr=self.cfg.OPTIM.UDA_LR)
159 | uda_solver.solve(epoch, adapt_model.discrim, opt_discrim)
160 | elif uda_strat in ['mme']:
161 | uda_solver.solve(epoch)
162 | else:
163 | logging.info('Warning: no uda training with {}, skipped'.format(uda_strat))
164 | return self.model
165 |
166 | adapt_model.save(adapt_net_file)
167 |
168 | self.model = adapt_model.tgt_net
169 | return self.model
170 |
171 | def build_loaders(self):
172 | src_loader = self.src_dset.get_loaders()[0]
173 | tgt_loader = self.tgt_dset.get_loaders()[0]
174 |
175 | target_train_dset = self.tgt_dset.get_dsets()[0]
176 | train_sampler = SubsetRandomSampler(self.tgt_dset.train_idx[self.idxs_lb])
177 | tgt_sup_loader = DataLoader(target_train_dset, sampler=train_sampler,
178 | num_workers=self.cfg.DATALOADER.NUM_WORKERS, \
179 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False)
180 | train_sampler = SubsetRandomSampler(self.tgt_dset.train_idx[~self.idxs_lb])
181 | tgt_unsup_loader = DataLoader(target_train_dset, sampler=train_sampler,
182 | num_workers=self.cfg.DATALOADER.NUM_WORKERS, \
183 | batch_size=self.cfg.DATALOADER.BATCH_SIZE*self.cfg.DATALOADER.TGT_UNSUP_BS_MUL,
184 | drop_last=False)
185 |
186 | # create joint src_tgt_sup loader as commonly used
187 | joint_list = [self.src_dset.train_dataset.samples[_] for _ in self.src_dset.train_idx] + \
188 | [self.tgt_dset.train_dataset.samples[_] for _ in self.tgt_dset.train_idx[self.idxs_lb]]
189 |
190 | # use source train transform
191 | join_transform = self.src_dset.get_dsets()[0].transform
192 | joint_train_ds = ImageList(joint_list, root=self.cfg.DATASET.ROOT, transform=join_transform)
193 | joint_sup_loader = DataLoader(joint_train_ds, batch_size=self.cfg.DATALOADER.BATCH_SIZE, shuffle=True,
194 | drop_last=False, num_workers=self.cfg.DATALOADER.NUM_WORKERS)
195 |
196 | return src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader
197 |
198 | def train(self, epoch):
199 | src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader = self.build_loaders()
200 |
201 | if self.opt_net_tgt is None:
202 | self.opt_net_tgt = utils.get_optim(self.cfg.OPTIM.NAME, self.model.parameters(self.cfg.OPTIM.ADAPT_LR,
203 | self.cfg.OPTIM.BASE_LR_MULT), lr=self.cfg.OPTIM.ADAPT_LR, weight_decay=0.00001)
204 |
205 | if self.opt_discrim is None:
206 | self.opt_discrim = utils.get_optim(self.cfg.OPTIM.NAME, self.discrim.parameters(), lr=self.cfg.OPTIM.ADAPT_LR, weight_decay=0)
207 |
208 | solver = get_solver(self.cfg.ADA.DA, self.model, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader,
209 | joint_sup_loader, self.opt_net_tgt, True, self.device, self.cfg)
210 |
211 | if self.cfg.ADA.DA in ['dann']:
212 | solver.solve(epoch, self.discrim, self.opt_discrim)
213 | elif self.cfg.ADA.DA in ['LAA', 'RAA']:
214 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx)
215 | seq_query_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler,
216 | num_workers=self.cfg.DATALOADER.NUM_WORKERS,
217 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False)
218 | solver.solve(epoch, seq_query_loader)
219 | else:
220 | solver.solve(epoch)
221 |
222 |
223 | return self.model
224 |
225 |
226 | @register_strategy('random')
227 | class RandomSampling(SamplingStrategy):
228 | """
229 | Uniform sampling
230 | """
231 |
232 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg):
233 | super(RandomSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg)
234 |
235 | def query(self, n, epoch):
236 | return np.random.choice(np.where(self.idxs_lb == 0)[0], n, replace=False)
237 |
238 |
239 |
240 | @register_strategy('entropy')
241 | class EntropySampling(SamplingStrategy):
242 | """
243 | Implements entropy based sampling
244 | """
245 |
246 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg):
247 | super(EntropySampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg)
248 |
249 | def query(self, n, epoch):
250 | idxs_unlabeled, all_probs, all_log_probs, _ = self.pred()
251 | # Compute entropy
252 | entropy = -(all_probs * all_log_probs).sum(1)
253 | q_idxs = (entropy).sort(descending=True)[1][:n]
254 | q_idxs = q_idxs.cpu().numpy()
255 | return idxs_unlabeled[q_idxs]
256 |
257 | @register_strategy('margin')
258 | class MarginSampling(SamplingStrategy):
259 | """
260 | Implements margin based sampling
261 | """
262 |
263 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg):
264 | super(MarginSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg)
265 |
266 | def query(self, n, epoch):
267 | idxs_unlabeled, all_probs, _, _ = self.pred()
268 | # Compute BvSB margin
269 | top2 = torch.topk(all_probs, 2).values
270 | BvSB_scores = 1-(top2[:,0] - top2[:,1]) # use minus for descending sorting
271 | q_idxs = (BvSB_scores).sort(descending=True)[1][:n]
272 | q_idxs = q_idxs.cpu().numpy()
273 | return idxs_unlabeled[q_idxs]
274 |
275 |
276 | @register_strategy('leastConfidence')
277 | class LeastConfidenceSampling(SamplingStrategy):
278 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg):
279 | super(LeastConfidenceSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg)
280 |
281 | def query(self, n, epoch):
282 | idxs_unlabeled, all_probs, _, _ = self.pred()
283 | confidences = -all_probs.max(1)[0] # use minus for descending sorting
284 | q_idxs = (confidences).sort(descending=True)[1][:n]
285 | q_idxs = q_idxs.cpu().numpy()
286 | return idxs_unlabeled[q_idxs]
287 |
288 |
289 | @register_strategy('coreset')
290 | class CoreSetSampling(SamplingStrategy):
291 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg):
292 | super(CoreSetSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg)
293 |
294 | def furthest_first(self, X, X_lb, n):
295 | m = np.shape(X)[0]
296 | if np.shape(X_lb)[0] == 0:
297 | min_dist = np.tile(float("inf"), m)
298 | else:
299 | dist_ctr = pairwise_distances(X, X_lb)
300 | min_dist = np.amin(dist_ctr, axis=1)
301 |
302 | idxs = []
303 |
304 | for i in range(n):
305 | idx = min_dist.argmax()
306 | idxs.append(idx)
307 | dist_new_ctr = pairwise_distances(X, X[[idx], :])
308 | for j in range(m):
309 | min_dist[j] = min(min_dist[j], dist_new_ctr[j, 0])
310 |
311 | return idxs
312 |
313 | def query(self, n, epoch):
314 | idxs = np.arange(len(self.tgt_dset.train_idx))
315 | idxs_unlabeled, _, _, _, all_embs = self.pred(idxs=idxs, with_emb=True)
316 | all_embs = all_embs.numpy()
317 | q_idxs = self.furthest_first(all_embs[~self.idxs_lb, :], all_embs[self.idxs_lb, :], n)
318 | return idxs_unlabeled[q_idxs]
319 |
320 |
321 | @register_strategy('AADA')
322 | class AADASampling(SamplingStrategy):
323 | """
324 | Implements Active Adversarial Domain Adaptation (https://arxiv.org/abs/1904.07848)
325 | """
326 |
327 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg):
328 | super(AADASampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg)
329 |
330 | def query(self, n, epoch):
331 | """
332 | s(x) = frac{1-G*_d}{G_f(x))}{G*_d(G_f(x))} [Diversity] * H(G_y(G_f(x))) [Uncertainty]
333 | """
334 | self.model.eval()
335 | idxs_unlabeled = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb]
336 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs_unlabeled])
337 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler, num_workers=4, batch_size=64,
338 | drop_last=False)
339 |
340 | # Get diversity and entropy
341 | all_log_probs, all_scores = [], []
342 | with torch.no_grad():
343 | for batch_idx, (data, target, _) in enumerate(data_loader):
344 | data, target = data.to(self.device), target.to(self.device)
345 | scores = self.model(data)
346 | log_probs = nn.LogSoftmax(dim=1)(scores)
347 | all_scores.append(scores)
348 | all_log_probs.append(log_probs)
349 |
350 | all_scores = torch.cat(all_scores)
351 | all_log_probs = torch.cat(all_log_probs)
352 |
353 | all_probs = torch.exp(all_log_probs)
354 | disc_scores = nn.Softmax(dim=1)(self.discrim(all_scores))
355 | # Compute diversity
356 | self.D = torch.div(disc_scores[:, 0], disc_scores[:, 1])
357 | # Compute entropy
358 | self.E = -(all_probs * all_log_probs).sum(1)
359 | scores = (self.D * self.E).sort(descending=True)[1]
360 | # Sample from top-2 % instances, as recommended by authors
361 | top_N = max(int(len(scores) * 0.02), n)
362 | q_idxs = np.random.choice(scores[:top_N].cpu().numpy(), n, replace=False)
363 |
364 | return idxs_unlabeled[q_idxs]
365 |
366 |
367 | @register_strategy('BADGE')
368 | class BADGESampling(SamplingStrategy):
369 | """
370 | Implements BADGE: Batch Active Learning by Diverse, Uncertain Gradient Lower Bounds (https://arxiv.org/abs/1906.03671)
371 | """
372 |
373 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg):
374 | super(BADGESampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg)
375 |
376 | def query(self, n, epoch):
377 | idxs_unlabeled = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb]
378 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs_unlabeled])
379 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler, num_workers=self.cfg.DATALOADER.NUM_WORKERS,
380 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False)
381 | self.model.eval()
382 |
383 | if 'LeNet' in self.cfg.MODEL.BACKBONE.NAME:
384 | emb_dim = 500
385 | elif 'ResNet34' in self.cfg.MODEL.BACKBONE.NAME:
386 | emb_dim = 512
387 | elif 'ResNet50' in self.cfg.MODEL.BACKBONE.NAME:
388 | emb_dim = 256
389 |
390 | tgt_emb = torch.zeros([len(data_loader.sampler), self.num_classes])
391 | tgt_pen_emb = torch.zeros([len(data_loader.sampler), emb_dim])
392 | tgt_lab = torch.zeros(len(data_loader.sampler))
393 | tgt_preds = torch.zeros(len(data_loader.sampler))
394 | batch_sz = self.cfg.DATALOADER.BATCH_SIZE
395 |
396 | with torch.no_grad():
397 | for batch_idx, (data, target, _) in enumerate(data_loader):
398 | data, target = data.to(self.device), target.to(self.device)
399 | e1, e2 = self.model(data, with_emb=True)
400 | tgt_pen_emb[batch_idx * batch_sz:batch_idx * batch_sz + min(batch_sz, e2.shape[0]), :] = e2.cpu()
401 | tgt_emb[batch_idx * batch_sz:batch_idx * batch_sz + min(batch_sz, e1.shape[0]), :] = e1.cpu()
402 | tgt_lab[batch_idx * batch_sz:batch_idx * batch_sz + min(batch_sz, e1.shape[0])] = target
403 | tgt_preds[batch_idx * batch_sz:batch_idx * batch_sz + min(batch_sz, e1.shape[0])] = e1.argmax(dim=1,
404 | keepdim=True).squeeze()
405 |
406 | # Compute uncertainty gradient
407 | tgt_scores = nn.Softmax(dim=1)(tgt_emb)
408 | tgt_scores_delta = torch.zeros_like(tgt_scores)
409 | tgt_scores_delta[torch.arange(len(tgt_scores_delta)), tgt_preds.long()] = 1
410 |
411 | # Uncertainty embedding
412 | badge_uncertainty = (tgt_scores - tgt_scores_delta)
413 |
414 | # Seed with maximum uncertainty example
415 | max_norm = row_norms(badge_uncertainty.cpu().numpy()).argmax()
416 |
417 | _, q_idxs = kmeans_plus_plus_opt(badge_uncertainty.cpu().numpy(), tgt_pen_emb.cpu().numpy(), n,
418 | init=[max_norm])
419 |
420 | return idxs_unlabeled[q_idxs]
421 |
422 |
423 | @register_strategy('kmeans')
424 | class KmeansSampling(SamplingStrategy):
425 | """
426 | Implements CLUE: CLustering via Uncertainty-weighted Embeddings
427 | """
428 |
429 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg):
430 | super(KmeansSampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg)
431 |
432 | def query(self, n, epoch):
433 | idxs_unlabeled, _, _, _, all_embs = self.pred(with_emb=True)
434 | all_embs = all_embs.numpy()
435 |
436 | # Run weighted K-means over embeddings
437 | km = KMeans(n_clusters=n)
438 | km.fit(all_embs)
439 |
440 | # use below code to match CLUE implementation
441 | # Find nearest neighbors to inferred centroids
442 | dists = euclidean_distances(km.cluster_centers_, all_embs)
443 | sort_idxs = dists.argsort(axis=1)
444 | q_idxs = []
445 | ax, rem = 0, n
446 | while rem > 0:
447 | q_idxs.extend(list(sort_idxs[:, ax][:rem]))
448 | q_idxs = list(set(q_idxs))
449 | rem = n - len(q_idxs)
450 | ax += 1
451 |
452 | return idxs_unlabeled[q_idxs]
453 |
454 |
455 | @register_strategy('CLUE')
456 | class CLUESampling(SamplingStrategy):
457 | """
458 | Implements CLUE: CLustering via Uncertainty-weighted Embeddings
459 | """
460 |
461 | def __init__(self, src_dset, tgt_dset, model, device, num_classes, cfg):
462 | super(CLUESampling, self).__init__(src_dset, tgt_dset, model, device, num_classes, cfg)
463 | self.random_state = np.random.RandomState(1234)
464 | self.T = 0.1
465 |
466 | def query(self, n, epoch):
467 | idxs_unlabeled = np.arange(len(self.tgt_dset.train_idx))[~self.idxs_lb]
468 | train_sampler = ActualSequentialSampler(self.tgt_dset.train_idx[idxs_unlabeled])
469 | data_loader = torch.utils.data.DataLoader(self.query_dset, sampler=train_sampler, num_workers=self.cfg.DATALOADER.NUM_WORKERS, \
470 | batch_size=self.cfg.DATALOADER.BATCH_SIZE, drop_last=False)
471 | self.model.eval()
472 |
473 | if 'LeNet' in self.cfg.MODEL.BACKBONE.NAME:
474 | emb_dim = 500
475 | elif 'ResNet34' in self.cfg.MODEL.BACKBONE.NAME:
476 | emb_dim = 512
477 | elif 'ResNet50' in self.cfg.MODEL.BACKBONE.NAME:
478 | emb_dim = 256
479 |
480 | # Get embedding of target instances
481 | tgt_emb, tgt_lab, tgt_preds, tgt_pen_emb = get_embedding(self.model, data_loader, self.device,
482 | self.num_classes, \
483 | self.cfg, with_emb=True, emb_dim=emb_dim)
484 | tgt_pen_emb = tgt_pen_emb.cpu().numpy()
485 | tgt_scores = torch.softmax(tgt_emb / self.T, dim=-1)
486 | tgt_scores += 1e-8
487 | sample_weights = -(tgt_scores * torch.log(tgt_scores)).sum(1).cpu().numpy()
488 |
489 | # Run weighted K-means over embeddings
490 | km = KMeans(n)
491 | km.fit(tgt_pen_emb, sample_weight=sample_weights)
492 |
493 | # Find nearest neighbors to inferred centroids
494 | dists = euclidean_distances(km.cluster_centers_, tgt_pen_emb)
495 | sort_idxs = dists.argsort(axis=1)
496 | q_idxs = []
497 | ax, rem = 0, n
498 | while rem > 0:
499 | q_idxs.extend(list(sort_idxs[:, ax][:rem]))
500 | q_idxs = list(set(q_idxs))
501 | rem = n - len(q_idxs)
502 | ax += 1
503 |
504 | return idxs_unlabeled[q_idxs]
505 |
506 |
--------------------------------------------------------------------------------
/active/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data.sampler import Sampler
4 | import copy
5 |
6 | def ActualSequentialLoader(subsetRandomLoader, indices=None, transform=None, batch_size=None):
7 | indices = indices if indices is not None else subsetRandomLoader.sampler.indices
8 | train_sampler = ActualSequentialSampler(indices)
9 | dataset = copy.deepcopy(subsetRandomLoader.dataset)
10 | if transform is not None:
11 | dataset.transform = transform
12 |
13 | batch_size = batch_size if batch_size is not None else subsetRandomLoader.batch_size
14 | actualSequentialLoader = torch.utils.data.DataLoader(dataset, sampler=train_sampler,
15 | num_workers=subsetRandomLoader.num_workers,
16 | batch_size=batch_size, drop_last=False)
17 | return actualSequentialLoader
18 |
19 | class ActualSequentialSampler(Sampler):
20 | r"""Samples elements sequentially, always in the same order.
21 |
22 | Arguments:
23 | data_source (Dataset): dataset to sample from
24 | """
25 |
26 | def __init__(self, data_source):
27 | self.data_source = data_source
28 |
29 | def __iter__(self):
30 | return iter(self.data_source)
31 |
32 | def __len__(self):
33 | return len(self.data_source)
34 |
35 | def row_norms(X, squared=False):
36 | """Row-wise (squared) Euclidean norm of X.
37 | Equivalent to np.sqrt((X * X).sum(axis=1)), but also supports sparse
38 | matrices and does not create an X.shape-sized temporary.
39 | Performs no input validation.
40 | Parameters
41 | ----------
42 | X : array_like
43 | The input array
44 | squared : bool, optional (default = False)
45 | If True, return squared norms.
46 | Returns
47 | -------
48 | array_like
49 | The row-wise (squared) Euclidean norm of X.
50 | """
51 | norms = np.einsum('ij,ij->i', X, X)
52 |
53 | if not squared:
54 | np.sqrt(norms, norms)
55 | return norms
56 |
57 | def outer_product_opt(c1, d1, c2, d2):
58 | """Computes euclidean distance between a1xb1 and a2xb2 without evaluating / storing cross products
59 | """
60 | B1, B2 = c1.shape[0], c2.shape[0]
61 | t1 = np.matmul(np.matmul(c1[:, None, :], c1[:, None, :].swapaxes(2, 1)), np.matmul(d1[:, None, :], d1[:, None, :].swapaxes(2, 1)))
62 | t2 = np.matmul(np.matmul(c2[:, None, :], c2[:, None, :].swapaxes(2, 1)), np.matmul(d2[:, None, :], d2[:, None, :].swapaxes(2, 1)))
63 | t3 = np.matmul(c1, c2.T) * np.matmul(d1, d2.T)
64 | t1 = t1.reshape(B1, 1).repeat(B2, axis=1)
65 | t2 = t2.reshape(1, B2).repeat(B1, axis=0)
66 | return t1 + t2 - 2*t3
67 |
68 | def kmeans_plus_plus_opt(X1, X2, n_clusters, init=[0], random_state=np.random.RandomState(1234), n_local_trials=None):
69 | """Init n_clusters seeds according to k-means++ (adapted from scikit-learn source code)
70 | Parameters
71 | ----------
72 | X1, X2 : array or sparse matrix
73 | The data to pick seeds for. To avoid memory copy, the input data
74 | should be double precision (dtype=np.float64).
75 | n_clusters : integer
76 | The number of seeds to choose
77 | init : list
78 | List of points already picked
79 | random_state : int, RandomState instance
80 | The generator used to initialize the centers. Use an int to make the
81 | randomness deterministic.
82 | See :term:`Glossary `.
83 | n_local_trials : integer, optional
84 | The number of seeding trials for each center (except the first),
85 | of which the one reducing inertia the most is greedily chosen.
86 | Set to None to make the number of trials depend logarithmically
87 | on the number of seeds (2+log(k)); this is the default.
88 | Notes
89 | -----
90 | Selects initial cluster centers for k-mean clustering in a smart way
91 | to speed up convergence. see: Arthur, D. and Vassilvitskii, S.
92 | "k-means++: the advantages of careful seeding". ACM-SIAM symposium
93 | on Discrete algorithms. 2007
94 | Version ported from http://www.stanford.edu/~darthur/kMeansppTest.zip,
95 | which is the implementation used in the aforementioned paper.
96 | """
97 |
98 | n_samples, n_feat1 = X1.shape
99 | _, n_feat2 = X2.shape
100 | # x_squared_norms = row_norms(X, squared=True)
101 | centers1 = np.empty((n_clusters+len(init)-1, n_feat1), dtype=X1.dtype)
102 | centers2 = np.empty((n_clusters+len(init)-1, n_feat2), dtype=X1.dtype)
103 |
104 | idxs = np.empty((n_clusters+len(init)-1,), dtype=np.long)
105 |
106 | # Set the number of local seeding trials if none is given
107 | if n_local_trials is None:
108 | # This is what Arthur/Vassilvitskii tried, but did not report
109 | # specific results for other than mentioning in the conclusion
110 | # that it helped.
111 | n_local_trials = 2 + int(np.log(n_clusters))
112 |
113 | # Pick first center randomly
114 | center_id = init
115 |
116 | centers1[:len(init)] = X1[center_id]
117 | centers2[:len(init)] = X2[center_id]
118 | idxs[:len(init)] = center_id
119 |
120 | # Initialize list of closest distances and calculate current potential
121 | distance_to_candidates = outer_product_opt(centers1[:len(init)], centers2[:len(init)], X1, X2).reshape(len(init), -1)
122 |
123 | candidates_pot = distance_to_candidates.sum(axis=1)
124 | best_candidate = np.argmin(candidates_pot)
125 | current_pot = candidates_pot[best_candidate]
126 | closest_dist_sq = distance_to_candidates[best_candidate]
127 |
128 | # Pick the remaining n_clusters-1 points
129 | for c in range(len(init), len(init)+n_clusters-1):
130 | # Choose center candidates by sampling with probability proportional
131 | # to the squared distance to the closest existing center
132 | rand_vals = random_state.random_sample(n_local_trials) * current_pot
133 | candidate_ids = np.searchsorted(closest_dist_sq.cumsum(),
134 | rand_vals)
135 | # XXX: numerical imprecision can result in a candidate_id out of range
136 | np.clip(candidate_ids, None, closest_dist_sq.size - 1,
137 | out=candidate_ids)
138 |
139 | # Compute distances to center candidates
140 | distance_to_candidates = outer_product_opt(X1[candidate_ids], X2[candidate_ids], X1, X2).reshape(len(candidate_ids), -1)
141 |
142 | # update closest distances squared and potential for each candidate
143 | np.minimum(closest_dist_sq, distance_to_candidates,
144 | out=distance_to_candidates)
145 | candidates_pot = distance_to_candidates.sum(axis=1)
146 |
147 | # Decide which candidate is the best
148 | best_candidate = np.argmin(candidates_pot)
149 | current_pot = candidates_pot[best_candidate]
150 | closest_dist_sq = distance_to_candidates[best_candidate]
151 | best_candidate = candidate_ids[best_candidate]
152 |
153 | idxs[c] = best_candidate
154 |
155 | return None, idxs[len(init)-1:]
156 |
157 | def get_embedding(model, loader, device, num_classes, cfg, with_emb=False, emb_dim=512):
158 | model.eval()
159 | embedding = torch.zeros([len(loader.sampler), num_classes])
160 | embedding_pen = torch.zeros([len(loader.sampler), emb_dim])
161 | labels = torch.zeros(len(loader.sampler))
162 | preds = torch.zeros(len(loader.sampler))
163 | batch_sz = cfg.DATALOADER.BATCH_SIZE
164 | with torch.no_grad():
165 | for batch_idx, (data, target, _) in enumerate(loader):
166 | data, target = data.to(device), target.to(device)
167 | if with_emb:
168 | e1, e2 = model(data, with_emb=True)
169 | embedding_pen[batch_idx*batch_sz:batch_idx*batch_sz + min(batch_sz, e2.shape[0]), :] = e2.cpu()
170 | else:
171 | e1 = model(data, with_emb=False)
172 |
173 | embedding[batch_idx*batch_sz:batch_idx*batch_sz + min(batch_sz, e1.shape[0]), :] = e1.cpu()
174 | labels[batch_idx*batch_sz:batch_idx*batch_sz + min(batch_sz, e1.shape[0])] = target
175 | preds[batch_idx*batch_sz:batch_idx*batch_sz + min(batch_sz, e1.shape[0])] = e1.argmax(dim=1, keepdim=True).squeeze()
176 |
177 | return embedding, labels, preds, embedding_pen
178 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .defaults import _C as cfg
2 |
3 |
4 | def get_cfg_default():
5 | return cfg.clone()
6 |
--------------------------------------------------------------------------------
/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | ###########################
4 | # Config definition
5 | ###########################
6 |
7 | _C = CN()
8 |
9 | _C.SEED = 0
10 | _C.NOTE = ''
11 |
12 | ###########################
13 | # Dataset
14 | ###########################
15 | _C.DATASET = CN()
16 | # Directory where datasets are stored
17 | _C.DATASET.ROOT = ''
18 | _C.DATASET.NAME = ''
19 | # List of domains
20 | _C.DATASET.SOURCE_DOMAINS = []
21 | _C.DATASET.TARGET_DOMAINS = []
22 | _C.DATASET.SOURCE_DOMAIN = ''
23 | _C.DATASET.TARGET_DOMAIN = ''
24 | _C.DATASET.SOURCE_VALID_TYPE = 'val'
25 | _C.DATASET.SOURCE_VALID_RATIO = 1.0
26 | _C.DATASET.SOURCE_TRANSFORMS = ('Resize','RandomCrop','Normalize')
27 | _C.DATASET.TARGET_TRANSFORMS = ('Resize','RandomCrop','Normalize')
28 | _C.DATASET.QUERY_TRANSFORMS = ('Resize','CenterCrop','Normalize')
29 | _C.DATASET.TEST_TRANSFORMS = ('Resize','CenterCrop','Normalize')
30 | _C.DATASET.RAND_TRANSFORMS = 'rand_transform'
31 | _C.DATASET.NUM_CLASS = 12
32 |
33 | ###########################
34 | # Dataloader
35 | ###########################
36 | _C.DATALOADER = CN()
37 | _C.DATALOADER.NUM_WORKERS = 4
38 | _C.DATALOADER.BATCH_SIZE = 32
39 | _C.DATALOADER.TGT_UNSUP_BS_MUL = 1
40 |
41 | ###########################
42 | # Model
43 | ###########################
44 | _C.MODEL = CN()
45 | # Path to model weights (for initialization)
46 | _C.MODEL.INIT_WEIGHTS = ''
47 | _C.MODEL.BACKBONE = CN()
48 | _C.MODEL.BACKBONE.NAME = 'ResNet50Fc'
49 | _C.MODEL.BACKBONE.PRETRAINED = True
50 | _C.MODEL.NORMALIZE = False
51 | _C.MODEL.TEMP = 0.05
52 |
53 | ###########################
54 | # Optimization
55 | ###########################
56 | _C.OPTIM = CN()
57 | _C.OPTIM.NAME = 'Adadelta'
58 | _C.OPTIM.SOURCE_NAME = 'Adadelta'
59 | _C.OPTIM.UDA_NAME = 'Adadelta'
60 | _C.OPTIM.SOURCE_LR = 0.1
61 | _C.OPTIM.UDA_LR = 0.1
62 | _C.OPTIM.ADAPT_LR = 0.1
63 | _C.OPTIM.BASE_LR_MULT = 0.1
64 |
65 | ###########################
66 | # Trainer specifics
67 | ###########################
68 | _C.TRAINER = CN()
69 | _C.TRAINER.LOAD_FROM_CHECKPOINT = True
70 | _C.TRAINER.TRAIN_ON_SOURCE = True
71 | _C.TRAINER.MAX_SOURCE_EPOCHS = 20
72 | _C.TRAINER.MAX_UDA_EPOCHS = 20
73 | _C.TRAINER.MAX_EPOCHS = 20
74 | _C.TRAINER.EVAL_ACC = True
75 | _C.TRAINER.ITER_PER_EPOCH = None
76 |
77 | ###########################
78 | # Active DA
79 | ###########################
80 | _C.ADA = CN()
81 | _C.ADA.TASKS = None
82 | _C.ADA.BUDGET = 0.05
83 | _C.ADA.ROUND = 5
84 | _C.ADA.ROUNDS = None
85 | _C.ADA.UDA = 'dann'
86 | _C.ADA.DA = 'ft'
87 | _C.ADA.AL = 'LAS'
88 | _C.ADA.SRC_SUP_WT = 1.0
89 | _C.ADA.TGT_SUP_WT = 1.0
90 | _C.ADA.UNSUP_WT = 0.1
91 | _C.ADA.CEN_WT = 0.1
92 |
93 | ###########################
94 | # LADA
95 | ###########################
96 | _C.LADA = CN()
97 | _C.LADA.S_K = 10
98 | _C.LADA.S_M = 10
99 | _C.LADA.S_PROP_ITER = 1
100 | _C.LADA.S_PROP_COEF = 1.0
101 | _C.LADA.A_K = 10
102 | _C.LADA.A_TH = 0.9
103 | _C.LADA.A_RAND_NUM = 1
104 |
--------------------------------------------------------------------------------
/configs/office31.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | ROOT: 'data/'
3 | NAME: 'office31'
4 | SOURCE_DOMAINS: ['webcam', 'amazon', 'dslr']
5 | TARGET_DOMAINS: ['webcam', 'amazon', 'dslr']
6 | NUM_CLASS: 31
7 |
8 | DATALOADER:
9 | BATCH_SIZE: 32
10 |
11 | OPTIM:
12 | NAME: 'Adadelta'
13 | SOURCE_LR: 0.1
14 | BASE_LR_MULT: 0.1
15 |
16 | TRAINER:
17 | MAX_EPOCHS: 40
18 | TRAIN_ON_SOURCE : False
19 | MAX_UDA_EPOCHS: 0
20 |
21 | ADA:
22 | DA : 'ft'
23 | AL : 'random'
24 | ROUNDS : [10, 12, 14, 16, 18]
25 |
26 | LADA:
27 | S_K : 5
28 | S_M : 10
29 | A_K : 5
30 |
31 | SEED: 0 # 0,1,2,3,4 for five random experiments
32 |
33 |
--------------------------------------------------------------------------------
/configs/officehome.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | ROOT: 'data/'
3 | NAME: 'officehome'
4 | SOURCE_DOMAINS: ['Art', 'Clipart', 'Product', 'RealWorld']
5 | TARGET_DOMAINS: ['Art', 'Clipart', 'Product', 'RealWorld']
6 | NUM_CLASS: 65
7 |
8 | DATALOADER:
9 | BATCH_SIZE: 32
10 |
11 | OPTIM:
12 | NAME: 'Adadelta'
13 | SOURCE_LR: 0.1
14 | BASE_LR_MULT: 0.1
15 |
16 | TRAINER:
17 | MAX_EPOCHS: 40
18 | TRAIN_ON_SOURCE : False
19 | MAX_UDA_EPOCHS: 0
20 |
21 | ADA:
22 | DA : 'ft'
23 | AL : 'random'
24 | ROUNDS : [10, 12, 14, 16, 18]
25 |
26 | SEED: 0 # 0,1,2,3,4 for five random experiments
27 |
28 |
--------------------------------------------------------------------------------
/configs/officehome_RSUT.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | ROOT: 'data/'
3 | NAME: 'officehome_RSUT'
4 | SOURCE_DOMAINS: ['Clipart_RS', 'Product_RS', 'RealWorld_RS']
5 | TARGET_DOMAINS: ['Clipart_UT', 'Product_UT', 'RealWorld_UT']
6 | NUM_CLASS: 65
7 |
8 | DATALOADER:
9 | BATCH_SIZE: 32
10 |
11 | OPTIM:
12 | NAME: 'Adadelta'
13 | SOURCE_LR: 0.1
14 | BASE_LR_MULT: 0.1
15 |
16 | TRAINER:
17 | MAX_EPOCHS: 40
18 | TRAIN_ON_SOURCE : False
19 | MAX_UDA_EPOCHS: 0
20 |
21 | ADA:
22 | TASKS: [['Clipart_RS','Product_UT'], ['Clipart_RS','RealWorld_UT'], ['Product_RS','Clipart_UT'], ['Product_RS','RealWorld_UT'], ['RealWorld_RS','Clipart_UT'], ['RealWorld_RS','Product_UT']]
23 | DA : 'ft'
24 | AL : 'random'
25 | ROUNDS : [10, 12, 14, 16, 18]
26 |
27 | SEED: 0 # 0,1,2,3,4 for five random experiments
28 |
29 |
--------------------------------------------------------------------------------
/configs/visda.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | ROOT: 'data/'
3 | NAME: 'visda'
4 | SOURCE_DOMAINS: ['train']
5 | TARGET_DOMAINS: ['validation']
6 | NUM_CLASS: 12
7 |
8 | DATALOADER:
9 | BATCH_SIZE: 32
10 |
11 | OPTIM:
12 | NAME: 'Adadelta'
13 | SOURCE_LR: 0.1
14 | BASE_LR_MULT: 0.1
15 |
16 | TRAINER:
17 | MAX_EPOCHS: 10
18 | TRAIN_ON_SOURCE : True
19 | MAX_UDA_EPOCHS: 0
20 | MAX_SOURCE_EPOCHS : 1
21 |
22 | ADA:
23 | DA : 'ft'
24 | AL : 'random'
25 | ROUNDS : [0,2,4,6,8]
26 |
27 | LADA:
28 | S_K : 10
29 | S_M : 5
30 | A_K : 10
31 |
32 | SEED: 0 # 0,1,2,3,4 for five random experiments
33 |
34 |
--------------------------------------------------------------------------------
/data/image_list/office31/dslr.txt:
--------------------------------------------------------------------------------
1 | office31/dslr/images/calculator/frame_0001.jpg 5
2 | office31/dslr/images/calculator/frame_0002.jpg 5
3 | office31/dslr/images/calculator/frame_0003.jpg 5
4 | office31/dslr/images/calculator/frame_0004.jpg 5
5 | office31/dslr/images/calculator/frame_0005.jpg 5
6 | office31/dslr/images/calculator/frame_0006.jpg 5
7 | office31/dslr/images/calculator/frame_0007.jpg 5
8 | office31/dslr/images/calculator/frame_0008.jpg 5
9 | office31/dslr/images/calculator/frame_0009.jpg 5
10 | office31/dslr/images/calculator/frame_0010.jpg 5
11 | office31/dslr/images/calculator/frame_0011.jpg 5
12 | office31/dslr/images/calculator/frame_0012.jpg 5
13 | office31/dslr/images/ring_binder/frame_0001.jpg 24
14 | office31/dslr/images/ring_binder/frame_0002.jpg 24
15 | office31/dslr/images/ring_binder/frame_0003.jpg 24
16 | office31/dslr/images/ring_binder/frame_0004.jpg 24
17 | office31/dslr/images/ring_binder/frame_0005.jpg 24
18 | office31/dslr/images/ring_binder/frame_0006.jpg 24
19 | office31/dslr/images/ring_binder/frame_0007.jpg 24
20 | office31/dslr/images/ring_binder/frame_0008.jpg 24
21 | office31/dslr/images/ring_binder/frame_0009.jpg 24
22 | office31/dslr/images/ring_binder/frame_0010.jpg 24
23 | office31/dslr/images/printer/frame_0001.jpg 21
24 | office31/dslr/images/printer/frame_0002.jpg 21
25 | office31/dslr/images/printer/frame_0003.jpg 21
26 | office31/dslr/images/printer/frame_0004.jpg 21
27 | office31/dslr/images/printer/frame_0005.jpg 21
28 | office31/dslr/images/printer/frame_0006.jpg 21
29 | office31/dslr/images/printer/frame_0007.jpg 21
30 | office31/dslr/images/printer/frame_0008.jpg 21
31 | office31/dslr/images/printer/frame_0009.jpg 21
32 | office31/dslr/images/printer/frame_0010.jpg 21
33 | office31/dslr/images/printer/frame_0011.jpg 21
34 | office31/dslr/images/printer/frame_0012.jpg 21
35 | office31/dslr/images/printer/frame_0013.jpg 21
36 | office31/dslr/images/printer/frame_0014.jpg 21
37 | office31/dslr/images/printer/frame_0015.jpg 21
38 | office31/dslr/images/keyboard/frame_0001.jpg 11
39 | office31/dslr/images/keyboard/frame_0002.jpg 11
40 | office31/dslr/images/keyboard/frame_0003.jpg 11
41 | office31/dslr/images/keyboard/frame_0004.jpg 11
42 | office31/dslr/images/keyboard/frame_0005.jpg 11
43 | office31/dslr/images/keyboard/frame_0006.jpg 11
44 | office31/dslr/images/keyboard/frame_0007.jpg 11
45 | office31/dslr/images/keyboard/frame_0008.jpg 11
46 | office31/dslr/images/keyboard/frame_0009.jpg 11
47 | office31/dslr/images/keyboard/frame_0010.jpg 11
48 | office31/dslr/images/scissors/frame_0001.jpg 26
49 | office31/dslr/images/scissors/frame_0002.jpg 26
50 | office31/dslr/images/scissors/frame_0003.jpg 26
51 | office31/dslr/images/scissors/frame_0004.jpg 26
52 | office31/dslr/images/scissors/frame_0005.jpg 26
53 | office31/dslr/images/scissors/frame_0006.jpg 26
54 | office31/dslr/images/scissors/frame_0007.jpg 26
55 | office31/dslr/images/scissors/frame_0008.jpg 26
56 | office31/dslr/images/scissors/frame_0009.jpg 26
57 | office31/dslr/images/scissors/frame_0010.jpg 26
58 | office31/dslr/images/scissors/frame_0011.jpg 26
59 | office31/dslr/images/scissors/frame_0012.jpg 26
60 | office31/dslr/images/scissors/frame_0013.jpg 26
61 | office31/dslr/images/scissors/frame_0014.jpg 26
62 | office31/dslr/images/scissors/frame_0015.jpg 26
63 | office31/dslr/images/scissors/frame_0016.jpg 26
64 | office31/dslr/images/scissors/frame_0017.jpg 26
65 | office31/dslr/images/scissors/frame_0018.jpg 26
66 | office31/dslr/images/laptop_computer/frame_0001.jpg 12
67 | office31/dslr/images/laptop_computer/frame_0002.jpg 12
68 | office31/dslr/images/laptop_computer/frame_0003.jpg 12
69 | office31/dslr/images/laptop_computer/frame_0004.jpg 12
70 | office31/dslr/images/laptop_computer/frame_0005.jpg 12
71 | office31/dslr/images/laptop_computer/frame_0006.jpg 12
72 | office31/dslr/images/laptop_computer/frame_0007.jpg 12
73 | office31/dslr/images/laptop_computer/frame_0008.jpg 12
74 | office31/dslr/images/laptop_computer/frame_0009.jpg 12
75 | office31/dslr/images/laptop_computer/frame_0010.jpg 12
76 | office31/dslr/images/laptop_computer/frame_0011.jpg 12
77 | office31/dslr/images/laptop_computer/frame_0012.jpg 12
78 | office31/dslr/images/laptop_computer/frame_0013.jpg 12
79 | office31/dslr/images/laptop_computer/frame_0014.jpg 12
80 | office31/dslr/images/laptop_computer/frame_0015.jpg 12
81 | office31/dslr/images/laptop_computer/frame_0016.jpg 12
82 | office31/dslr/images/laptop_computer/frame_0017.jpg 12
83 | office31/dslr/images/laptop_computer/frame_0018.jpg 12
84 | office31/dslr/images/laptop_computer/frame_0019.jpg 12
85 | office31/dslr/images/laptop_computer/frame_0020.jpg 12
86 | office31/dslr/images/laptop_computer/frame_0021.jpg 12
87 | office31/dslr/images/laptop_computer/frame_0022.jpg 12
88 | office31/dslr/images/laptop_computer/frame_0023.jpg 12
89 | office31/dslr/images/laptop_computer/frame_0024.jpg 12
90 | office31/dslr/images/mouse/frame_0001.jpg 16
91 | office31/dslr/images/mouse/frame_0002.jpg 16
92 | office31/dslr/images/mouse/frame_0003.jpg 16
93 | office31/dslr/images/mouse/frame_0004.jpg 16
94 | office31/dslr/images/mouse/frame_0005.jpg 16
95 | office31/dslr/images/mouse/frame_0006.jpg 16
96 | office31/dslr/images/mouse/frame_0007.jpg 16
97 | office31/dslr/images/mouse/frame_0008.jpg 16
98 | office31/dslr/images/mouse/frame_0009.jpg 16
99 | office31/dslr/images/mouse/frame_0010.jpg 16
100 | office31/dslr/images/mouse/frame_0011.jpg 16
101 | office31/dslr/images/mouse/frame_0012.jpg 16
102 | office31/dslr/images/monitor/frame_0001.jpg 15
103 | office31/dslr/images/monitor/frame_0002.jpg 15
104 | office31/dslr/images/monitor/frame_0003.jpg 15
105 | office31/dslr/images/monitor/frame_0004.jpg 15
106 | office31/dslr/images/monitor/frame_0005.jpg 15
107 | office31/dslr/images/monitor/frame_0006.jpg 15
108 | office31/dslr/images/monitor/frame_0007.jpg 15
109 | office31/dslr/images/monitor/frame_0008.jpg 15
110 | office31/dslr/images/monitor/frame_0009.jpg 15
111 | office31/dslr/images/monitor/frame_0010.jpg 15
112 | office31/dslr/images/monitor/frame_0011.jpg 15
113 | office31/dslr/images/monitor/frame_0012.jpg 15
114 | office31/dslr/images/monitor/frame_0013.jpg 15
115 | office31/dslr/images/monitor/frame_0014.jpg 15
116 | office31/dslr/images/monitor/frame_0015.jpg 15
117 | office31/dslr/images/monitor/frame_0016.jpg 15
118 | office31/dslr/images/monitor/frame_0017.jpg 15
119 | office31/dslr/images/monitor/frame_0018.jpg 15
120 | office31/dslr/images/monitor/frame_0019.jpg 15
121 | office31/dslr/images/monitor/frame_0020.jpg 15
122 | office31/dslr/images/monitor/frame_0021.jpg 15
123 | office31/dslr/images/monitor/frame_0022.jpg 15
124 | office31/dslr/images/mug/frame_0001.jpg 17
125 | office31/dslr/images/mug/frame_0002.jpg 17
126 | office31/dslr/images/mug/frame_0003.jpg 17
127 | office31/dslr/images/mug/frame_0004.jpg 17
128 | office31/dslr/images/mug/frame_0005.jpg 17
129 | office31/dslr/images/mug/frame_0006.jpg 17
130 | office31/dslr/images/mug/frame_0007.jpg 17
131 | office31/dslr/images/mug/frame_0008.jpg 17
132 | office31/dslr/images/tape_dispenser/frame_0001.jpg 29
133 | office31/dslr/images/tape_dispenser/frame_0002.jpg 29
134 | office31/dslr/images/tape_dispenser/frame_0003.jpg 29
135 | office31/dslr/images/tape_dispenser/frame_0004.jpg 29
136 | office31/dslr/images/tape_dispenser/frame_0005.jpg 29
137 | office31/dslr/images/tape_dispenser/frame_0006.jpg 29
138 | office31/dslr/images/tape_dispenser/frame_0007.jpg 29
139 | office31/dslr/images/tape_dispenser/frame_0008.jpg 29
140 | office31/dslr/images/tape_dispenser/frame_0009.jpg 29
141 | office31/dslr/images/tape_dispenser/frame_0010.jpg 29
142 | office31/dslr/images/tape_dispenser/frame_0011.jpg 29
143 | office31/dslr/images/tape_dispenser/frame_0012.jpg 29
144 | office31/dslr/images/tape_dispenser/frame_0013.jpg 29
145 | office31/dslr/images/tape_dispenser/frame_0014.jpg 29
146 | office31/dslr/images/tape_dispenser/frame_0015.jpg 29
147 | office31/dslr/images/tape_dispenser/frame_0016.jpg 29
148 | office31/dslr/images/tape_dispenser/frame_0017.jpg 29
149 | office31/dslr/images/tape_dispenser/frame_0018.jpg 29
150 | office31/dslr/images/tape_dispenser/frame_0019.jpg 29
151 | office31/dslr/images/tape_dispenser/frame_0020.jpg 29
152 | office31/dslr/images/tape_dispenser/frame_0021.jpg 29
153 | office31/dslr/images/tape_dispenser/frame_0022.jpg 29
154 | office31/dslr/images/pen/frame_0001.jpg 19
155 | office31/dslr/images/pen/frame_0002.jpg 19
156 | office31/dslr/images/pen/frame_0003.jpg 19
157 | office31/dslr/images/pen/frame_0004.jpg 19
158 | office31/dslr/images/pen/frame_0005.jpg 19
159 | office31/dslr/images/pen/frame_0006.jpg 19
160 | office31/dslr/images/pen/frame_0007.jpg 19
161 | office31/dslr/images/pen/frame_0008.jpg 19
162 | office31/dslr/images/pen/frame_0009.jpg 19
163 | office31/dslr/images/pen/frame_0010.jpg 19
164 | office31/dslr/images/bike/frame_0001.jpg 1
165 | office31/dslr/images/bike/frame_0002.jpg 1
166 | office31/dslr/images/bike/frame_0003.jpg 1
167 | office31/dslr/images/bike/frame_0004.jpg 1
168 | office31/dslr/images/bike/frame_0005.jpg 1
169 | office31/dslr/images/bike/frame_0006.jpg 1
170 | office31/dslr/images/bike/frame_0007.jpg 1
171 | office31/dslr/images/bike/frame_0008.jpg 1
172 | office31/dslr/images/bike/frame_0009.jpg 1
173 | office31/dslr/images/bike/frame_0010.jpg 1
174 | office31/dslr/images/bike/frame_0011.jpg 1
175 | office31/dslr/images/bike/frame_0012.jpg 1
176 | office31/dslr/images/bike/frame_0013.jpg 1
177 | office31/dslr/images/bike/frame_0014.jpg 1
178 | office31/dslr/images/bike/frame_0015.jpg 1
179 | office31/dslr/images/bike/frame_0016.jpg 1
180 | office31/dslr/images/bike/frame_0017.jpg 1
181 | office31/dslr/images/bike/frame_0018.jpg 1
182 | office31/dslr/images/bike/frame_0019.jpg 1
183 | office31/dslr/images/bike/frame_0020.jpg 1
184 | office31/dslr/images/bike/frame_0021.jpg 1
185 | office31/dslr/images/punchers/frame_0001.jpg 23
186 | office31/dslr/images/punchers/frame_0002.jpg 23
187 | office31/dslr/images/punchers/frame_0003.jpg 23
188 | office31/dslr/images/punchers/frame_0004.jpg 23
189 | office31/dslr/images/punchers/frame_0005.jpg 23
190 | office31/dslr/images/punchers/frame_0006.jpg 23
191 | office31/dslr/images/punchers/frame_0007.jpg 23
192 | office31/dslr/images/punchers/frame_0008.jpg 23
193 | office31/dslr/images/punchers/frame_0009.jpg 23
194 | office31/dslr/images/punchers/frame_0010.jpg 23
195 | office31/dslr/images/punchers/frame_0011.jpg 23
196 | office31/dslr/images/punchers/frame_0012.jpg 23
197 | office31/dslr/images/punchers/frame_0013.jpg 23
198 | office31/dslr/images/punchers/frame_0014.jpg 23
199 | office31/dslr/images/punchers/frame_0015.jpg 23
200 | office31/dslr/images/punchers/frame_0016.jpg 23
201 | office31/dslr/images/punchers/frame_0017.jpg 23
202 | office31/dslr/images/punchers/frame_0018.jpg 23
203 | office31/dslr/images/back_pack/frame_0001.jpg 0
204 | office31/dslr/images/back_pack/frame_0002.jpg 0
205 | office31/dslr/images/back_pack/frame_0003.jpg 0
206 | office31/dslr/images/back_pack/frame_0004.jpg 0
207 | office31/dslr/images/back_pack/frame_0005.jpg 0
208 | office31/dslr/images/back_pack/frame_0006.jpg 0
209 | office31/dslr/images/back_pack/frame_0007.jpg 0
210 | office31/dslr/images/back_pack/frame_0008.jpg 0
211 | office31/dslr/images/back_pack/frame_0009.jpg 0
212 | office31/dslr/images/back_pack/frame_0010.jpg 0
213 | office31/dslr/images/back_pack/frame_0011.jpg 0
214 | office31/dslr/images/back_pack/frame_0012.jpg 0
215 | office31/dslr/images/desktop_computer/frame_0001.jpg 8
216 | office31/dslr/images/desktop_computer/frame_0002.jpg 8
217 | office31/dslr/images/desktop_computer/frame_0003.jpg 8
218 | office31/dslr/images/desktop_computer/frame_0004.jpg 8
219 | office31/dslr/images/desktop_computer/frame_0005.jpg 8
220 | office31/dslr/images/desktop_computer/frame_0006.jpg 8
221 | office31/dslr/images/desktop_computer/frame_0007.jpg 8
222 | office31/dslr/images/desktop_computer/frame_0008.jpg 8
223 | office31/dslr/images/desktop_computer/frame_0009.jpg 8
224 | office31/dslr/images/desktop_computer/frame_0010.jpg 8
225 | office31/dslr/images/desktop_computer/frame_0011.jpg 8
226 | office31/dslr/images/desktop_computer/frame_0012.jpg 8
227 | office31/dslr/images/desktop_computer/frame_0013.jpg 8
228 | office31/dslr/images/desktop_computer/frame_0014.jpg 8
229 | office31/dslr/images/desktop_computer/frame_0015.jpg 8
230 | office31/dslr/images/speaker/frame_0001.jpg 27
231 | office31/dslr/images/speaker/frame_0002.jpg 27
232 | office31/dslr/images/speaker/frame_0003.jpg 27
233 | office31/dslr/images/speaker/frame_0004.jpg 27
234 | office31/dslr/images/speaker/frame_0005.jpg 27
235 | office31/dslr/images/speaker/frame_0006.jpg 27
236 | office31/dslr/images/speaker/frame_0007.jpg 27
237 | office31/dslr/images/speaker/frame_0008.jpg 27
238 | office31/dslr/images/speaker/frame_0009.jpg 27
239 | office31/dslr/images/speaker/frame_0010.jpg 27
240 | office31/dslr/images/speaker/frame_0011.jpg 27
241 | office31/dslr/images/speaker/frame_0012.jpg 27
242 | office31/dslr/images/speaker/frame_0013.jpg 27
243 | office31/dslr/images/speaker/frame_0014.jpg 27
244 | office31/dslr/images/speaker/frame_0015.jpg 27
245 | office31/dslr/images/speaker/frame_0016.jpg 27
246 | office31/dslr/images/speaker/frame_0017.jpg 27
247 | office31/dslr/images/speaker/frame_0018.jpg 27
248 | office31/dslr/images/speaker/frame_0019.jpg 27
249 | office31/dslr/images/speaker/frame_0020.jpg 27
250 | office31/dslr/images/speaker/frame_0021.jpg 27
251 | office31/dslr/images/speaker/frame_0022.jpg 27
252 | office31/dslr/images/speaker/frame_0023.jpg 27
253 | office31/dslr/images/speaker/frame_0024.jpg 27
254 | office31/dslr/images/speaker/frame_0025.jpg 27
255 | office31/dslr/images/speaker/frame_0026.jpg 27
256 | office31/dslr/images/mobile_phone/frame_0001.jpg 14
257 | office31/dslr/images/mobile_phone/frame_0002.jpg 14
258 | office31/dslr/images/mobile_phone/frame_0003.jpg 14
259 | office31/dslr/images/mobile_phone/frame_0004.jpg 14
260 | office31/dslr/images/mobile_phone/frame_0005.jpg 14
261 | office31/dslr/images/mobile_phone/frame_0006.jpg 14
262 | office31/dslr/images/mobile_phone/frame_0007.jpg 14
263 | office31/dslr/images/mobile_phone/frame_0008.jpg 14
264 | office31/dslr/images/mobile_phone/frame_0009.jpg 14
265 | office31/dslr/images/mobile_phone/frame_0010.jpg 14
266 | office31/dslr/images/mobile_phone/frame_0011.jpg 14
267 | office31/dslr/images/mobile_phone/frame_0012.jpg 14
268 | office31/dslr/images/mobile_phone/frame_0013.jpg 14
269 | office31/dslr/images/mobile_phone/frame_0014.jpg 14
270 | office31/dslr/images/mobile_phone/frame_0015.jpg 14
271 | office31/dslr/images/mobile_phone/frame_0016.jpg 14
272 | office31/dslr/images/mobile_phone/frame_0017.jpg 14
273 | office31/dslr/images/mobile_phone/frame_0018.jpg 14
274 | office31/dslr/images/mobile_phone/frame_0019.jpg 14
275 | office31/dslr/images/mobile_phone/frame_0020.jpg 14
276 | office31/dslr/images/mobile_phone/frame_0021.jpg 14
277 | office31/dslr/images/mobile_phone/frame_0022.jpg 14
278 | office31/dslr/images/mobile_phone/frame_0023.jpg 14
279 | office31/dslr/images/mobile_phone/frame_0024.jpg 14
280 | office31/dslr/images/mobile_phone/frame_0025.jpg 14
281 | office31/dslr/images/mobile_phone/frame_0026.jpg 14
282 | office31/dslr/images/mobile_phone/frame_0027.jpg 14
283 | office31/dslr/images/mobile_phone/frame_0028.jpg 14
284 | office31/dslr/images/mobile_phone/frame_0029.jpg 14
285 | office31/dslr/images/mobile_phone/frame_0030.jpg 14
286 | office31/dslr/images/mobile_phone/frame_0031.jpg 14
287 | office31/dslr/images/paper_notebook/frame_0001.jpg 18
288 | office31/dslr/images/paper_notebook/frame_0002.jpg 18
289 | office31/dslr/images/paper_notebook/frame_0003.jpg 18
290 | office31/dslr/images/paper_notebook/frame_0004.jpg 18
291 | office31/dslr/images/paper_notebook/frame_0005.jpg 18
292 | office31/dslr/images/paper_notebook/frame_0006.jpg 18
293 | office31/dslr/images/paper_notebook/frame_0007.jpg 18
294 | office31/dslr/images/paper_notebook/frame_0008.jpg 18
295 | office31/dslr/images/paper_notebook/frame_0009.jpg 18
296 | office31/dslr/images/paper_notebook/frame_0010.jpg 18
297 | office31/dslr/images/ruler/frame_0001.jpg 25
298 | office31/dslr/images/ruler/frame_0002.jpg 25
299 | office31/dslr/images/ruler/frame_0003.jpg 25
300 | office31/dslr/images/ruler/frame_0004.jpg 25
301 | office31/dslr/images/ruler/frame_0005.jpg 25
302 | office31/dslr/images/ruler/frame_0006.jpg 25
303 | office31/dslr/images/ruler/frame_0007.jpg 25
304 | office31/dslr/images/letter_tray/frame_0001.jpg 13
305 | office31/dslr/images/letter_tray/frame_0002.jpg 13
306 | office31/dslr/images/letter_tray/frame_0003.jpg 13
307 | office31/dslr/images/letter_tray/frame_0004.jpg 13
308 | office31/dslr/images/letter_tray/frame_0005.jpg 13
309 | office31/dslr/images/letter_tray/frame_0006.jpg 13
310 | office31/dslr/images/letter_tray/frame_0007.jpg 13
311 | office31/dslr/images/letter_tray/frame_0008.jpg 13
312 | office31/dslr/images/letter_tray/frame_0009.jpg 13
313 | office31/dslr/images/letter_tray/frame_0010.jpg 13
314 | office31/dslr/images/letter_tray/frame_0011.jpg 13
315 | office31/dslr/images/letter_tray/frame_0012.jpg 13
316 | office31/dslr/images/letter_tray/frame_0013.jpg 13
317 | office31/dslr/images/letter_tray/frame_0014.jpg 13
318 | office31/dslr/images/letter_tray/frame_0015.jpg 13
319 | office31/dslr/images/letter_tray/frame_0016.jpg 13
320 | office31/dslr/images/file_cabinet/frame_0001.jpg 9
321 | office31/dslr/images/file_cabinet/frame_0002.jpg 9
322 | office31/dslr/images/file_cabinet/frame_0003.jpg 9
323 | office31/dslr/images/file_cabinet/frame_0004.jpg 9
324 | office31/dslr/images/file_cabinet/frame_0005.jpg 9
325 | office31/dslr/images/file_cabinet/frame_0006.jpg 9
326 | office31/dslr/images/file_cabinet/frame_0007.jpg 9
327 | office31/dslr/images/file_cabinet/frame_0008.jpg 9
328 | office31/dslr/images/file_cabinet/frame_0009.jpg 9
329 | office31/dslr/images/file_cabinet/frame_0010.jpg 9
330 | office31/dslr/images/file_cabinet/frame_0011.jpg 9
331 | office31/dslr/images/file_cabinet/frame_0012.jpg 9
332 | office31/dslr/images/file_cabinet/frame_0013.jpg 9
333 | office31/dslr/images/file_cabinet/frame_0014.jpg 9
334 | office31/dslr/images/file_cabinet/frame_0015.jpg 9
335 | office31/dslr/images/phone/frame_0001.jpg 20
336 | office31/dslr/images/phone/frame_0002.jpg 20
337 | office31/dslr/images/phone/frame_0003.jpg 20
338 | office31/dslr/images/phone/frame_0004.jpg 20
339 | office31/dslr/images/phone/frame_0005.jpg 20
340 | office31/dslr/images/phone/frame_0006.jpg 20
341 | office31/dslr/images/phone/frame_0007.jpg 20
342 | office31/dslr/images/phone/frame_0008.jpg 20
343 | office31/dslr/images/phone/frame_0009.jpg 20
344 | office31/dslr/images/phone/frame_0010.jpg 20
345 | office31/dslr/images/phone/frame_0011.jpg 20
346 | office31/dslr/images/phone/frame_0012.jpg 20
347 | office31/dslr/images/phone/frame_0013.jpg 20
348 | office31/dslr/images/bookcase/frame_0001.jpg 3
349 | office31/dslr/images/bookcase/frame_0002.jpg 3
350 | office31/dslr/images/bookcase/frame_0003.jpg 3
351 | office31/dslr/images/bookcase/frame_0004.jpg 3
352 | office31/dslr/images/bookcase/frame_0005.jpg 3
353 | office31/dslr/images/bookcase/frame_0006.jpg 3
354 | office31/dslr/images/bookcase/frame_0007.jpg 3
355 | office31/dslr/images/bookcase/frame_0008.jpg 3
356 | office31/dslr/images/bookcase/frame_0009.jpg 3
357 | office31/dslr/images/bookcase/frame_0010.jpg 3
358 | office31/dslr/images/bookcase/frame_0011.jpg 3
359 | office31/dslr/images/bookcase/frame_0012.jpg 3
360 | office31/dslr/images/projector/frame_0001.jpg 22
361 | office31/dslr/images/projector/frame_0002.jpg 22
362 | office31/dslr/images/projector/frame_0003.jpg 22
363 | office31/dslr/images/projector/frame_0004.jpg 22
364 | office31/dslr/images/projector/frame_0005.jpg 22
365 | office31/dslr/images/projector/frame_0006.jpg 22
366 | office31/dslr/images/projector/frame_0007.jpg 22
367 | office31/dslr/images/projector/frame_0008.jpg 22
368 | office31/dslr/images/projector/frame_0009.jpg 22
369 | office31/dslr/images/projector/frame_0010.jpg 22
370 | office31/dslr/images/projector/frame_0011.jpg 22
371 | office31/dslr/images/projector/frame_0012.jpg 22
372 | office31/dslr/images/projector/frame_0013.jpg 22
373 | office31/dslr/images/projector/frame_0014.jpg 22
374 | office31/dslr/images/projector/frame_0015.jpg 22
375 | office31/dslr/images/projector/frame_0016.jpg 22
376 | office31/dslr/images/projector/frame_0017.jpg 22
377 | office31/dslr/images/projector/frame_0018.jpg 22
378 | office31/dslr/images/projector/frame_0019.jpg 22
379 | office31/dslr/images/projector/frame_0020.jpg 22
380 | office31/dslr/images/projector/frame_0021.jpg 22
381 | office31/dslr/images/projector/frame_0022.jpg 22
382 | office31/dslr/images/projector/frame_0023.jpg 22
383 | office31/dslr/images/stapler/frame_0001.jpg 28
384 | office31/dslr/images/stapler/frame_0002.jpg 28
385 | office31/dslr/images/stapler/frame_0003.jpg 28
386 | office31/dslr/images/stapler/frame_0004.jpg 28
387 | office31/dslr/images/stapler/frame_0005.jpg 28
388 | office31/dslr/images/stapler/frame_0006.jpg 28
389 | office31/dslr/images/stapler/frame_0007.jpg 28
390 | office31/dslr/images/stapler/frame_0008.jpg 28
391 | office31/dslr/images/stapler/frame_0009.jpg 28
392 | office31/dslr/images/stapler/frame_0010.jpg 28
393 | office31/dslr/images/stapler/frame_0011.jpg 28
394 | office31/dslr/images/stapler/frame_0012.jpg 28
395 | office31/dslr/images/stapler/frame_0013.jpg 28
396 | office31/dslr/images/stapler/frame_0014.jpg 28
397 | office31/dslr/images/stapler/frame_0015.jpg 28
398 | office31/dslr/images/stapler/frame_0016.jpg 28
399 | office31/dslr/images/stapler/frame_0017.jpg 28
400 | office31/dslr/images/stapler/frame_0018.jpg 28
401 | office31/dslr/images/stapler/frame_0019.jpg 28
402 | office31/dslr/images/stapler/frame_0020.jpg 28
403 | office31/dslr/images/stapler/frame_0021.jpg 28
404 | office31/dslr/images/trash_can/frame_0001.jpg 30
405 | office31/dslr/images/trash_can/frame_0002.jpg 30
406 | office31/dslr/images/trash_can/frame_0003.jpg 30
407 | office31/dslr/images/trash_can/frame_0004.jpg 30
408 | office31/dslr/images/trash_can/frame_0005.jpg 30
409 | office31/dslr/images/trash_can/frame_0006.jpg 30
410 | office31/dslr/images/trash_can/frame_0007.jpg 30
411 | office31/dslr/images/trash_can/frame_0008.jpg 30
412 | office31/dslr/images/trash_can/frame_0009.jpg 30
413 | office31/dslr/images/trash_can/frame_0010.jpg 30
414 | office31/dslr/images/trash_can/frame_0011.jpg 30
415 | office31/dslr/images/trash_can/frame_0012.jpg 30
416 | office31/dslr/images/trash_can/frame_0013.jpg 30
417 | office31/dslr/images/trash_can/frame_0014.jpg 30
418 | office31/dslr/images/trash_can/frame_0015.jpg 30
419 | office31/dslr/images/bike_helmet/frame_0001.jpg 2
420 | office31/dslr/images/bike_helmet/frame_0002.jpg 2
421 | office31/dslr/images/bike_helmet/frame_0003.jpg 2
422 | office31/dslr/images/bike_helmet/frame_0004.jpg 2
423 | office31/dslr/images/bike_helmet/frame_0005.jpg 2
424 | office31/dslr/images/bike_helmet/frame_0006.jpg 2
425 | office31/dslr/images/bike_helmet/frame_0007.jpg 2
426 | office31/dslr/images/bike_helmet/frame_0008.jpg 2
427 | office31/dslr/images/bike_helmet/frame_0009.jpg 2
428 | office31/dslr/images/bike_helmet/frame_0010.jpg 2
429 | office31/dslr/images/bike_helmet/frame_0011.jpg 2
430 | office31/dslr/images/bike_helmet/frame_0012.jpg 2
431 | office31/dslr/images/bike_helmet/frame_0013.jpg 2
432 | office31/dslr/images/bike_helmet/frame_0014.jpg 2
433 | office31/dslr/images/bike_helmet/frame_0015.jpg 2
434 | office31/dslr/images/bike_helmet/frame_0016.jpg 2
435 | office31/dslr/images/bike_helmet/frame_0017.jpg 2
436 | office31/dslr/images/bike_helmet/frame_0018.jpg 2
437 | office31/dslr/images/bike_helmet/frame_0019.jpg 2
438 | office31/dslr/images/bike_helmet/frame_0020.jpg 2
439 | office31/dslr/images/bike_helmet/frame_0021.jpg 2
440 | office31/dslr/images/bike_helmet/frame_0022.jpg 2
441 | office31/dslr/images/bike_helmet/frame_0023.jpg 2
442 | office31/dslr/images/bike_helmet/frame_0024.jpg 2
443 | office31/dslr/images/headphones/frame_0001.jpg 10
444 | office31/dslr/images/headphones/frame_0002.jpg 10
445 | office31/dslr/images/headphones/frame_0003.jpg 10
446 | office31/dslr/images/headphones/frame_0004.jpg 10
447 | office31/dslr/images/headphones/frame_0005.jpg 10
448 | office31/dslr/images/headphones/frame_0006.jpg 10
449 | office31/dslr/images/headphones/frame_0007.jpg 10
450 | office31/dslr/images/headphones/frame_0008.jpg 10
451 | office31/dslr/images/headphones/frame_0009.jpg 10
452 | office31/dslr/images/headphones/frame_0010.jpg 10
453 | office31/dslr/images/headphones/frame_0011.jpg 10
454 | office31/dslr/images/headphones/frame_0012.jpg 10
455 | office31/dslr/images/headphones/frame_0013.jpg 10
456 | office31/dslr/images/desk_lamp/frame_0001.jpg 7
457 | office31/dslr/images/desk_lamp/frame_0002.jpg 7
458 | office31/dslr/images/desk_lamp/frame_0003.jpg 7
459 | office31/dslr/images/desk_lamp/frame_0004.jpg 7
460 | office31/dslr/images/desk_lamp/frame_0005.jpg 7
461 | office31/dslr/images/desk_lamp/frame_0006.jpg 7
462 | office31/dslr/images/desk_lamp/frame_0007.jpg 7
463 | office31/dslr/images/desk_lamp/frame_0008.jpg 7
464 | office31/dslr/images/desk_lamp/frame_0009.jpg 7
465 | office31/dslr/images/desk_lamp/frame_0010.jpg 7
466 | office31/dslr/images/desk_lamp/frame_0011.jpg 7
467 | office31/dslr/images/desk_lamp/frame_0012.jpg 7
468 | office31/dslr/images/desk_lamp/frame_0013.jpg 7
469 | office31/dslr/images/desk_lamp/frame_0014.jpg 7
470 | office31/dslr/images/desk_chair/frame_0001.jpg 6
471 | office31/dslr/images/desk_chair/frame_0002.jpg 6
472 | office31/dslr/images/desk_chair/frame_0003.jpg 6
473 | office31/dslr/images/desk_chair/frame_0004.jpg 6
474 | office31/dslr/images/desk_chair/frame_0005.jpg 6
475 | office31/dslr/images/desk_chair/frame_0006.jpg 6
476 | office31/dslr/images/desk_chair/frame_0007.jpg 6
477 | office31/dslr/images/desk_chair/frame_0008.jpg 6
478 | office31/dslr/images/desk_chair/frame_0009.jpg 6
479 | office31/dslr/images/desk_chair/frame_0010.jpg 6
480 | office31/dslr/images/desk_chair/frame_0011.jpg 6
481 | office31/dslr/images/desk_chair/frame_0012.jpg 6
482 | office31/dslr/images/desk_chair/frame_0013.jpg 6
483 | office31/dslr/images/bottle/frame_0001.jpg 4
484 | office31/dslr/images/bottle/frame_0002.jpg 4
485 | office31/dslr/images/bottle/frame_0003.jpg 4
486 | office31/dslr/images/bottle/frame_0004.jpg 4
487 | office31/dslr/images/bottle/frame_0005.jpg 4
488 | office31/dslr/images/bottle/frame_0006.jpg 4
489 | office31/dslr/images/bottle/frame_0007.jpg 4
490 | office31/dslr/images/bottle/frame_0008.jpg 4
491 | office31/dslr/images/bottle/frame_0009.jpg 4
492 | office31/dslr/images/bottle/frame_0010.jpg 4
493 | office31/dslr/images/bottle/frame_0011.jpg 4
494 | office31/dslr/images/bottle/frame_0012.jpg 4
495 | office31/dslr/images/bottle/frame_0013.jpg 4
496 | office31/dslr/images/bottle/frame_0014.jpg 4
497 | office31/dslr/images/bottle/frame_0015.jpg 4
498 | office31/dslr/images/bottle/frame_0016.jpg 4
499 |
--------------------------------------------------------------------------------
/data/image_list/office31/webcam.txt:
--------------------------------------------------------------------------------
1 | office31//webcam/images/calculator/frame_0001.jpg 5
2 | office31//webcam/images/calculator/frame_0002.jpg 5
3 | office31//webcam/images/calculator/frame_0003.jpg 5
4 | office31//webcam/images/calculator/frame_0004.jpg 5
5 | office31//webcam/images/calculator/frame_0005.jpg 5
6 | office31//webcam/images/calculator/frame_0006.jpg 5
7 | office31//webcam/images/calculator/frame_0007.jpg 5
8 | office31//webcam/images/calculator/frame_0008.jpg 5
9 | office31//webcam/images/calculator/frame_0009.jpg 5
10 | office31//webcam/images/calculator/frame_0010.jpg 5
11 | office31//webcam/images/calculator/frame_0011.jpg 5
12 | office31//webcam/images/calculator/frame_0012.jpg 5
13 | office31//webcam/images/calculator/frame_0013.jpg 5
14 | office31//webcam/images/calculator/frame_0014.jpg 5
15 | office31//webcam/images/calculator/frame_0015.jpg 5
16 | office31//webcam/images/calculator/frame_0016.jpg 5
17 | office31//webcam/images/calculator/frame_0017.jpg 5
18 | office31//webcam/images/calculator/frame_0018.jpg 5
19 | office31//webcam/images/calculator/frame_0019.jpg 5
20 | office31//webcam/images/calculator/frame_0020.jpg 5
21 | office31//webcam/images/calculator/frame_0021.jpg 5
22 | office31//webcam/images/calculator/frame_0022.jpg 5
23 | office31//webcam/images/calculator/frame_0023.jpg 5
24 | office31//webcam/images/calculator/frame_0024.jpg 5
25 | office31//webcam/images/calculator/frame_0025.jpg 5
26 | office31//webcam/images/calculator/frame_0026.jpg 5
27 | office31//webcam/images/calculator/frame_0027.jpg 5
28 | office31//webcam/images/calculator/frame_0028.jpg 5
29 | office31//webcam/images/calculator/frame_0029.jpg 5
30 | office31//webcam/images/calculator/frame_0030.jpg 5
31 | office31//webcam/images/calculator/frame_0031.jpg 5
32 | office31//webcam/images/ring_binder/frame_0001.jpg 24
33 | office31//webcam/images/ring_binder/frame_0002.jpg 24
34 | office31//webcam/images/ring_binder/frame_0003.jpg 24
35 | office31//webcam/images/ring_binder/frame_0004.jpg 24
36 | office31//webcam/images/ring_binder/frame_0005.jpg 24
37 | office31//webcam/images/ring_binder/frame_0006.jpg 24
38 | office31//webcam/images/ring_binder/frame_0007.jpg 24
39 | office31//webcam/images/ring_binder/frame_0008.jpg 24
40 | office31//webcam/images/ring_binder/frame_0009.jpg 24
41 | office31//webcam/images/ring_binder/frame_0010.jpg 24
42 | office31//webcam/images/ring_binder/frame_0011.jpg 24
43 | office31//webcam/images/ring_binder/frame_0012.jpg 24
44 | office31//webcam/images/ring_binder/frame_0013.jpg 24
45 | office31//webcam/images/ring_binder/frame_0014.jpg 24
46 | office31//webcam/images/ring_binder/frame_0015.jpg 24
47 | office31//webcam/images/ring_binder/frame_0016.jpg 24
48 | office31//webcam/images/ring_binder/frame_0017.jpg 24
49 | office31//webcam/images/ring_binder/frame_0018.jpg 24
50 | office31//webcam/images/ring_binder/frame_0019.jpg 24
51 | office31//webcam/images/ring_binder/frame_0020.jpg 24
52 | office31//webcam/images/ring_binder/frame_0021.jpg 24
53 | office31//webcam/images/ring_binder/frame_0022.jpg 24
54 | office31//webcam/images/ring_binder/frame_0023.jpg 24
55 | office31//webcam/images/ring_binder/frame_0024.jpg 24
56 | office31//webcam/images/ring_binder/frame_0025.jpg 24
57 | office31//webcam/images/ring_binder/frame_0026.jpg 24
58 | office31//webcam/images/ring_binder/frame_0027.jpg 24
59 | office31//webcam/images/ring_binder/frame_0028.jpg 24
60 | office31//webcam/images/ring_binder/frame_0029.jpg 24
61 | office31//webcam/images/ring_binder/frame_0030.jpg 24
62 | office31//webcam/images/ring_binder/frame_0031.jpg 24
63 | office31//webcam/images/ring_binder/frame_0032.jpg 24
64 | office31//webcam/images/ring_binder/frame_0033.jpg 24
65 | office31//webcam/images/ring_binder/frame_0034.jpg 24
66 | office31//webcam/images/ring_binder/frame_0035.jpg 24
67 | office31//webcam/images/ring_binder/frame_0036.jpg 24
68 | office31//webcam/images/ring_binder/frame_0037.jpg 24
69 | office31//webcam/images/ring_binder/frame_0038.jpg 24
70 | office31//webcam/images/ring_binder/frame_0039.jpg 24
71 | office31//webcam/images/ring_binder/frame_0040.jpg 24
72 | office31//webcam/images/printer/frame_0001.jpg 21
73 | office31//webcam/images/printer/frame_0002.jpg 21
74 | office31//webcam/images/printer/frame_0003.jpg 21
75 | office31//webcam/images/printer/frame_0004.jpg 21
76 | office31//webcam/images/printer/frame_0005.jpg 21
77 | office31//webcam/images/printer/frame_0006.jpg 21
78 | office31//webcam/images/printer/frame_0007.jpg 21
79 | office31//webcam/images/printer/frame_0008.jpg 21
80 | office31//webcam/images/printer/frame_0009.jpg 21
81 | office31//webcam/images/printer/frame_0010.jpg 21
82 | office31//webcam/images/printer/frame_0011.jpg 21
83 | office31//webcam/images/printer/frame_0012.jpg 21
84 | office31//webcam/images/printer/frame_0013.jpg 21
85 | office31//webcam/images/printer/frame_0014.jpg 21
86 | office31//webcam/images/printer/frame_0015.jpg 21
87 | office31//webcam/images/printer/frame_0016.jpg 21
88 | office31//webcam/images/printer/frame_0017.jpg 21
89 | office31//webcam/images/printer/frame_0018.jpg 21
90 | office31//webcam/images/printer/frame_0019.jpg 21
91 | office31//webcam/images/printer/frame_0020.jpg 21
92 | office31//webcam/images/keyboard/frame_0001.jpg 11
93 | office31//webcam/images/keyboard/frame_0002.jpg 11
94 | office31//webcam/images/keyboard/frame_0003.jpg 11
95 | office31//webcam/images/keyboard/frame_0004.jpg 11
96 | office31//webcam/images/keyboard/frame_0005.jpg 11
97 | office31//webcam/images/keyboard/frame_0006.jpg 11
98 | office31//webcam/images/keyboard/frame_0007.jpg 11
99 | office31//webcam/images/keyboard/frame_0008.jpg 11
100 | office31//webcam/images/keyboard/frame_0009.jpg 11
101 | office31//webcam/images/keyboard/frame_0010.jpg 11
102 | office31//webcam/images/keyboard/frame_0011.jpg 11
103 | office31//webcam/images/keyboard/frame_0012.jpg 11
104 | office31//webcam/images/keyboard/frame_0013.jpg 11
105 | office31//webcam/images/keyboard/frame_0014.jpg 11
106 | office31//webcam/images/keyboard/frame_0015.jpg 11
107 | office31//webcam/images/keyboard/frame_0016.jpg 11
108 | office31//webcam/images/keyboard/frame_0017.jpg 11
109 | office31//webcam/images/keyboard/frame_0018.jpg 11
110 | office31//webcam/images/keyboard/frame_0019.jpg 11
111 | office31//webcam/images/keyboard/frame_0020.jpg 11
112 | office31//webcam/images/keyboard/frame_0021.jpg 11
113 | office31//webcam/images/keyboard/frame_0022.jpg 11
114 | office31//webcam/images/keyboard/frame_0023.jpg 11
115 | office31//webcam/images/keyboard/frame_0024.jpg 11
116 | office31//webcam/images/keyboard/frame_0025.jpg 11
117 | office31//webcam/images/keyboard/frame_0026.jpg 11
118 | office31//webcam/images/keyboard/frame_0027.jpg 11
119 | office31//webcam/images/scissors/frame_0001.jpg 26
120 | office31//webcam/images/scissors/frame_0002.jpg 26
121 | office31//webcam/images/scissors/frame_0003.jpg 26
122 | office31//webcam/images/scissors/frame_0004.jpg 26
123 | office31//webcam/images/scissors/frame_0005.jpg 26
124 | office31//webcam/images/scissors/frame_0006.jpg 26
125 | office31//webcam/images/scissors/frame_0007.jpg 26
126 | office31//webcam/images/scissors/frame_0008.jpg 26
127 | office31//webcam/images/scissors/frame_0009.jpg 26
128 | office31//webcam/images/scissors/frame_0010.jpg 26
129 | office31//webcam/images/scissors/frame_0011.jpg 26
130 | office31//webcam/images/scissors/frame_0012.jpg 26
131 | office31//webcam/images/scissors/frame_0013.jpg 26
132 | office31//webcam/images/scissors/frame_0014.jpg 26
133 | office31//webcam/images/scissors/frame_0015.jpg 26
134 | office31//webcam/images/scissors/frame_0016.jpg 26
135 | office31//webcam/images/scissors/frame_0017.jpg 26
136 | office31//webcam/images/scissors/frame_0018.jpg 26
137 | office31//webcam/images/scissors/frame_0019.jpg 26
138 | office31//webcam/images/scissors/frame_0020.jpg 26
139 | office31//webcam/images/scissors/frame_0021.jpg 26
140 | office31//webcam/images/scissors/frame_0022.jpg 26
141 | office31//webcam/images/scissors/frame_0023.jpg 26
142 | office31//webcam/images/scissors/frame_0024.jpg 26
143 | office31//webcam/images/scissors/frame_0025.jpg 26
144 | office31//webcam/images/laptop_computer/frame_0001.jpg 12
145 | office31//webcam/images/laptop_computer/frame_0002.jpg 12
146 | office31//webcam/images/laptop_computer/frame_0003.jpg 12
147 | office31//webcam/images/laptop_computer/frame_0004.jpg 12
148 | office31//webcam/images/laptop_computer/frame_0005.jpg 12
149 | office31//webcam/images/laptop_computer/frame_0006.jpg 12
150 | office31//webcam/images/laptop_computer/frame_0007.jpg 12
151 | office31//webcam/images/laptop_computer/frame_0008.jpg 12
152 | office31//webcam/images/laptop_computer/frame_0009.jpg 12
153 | office31//webcam/images/laptop_computer/frame_0010.jpg 12
154 | office31//webcam/images/laptop_computer/frame_0011.jpg 12
155 | office31//webcam/images/laptop_computer/frame_0012.jpg 12
156 | office31//webcam/images/laptop_computer/frame_0013.jpg 12
157 | office31//webcam/images/laptop_computer/frame_0014.jpg 12
158 | office31//webcam/images/laptop_computer/frame_0015.jpg 12
159 | office31//webcam/images/laptop_computer/frame_0016.jpg 12
160 | office31//webcam/images/laptop_computer/frame_0017.jpg 12
161 | office31//webcam/images/laptop_computer/frame_0018.jpg 12
162 | office31//webcam/images/laptop_computer/frame_0019.jpg 12
163 | office31//webcam/images/laptop_computer/frame_0020.jpg 12
164 | office31//webcam/images/laptop_computer/frame_0021.jpg 12
165 | office31//webcam/images/laptop_computer/frame_0022.jpg 12
166 | office31//webcam/images/laptop_computer/frame_0023.jpg 12
167 | office31//webcam/images/laptop_computer/frame_0024.jpg 12
168 | office31//webcam/images/laptop_computer/frame_0025.jpg 12
169 | office31//webcam/images/laptop_computer/frame_0026.jpg 12
170 | office31//webcam/images/laptop_computer/frame_0027.jpg 12
171 | office31//webcam/images/laptop_computer/frame_0028.jpg 12
172 | office31//webcam/images/laptop_computer/frame_0029.jpg 12
173 | office31//webcam/images/laptop_computer/frame_0030.jpg 12
174 | office31//webcam/images/mouse/frame_0001.jpg 16
175 | office31//webcam/images/mouse/frame_0002.jpg 16
176 | office31//webcam/images/mouse/frame_0003.jpg 16
177 | office31//webcam/images/mouse/frame_0004.jpg 16
178 | office31//webcam/images/mouse/frame_0005.jpg 16
179 | office31//webcam/images/mouse/frame_0006.jpg 16
180 | office31//webcam/images/mouse/frame_0007.jpg 16
181 | office31//webcam/images/mouse/frame_0008.jpg 16
182 | office31//webcam/images/mouse/frame_0009.jpg 16
183 | office31//webcam/images/mouse/frame_0010.jpg 16
184 | office31//webcam/images/mouse/frame_0011.jpg 16
185 | office31//webcam/images/mouse/frame_0012.jpg 16
186 | office31//webcam/images/mouse/frame_0013.jpg 16
187 | office31//webcam/images/mouse/frame_0014.jpg 16
188 | office31//webcam/images/mouse/frame_0015.jpg 16
189 | office31//webcam/images/mouse/frame_0016.jpg 16
190 | office31//webcam/images/mouse/frame_0017.jpg 16
191 | office31//webcam/images/mouse/frame_0018.jpg 16
192 | office31//webcam/images/mouse/frame_0019.jpg 16
193 | office31//webcam/images/mouse/frame_0020.jpg 16
194 | office31//webcam/images/mouse/frame_0021.jpg 16
195 | office31//webcam/images/mouse/frame_0022.jpg 16
196 | office31//webcam/images/mouse/frame_0023.jpg 16
197 | office31//webcam/images/mouse/frame_0024.jpg 16
198 | office31//webcam/images/mouse/frame_0025.jpg 16
199 | office31//webcam/images/mouse/frame_0026.jpg 16
200 | office31//webcam/images/mouse/frame_0027.jpg 16
201 | office31//webcam/images/mouse/frame_0028.jpg 16
202 | office31//webcam/images/mouse/frame_0029.jpg 16
203 | office31//webcam/images/mouse/frame_0030.jpg 16
204 | office31//webcam/images/monitor/frame_0001.jpg 15
205 | office31//webcam/images/monitor/frame_0002.jpg 15
206 | office31//webcam/images/monitor/frame_0003.jpg 15
207 | office31//webcam/images/monitor/frame_0004.jpg 15
208 | office31//webcam/images/monitor/frame_0005.jpg 15
209 | office31//webcam/images/monitor/frame_0006.jpg 15
210 | office31//webcam/images/monitor/frame_0007.jpg 15
211 | office31//webcam/images/monitor/frame_0008.jpg 15
212 | office31//webcam/images/monitor/frame_0009.jpg 15
213 | office31//webcam/images/monitor/frame_0010.jpg 15
214 | office31//webcam/images/monitor/frame_0011.jpg 15
215 | office31//webcam/images/monitor/frame_0012.jpg 15
216 | office31//webcam/images/monitor/frame_0013.jpg 15
217 | office31//webcam/images/monitor/frame_0014.jpg 15
218 | office31//webcam/images/monitor/frame_0015.jpg 15
219 | office31//webcam/images/monitor/frame_0016.jpg 15
220 | office31//webcam/images/monitor/frame_0017.jpg 15
221 | office31//webcam/images/monitor/frame_0018.jpg 15
222 | office31//webcam/images/monitor/frame_0019.jpg 15
223 | office31//webcam/images/monitor/frame_0020.jpg 15
224 | office31//webcam/images/monitor/frame_0021.jpg 15
225 | office31//webcam/images/monitor/frame_0022.jpg 15
226 | office31//webcam/images/monitor/frame_0023.jpg 15
227 | office31//webcam/images/monitor/frame_0024.jpg 15
228 | office31//webcam/images/monitor/frame_0025.jpg 15
229 | office31//webcam/images/monitor/frame_0026.jpg 15
230 | office31//webcam/images/monitor/frame_0027.jpg 15
231 | office31//webcam/images/monitor/frame_0028.jpg 15
232 | office31//webcam/images/monitor/frame_0029.jpg 15
233 | office31//webcam/images/monitor/frame_0030.jpg 15
234 | office31//webcam/images/monitor/frame_0031.jpg 15
235 | office31//webcam/images/monitor/frame_0032.jpg 15
236 | office31//webcam/images/monitor/frame_0033.jpg 15
237 | office31//webcam/images/monitor/frame_0034.jpg 15
238 | office31//webcam/images/monitor/frame_0035.jpg 15
239 | office31//webcam/images/monitor/frame_0036.jpg 15
240 | office31//webcam/images/monitor/frame_0037.jpg 15
241 | office31//webcam/images/monitor/frame_0038.jpg 15
242 | office31//webcam/images/monitor/frame_0039.jpg 15
243 | office31//webcam/images/monitor/frame_0040.jpg 15
244 | office31//webcam/images/monitor/frame_0041.jpg 15
245 | office31//webcam/images/monitor/frame_0042.jpg 15
246 | office31//webcam/images/monitor/frame_0043.jpg 15
247 | office31//webcam/images/mug/frame_0001.jpg 17
248 | office31//webcam/images/mug/frame_0002.jpg 17
249 | office31//webcam/images/mug/frame_0003.jpg 17
250 | office31//webcam/images/mug/frame_0004.jpg 17
251 | office31//webcam/images/mug/frame_0005.jpg 17
252 | office31//webcam/images/mug/frame_0006.jpg 17
253 | office31//webcam/images/mug/frame_0007.jpg 17
254 | office31//webcam/images/mug/frame_0008.jpg 17
255 | office31//webcam/images/mug/frame_0009.jpg 17
256 | office31//webcam/images/mug/frame_0010.jpg 17
257 | office31//webcam/images/mug/frame_0011.jpg 17
258 | office31//webcam/images/mug/frame_0012.jpg 17
259 | office31//webcam/images/mug/frame_0013.jpg 17
260 | office31//webcam/images/mug/frame_0014.jpg 17
261 | office31//webcam/images/mug/frame_0015.jpg 17
262 | office31//webcam/images/mug/frame_0016.jpg 17
263 | office31//webcam/images/mug/frame_0017.jpg 17
264 | office31//webcam/images/mug/frame_0018.jpg 17
265 | office31//webcam/images/mug/frame_0019.jpg 17
266 | office31//webcam/images/mug/frame_0020.jpg 17
267 | office31//webcam/images/mug/frame_0021.jpg 17
268 | office31//webcam/images/mug/frame_0022.jpg 17
269 | office31//webcam/images/mug/frame_0023.jpg 17
270 | office31//webcam/images/mug/frame_0024.jpg 17
271 | office31//webcam/images/mug/frame_0025.jpg 17
272 | office31//webcam/images/mug/frame_0026.jpg 17
273 | office31//webcam/images/mug/frame_0027.jpg 17
274 | office31//webcam/images/tape_dispenser/frame_0001.jpg 29
275 | office31//webcam/images/tape_dispenser/frame_0002.jpg 29
276 | office31//webcam/images/tape_dispenser/frame_0003.jpg 29
277 | office31//webcam/images/tape_dispenser/frame_0004.jpg 29
278 | office31//webcam/images/tape_dispenser/frame_0005.jpg 29
279 | office31//webcam/images/tape_dispenser/frame_0006.jpg 29
280 | office31//webcam/images/tape_dispenser/frame_0007.jpg 29
281 | office31//webcam/images/tape_dispenser/frame_0008.jpg 29
282 | office31//webcam/images/tape_dispenser/frame_0009.jpg 29
283 | office31//webcam/images/tape_dispenser/frame_0010.jpg 29
284 | office31//webcam/images/tape_dispenser/frame_0011.jpg 29
285 | office31//webcam/images/tape_dispenser/frame_0012.jpg 29
286 | office31//webcam/images/tape_dispenser/frame_0013.jpg 29
287 | office31//webcam/images/tape_dispenser/frame_0014.jpg 29
288 | office31//webcam/images/tape_dispenser/frame_0015.jpg 29
289 | office31//webcam/images/tape_dispenser/frame_0016.jpg 29
290 | office31//webcam/images/tape_dispenser/frame_0017.jpg 29
291 | office31//webcam/images/tape_dispenser/frame_0018.jpg 29
292 | office31//webcam/images/tape_dispenser/frame_0019.jpg 29
293 | office31//webcam/images/tape_dispenser/frame_0020.jpg 29
294 | office31//webcam/images/tape_dispenser/frame_0021.jpg 29
295 | office31//webcam/images/tape_dispenser/frame_0022.jpg 29
296 | office31//webcam/images/tape_dispenser/frame_0023.jpg 29
297 | office31//webcam/images/pen/frame_0001.jpg 19
298 | office31//webcam/images/pen/frame_0002.jpg 19
299 | office31//webcam/images/pen/frame_0003.jpg 19
300 | office31//webcam/images/pen/frame_0004.jpg 19
301 | office31//webcam/images/pen/frame_0005.jpg 19
302 | office31//webcam/images/pen/frame_0006.jpg 19
303 | office31//webcam/images/pen/frame_0007.jpg 19
304 | office31//webcam/images/pen/frame_0008.jpg 19
305 | office31//webcam/images/pen/frame_0009.jpg 19
306 | office31//webcam/images/pen/frame_0010.jpg 19
307 | office31//webcam/images/pen/frame_0011.jpg 19
308 | office31//webcam/images/pen/frame_0012.jpg 19
309 | office31//webcam/images/pen/frame_0013.jpg 19
310 | office31//webcam/images/pen/frame_0014.jpg 19
311 | office31//webcam/images/pen/frame_0015.jpg 19
312 | office31//webcam/images/pen/frame_0016.jpg 19
313 | office31//webcam/images/pen/frame_0017.jpg 19
314 | office31//webcam/images/pen/frame_0018.jpg 19
315 | office31//webcam/images/pen/frame_0019.jpg 19
316 | office31//webcam/images/pen/frame_0020.jpg 19
317 | office31//webcam/images/pen/frame_0021.jpg 19
318 | office31//webcam/images/pen/frame_0022.jpg 19
319 | office31//webcam/images/pen/frame_0023.jpg 19
320 | office31//webcam/images/pen/frame_0024.jpg 19
321 | office31//webcam/images/pen/frame_0025.jpg 19
322 | office31//webcam/images/pen/frame_0026.jpg 19
323 | office31//webcam/images/pen/frame_0027.jpg 19
324 | office31//webcam/images/pen/frame_0028.jpg 19
325 | office31//webcam/images/pen/frame_0029.jpg 19
326 | office31//webcam/images/pen/frame_0030.jpg 19
327 | office31//webcam/images/pen/frame_0031.jpg 19
328 | office31//webcam/images/pen/frame_0032.jpg 19
329 | office31//webcam/images/bike/frame_0001.jpg 1
330 | office31//webcam/images/bike/frame_0002.jpg 1
331 | office31//webcam/images/bike/frame_0003.jpg 1
332 | office31//webcam/images/bike/frame_0004.jpg 1
333 | office31//webcam/images/bike/frame_0005.jpg 1
334 | office31//webcam/images/bike/frame_0006.jpg 1
335 | office31//webcam/images/bike/frame_0007.jpg 1
336 | office31//webcam/images/bike/frame_0008.jpg 1
337 | office31//webcam/images/bike/frame_0009.jpg 1
338 | office31//webcam/images/bike/frame_0010.jpg 1
339 | office31//webcam/images/bike/frame_0011.jpg 1
340 | office31//webcam/images/bike/frame_0012.jpg 1
341 | office31//webcam/images/bike/frame_0013.jpg 1
342 | office31//webcam/images/bike/frame_0014.jpg 1
343 | office31//webcam/images/bike/frame_0015.jpg 1
344 | office31//webcam/images/bike/frame_0016.jpg 1
345 | office31//webcam/images/bike/frame_0017.jpg 1
346 | office31//webcam/images/bike/frame_0018.jpg 1
347 | office31//webcam/images/bike/frame_0019.jpg 1
348 | office31//webcam/images/bike/frame_0020.jpg 1
349 | office31//webcam/images/bike/frame_0021.jpg 1
350 | office31//webcam/images/punchers/frame_0001.jpg 23
351 | office31//webcam/images/punchers/frame_0002.jpg 23
352 | office31//webcam/images/punchers/frame_0003.jpg 23
353 | office31//webcam/images/punchers/frame_0004.jpg 23
354 | office31//webcam/images/punchers/frame_0005.jpg 23
355 | office31//webcam/images/punchers/frame_0006.jpg 23
356 | office31//webcam/images/punchers/frame_0007.jpg 23
357 | office31//webcam/images/punchers/frame_0008.jpg 23
358 | office31//webcam/images/punchers/frame_0009.jpg 23
359 | office31//webcam/images/punchers/frame_0010.jpg 23
360 | office31//webcam/images/punchers/frame_0011.jpg 23
361 | office31//webcam/images/punchers/frame_0012.jpg 23
362 | office31//webcam/images/punchers/frame_0013.jpg 23
363 | office31//webcam/images/punchers/frame_0014.jpg 23
364 | office31//webcam/images/punchers/frame_0015.jpg 23
365 | office31//webcam/images/punchers/frame_0016.jpg 23
366 | office31//webcam/images/punchers/frame_0017.jpg 23
367 | office31//webcam/images/punchers/frame_0018.jpg 23
368 | office31//webcam/images/punchers/frame_0019.jpg 23
369 | office31//webcam/images/punchers/frame_0020.jpg 23
370 | office31//webcam/images/punchers/frame_0021.jpg 23
371 | office31//webcam/images/punchers/frame_0022.jpg 23
372 | office31//webcam/images/punchers/frame_0023.jpg 23
373 | office31//webcam/images/punchers/frame_0024.jpg 23
374 | office31//webcam/images/punchers/frame_0025.jpg 23
375 | office31//webcam/images/punchers/frame_0026.jpg 23
376 | office31//webcam/images/punchers/frame_0027.jpg 23
377 | office31//webcam/images/back_pack/frame_0001.jpg 0
378 | office31//webcam/images/back_pack/frame_0002.jpg 0
379 | office31//webcam/images/back_pack/frame_0003.jpg 0
380 | office31//webcam/images/back_pack/frame_0004.jpg 0
381 | office31//webcam/images/back_pack/frame_0005.jpg 0
382 | office31//webcam/images/back_pack/frame_0006.jpg 0
383 | office31//webcam/images/back_pack/frame_0007.jpg 0
384 | office31//webcam/images/back_pack/frame_0008.jpg 0
385 | office31//webcam/images/back_pack/frame_0009.jpg 0
386 | office31//webcam/images/back_pack/frame_0010.jpg 0
387 | office31//webcam/images/back_pack/frame_0011.jpg 0
388 | office31//webcam/images/back_pack/frame_0012.jpg 0
389 | office31//webcam/images/back_pack/frame_0013.jpg 0
390 | office31//webcam/images/back_pack/frame_0014.jpg 0
391 | office31//webcam/images/back_pack/frame_0015.jpg 0
392 | office31//webcam/images/back_pack/frame_0016.jpg 0
393 | office31//webcam/images/back_pack/frame_0017.jpg 0
394 | office31//webcam/images/back_pack/frame_0018.jpg 0
395 | office31//webcam/images/back_pack/frame_0019.jpg 0
396 | office31//webcam/images/back_pack/frame_0020.jpg 0
397 | office31//webcam/images/back_pack/frame_0021.jpg 0
398 | office31//webcam/images/back_pack/frame_0022.jpg 0
399 | office31//webcam/images/back_pack/frame_0023.jpg 0
400 | office31//webcam/images/back_pack/frame_0024.jpg 0
401 | office31//webcam/images/back_pack/frame_0025.jpg 0
402 | office31//webcam/images/back_pack/frame_0026.jpg 0
403 | office31//webcam/images/back_pack/frame_0027.jpg 0
404 | office31//webcam/images/back_pack/frame_0028.jpg 0
405 | office31//webcam/images/back_pack/frame_0029.jpg 0
406 | office31//webcam/images/desktop_computer/frame_0001.jpg 8
407 | office31//webcam/images/desktop_computer/frame_0002.jpg 8
408 | office31//webcam/images/desktop_computer/frame_0003.jpg 8
409 | office31//webcam/images/desktop_computer/frame_0004.jpg 8
410 | office31//webcam/images/desktop_computer/frame_0005.jpg 8
411 | office31//webcam/images/desktop_computer/frame_0006.jpg 8
412 | office31//webcam/images/desktop_computer/frame_0007.jpg 8
413 | office31//webcam/images/desktop_computer/frame_0008.jpg 8
414 | office31//webcam/images/desktop_computer/frame_0009.jpg 8
415 | office31//webcam/images/desktop_computer/frame_0010.jpg 8
416 | office31//webcam/images/desktop_computer/frame_0011.jpg 8
417 | office31//webcam/images/desktop_computer/frame_0012.jpg 8
418 | office31//webcam/images/desktop_computer/frame_0013.jpg 8
419 | office31//webcam/images/desktop_computer/frame_0014.jpg 8
420 | office31//webcam/images/desktop_computer/frame_0015.jpg 8
421 | office31//webcam/images/desktop_computer/frame_0016.jpg 8
422 | office31//webcam/images/desktop_computer/frame_0017.jpg 8
423 | office31//webcam/images/desktop_computer/frame_0018.jpg 8
424 | office31//webcam/images/desktop_computer/frame_0019.jpg 8
425 | office31//webcam/images/desktop_computer/frame_0020.jpg 8
426 | office31//webcam/images/desktop_computer/frame_0021.jpg 8
427 | office31//webcam/images/speaker/frame_0001.jpg 27
428 | office31//webcam/images/speaker/frame_0002.jpg 27
429 | office31//webcam/images/speaker/frame_0003.jpg 27
430 | office31//webcam/images/speaker/frame_0004.jpg 27
431 | office31//webcam/images/speaker/frame_0005.jpg 27
432 | office31//webcam/images/speaker/frame_0006.jpg 27
433 | office31//webcam/images/speaker/frame_0007.jpg 27
434 | office31//webcam/images/speaker/frame_0008.jpg 27
435 | office31//webcam/images/speaker/frame_0009.jpg 27
436 | office31//webcam/images/speaker/frame_0010.jpg 27
437 | office31//webcam/images/speaker/frame_0011.jpg 27
438 | office31//webcam/images/speaker/frame_0012.jpg 27
439 | office31//webcam/images/speaker/frame_0013.jpg 27
440 | office31//webcam/images/speaker/frame_0014.jpg 27
441 | office31//webcam/images/speaker/frame_0015.jpg 27
442 | office31//webcam/images/speaker/frame_0016.jpg 27
443 | office31//webcam/images/speaker/frame_0017.jpg 27
444 | office31//webcam/images/speaker/frame_0018.jpg 27
445 | office31//webcam/images/speaker/frame_0019.jpg 27
446 | office31//webcam/images/speaker/frame_0020.jpg 27
447 | office31//webcam/images/speaker/frame_0021.jpg 27
448 | office31//webcam/images/speaker/frame_0022.jpg 27
449 | office31//webcam/images/speaker/frame_0023.jpg 27
450 | office31//webcam/images/speaker/frame_0024.jpg 27
451 | office31//webcam/images/speaker/frame_0025.jpg 27
452 | office31//webcam/images/speaker/frame_0026.jpg 27
453 | office31//webcam/images/speaker/frame_0027.jpg 27
454 | office31//webcam/images/speaker/frame_0028.jpg 27
455 | office31//webcam/images/speaker/frame_0029.jpg 27
456 | office31//webcam/images/speaker/frame_0030.jpg 27
457 | office31//webcam/images/mobile_phone/frame_0001.jpg 14
458 | office31//webcam/images/mobile_phone/frame_0002.jpg 14
459 | office31//webcam/images/mobile_phone/frame_0003.jpg 14
460 | office31//webcam/images/mobile_phone/frame_0004.jpg 14
461 | office31//webcam/images/mobile_phone/frame_0005.jpg 14
462 | office31//webcam/images/mobile_phone/frame_0006.jpg 14
463 | office31//webcam/images/mobile_phone/frame_0007.jpg 14
464 | office31//webcam/images/mobile_phone/frame_0008.jpg 14
465 | office31//webcam/images/mobile_phone/frame_0009.jpg 14
466 | office31//webcam/images/mobile_phone/frame_0010.jpg 14
467 | office31//webcam/images/mobile_phone/frame_0011.jpg 14
468 | office31//webcam/images/mobile_phone/frame_0012.jpg 14
469 | office31//webcam/images/mobile_phone/frame_0013.jpg 14
470 | office31//webcam/images/mobile_phone/frame_0014.jpg 14
471 | office31//webcam/images/mobile_phone/frame_0015.jpg 14
472 | office31//webcam/images/mobile_phone/frame_0016.jpg 14
473 | office31//webcam/images/mobile_phone/frame_0017.jpg 14
474 | office31//webcam/images/mobile_phone/frame_0018.jpg 14
475 | office31//webcam/images/mobile_phone/frame_0019.jpg 14
476 | office31//webcam/images/mobile_phone/frame_0020.jpg 14
477 | office31//webcam/images/mobile_phone/frame_0021.jpg 14
478 | office31//webcam/images/mobile_phone/frame_0022.jpg 14
479 | office31//webcam/images/mobile_phone/frame_0023.jpg 14
480 | office31//webcam/images/mobile_phone/frame_0024.jpg 14
481 | office31//webcam/images/mobile_phone/frame_0025.jpg 14
482 | office31//webcam/images/mobile_phone/frame_0026.jpg 14
483 | office31//webcam/images/mobile_phone/frame_0027.jpg 14
484 | office31//webcam/images/mobile_phone/frame_0028.jpg 14
485 | office31//webcam/images/mobile_phone/frame_0029.jpg 14
486 | office31//webcam/images/mobile_phone/frame_0030.jpg 14
487 | office31//webcam/images/paper_notebook/frame_0001.jpg 18
488 | office31//webcam/images/paper_notebook/frame_0002.jpg 18
489 | office31//webcam/images/paper_notebook/frame_0003.jpg 18
490 | office31//webcam/images/paper_notebook/frame_0004.jpg 18
491 | office31//webcam/images/paper_notebook/frame_0005.jpg 18
492 | office31//webcam/images/paper_notebook/frame_0006.jpg 18
493 | office31//webcam/images/paper_notebook/frame_0007.jpg 18
494 | office31//webcam/images/paper_notebook/frame_0008.jpg 18
495 | office31//webcam/images/paper_notebook/frame_0009.jpg 18
496 | office31//webcam/images/paper_notebook/frame_0010.jpg 18
497 | office31//webcam/images/paper_notebook/frame_0011.jpg 18
498 | office31//webcam/images/paper_notebook/frame_0012.jpg 18
499 | office31//webcam/images/paper_notebook/frame_0013.jpg 18
500 | office31//webcam/images/paper_notebook/frame_0014.jpg 18
501 | office31//webcam/images/paper_notebook/frame_0015.jpg 18
502 | office31//webcam/images/paper_notebook/frame_0016.jpg 18
503 | office31//webcam/images/paper_notebook/frame_0017.jpg 18
504 | office31//webcam/images/paper_notebook/frame_0018.jpg 18
505 | office31//webcam/images/paper_notebook/frame_0019.jpg 18
506 | office31//webcam/images/paper_notebook/frame_0020.jpg 18
507 | office31//webcam/images/paper_notebook/frame_0021.jpg 18
508 | office31//webcam/images/paper_notebook/frame_0022.jpg 18
509 | office31//webcam/images/paper_notebook/frame_0023.jpg 18
510 | office31//webcam/images/paper_notebook/frame_0024.jpg 18
511 | office31//webcam/images/paper_notebook/frame_0025.jpg 18
512 | office31//webcam/images/paper_notebook/frame_0026.jpg 18
513 | office31//webcam/images/paper_notebook/frame_0027.jpg 18
514 | office31//webcam/images/paper_notebook/frame_0028.jpg 18
515 | office31//webcam/images/ruler/frame_0001.jpg 25
516 | office31//webcam/images/ruler/frame_0002.jpg 25
517 | office31//webcam/images/ruler/frame_0003.jpg 25
518 | office31//webcam/images/ruler/frame_0004.jpg 25
519 | office31//webcam/images/ruler/frame_0005.jpg 25
520 | office31//webcam/images/ruler/frame_0006.jpg 25
521 | office31//webcam/images/ruler/frame_0007.jpg 25
522 | office31//webcam/images/ruler/frame_0008.jpg 25
523 | office31//webcam/images/ruler/frame_0009.jpg 25
524 | office31//webcam/images/ruler/frame_0010.jpg 25
525 | office31//webcam/images/ruler/frame_0011.jpg 25
526 | office31//webcam/images/letter_tray/frame_0001.jpg 13
527 | office31//webcam/images/letter_tray/frame_0002.jpg 13
528 | office31//webcam/images/letter_tray/frame_0003.jpg 13
529 | office31//webcam/images/letter_tray/frame_0004.jpg 13
530 | office31//webcam/images/letter_tray/frame_0005.jpg 13
531 | office31//webcam/images/letter_tray/frame_0006.jpg 13
532 | office31//webcam/images/letter_tray/frame_0007.jpg 13
533 | office31//webcam/images/letter_tray/frame_0008.jpg 13
534 | office31//webcam/images/letter_tray/frame_0009.jpg 13
535 | office31//webcam/images/letter_tray/frame_0010.jpg 13
536 | office31//webcam/images/letter_tray/frame_0011.jpg 13
537 | office31//webcam/images/letter_tray/frame_0012.jpg 13
538 | office31//webcam/images/letter_tray/frame_0013.jpg 13
539 | office31//webcam/images/letter_tray/frame_0014.jpg 13
540 | office31//webcam/images/letter_tray/frame_0015.jpg 13
541 | office31//webcam/images/letter_tray/frame_0016.jpg 13
542 | office31//webcam/images/letter_tray/frame_0017.jpg 13
543 | office31//webcam/images/letter_tray/frame_0018.jpg 13
544 | office31//webcam/images/letter_tray/frame_0019.jpg 13
545 | office31//webcam/images/file_cabinet/frame_0001.jpg 9
546 | office31//webcam/images/file_cabinet/frame_0002.jpg 9
547 | office31//webcam/images/file_cabinet/frame_0003.jpg 9
548 | office31//webcam/images/file_cabinet/frame_0004.jpg 9
549 | office31//webcam/images/file_cabinet/frame_0005.jpg 9
550 | office31//webcam/images/file_cabinet/frame_0006.jpg 9
551 | office31//webcam/images/file_cabinet/frame_0007.jpg 9
552 | office31//webcam/images/file_cabinet/frame_0008.jpg 9
553 | office31//webcam/images/file_cabinet/frame_0009.jpg 9
554 | office31//webcam/images/file_cabinet/frame_0010.jpg 9
555 | office31//webcam/images/file_cabinet/frame_0011.jpg 9
556 | office31//webcam/images/file_cabinet/frame_0012.jpg 9
557 | office31//webcam/images/file_cabinet/frame_0013.jpg 9
558 | office31//webcam/images/file_cabinet/frame_0014.jpg 9
559 | office31//webcam/images/file_cabinet/frame_0015.jpg 9
560 | office31//webcam/images/file_cabinet/frame_0016.jpg 9
561 | office31//webcam/images/file_cabinet/frame_0017.jpg 9
562 | office31//webcam/images/file_cabinet/frame_0018.jpg 9
563 | office31//webcam/images/file_cabinet/frame_0019.jpg 9
564 | office31//webcam/images/phone/frame_0001.jpg 20
565 | office31//webcam/images/phone/frame_0002.jpg 20
566 | office31//webcam/images/phone/frame_0003.jpg 20
567 | office31//webcam/images/phone/frame_0004.jpg 20
568 | office31//webcam/images/phone/frame_0005.jpg 20
569 | office31//webcam/images/phone/frame_0006.jpg 20
570 | office31//webcam/images/phone/frame_0007.jpg 20
571 | office31//webcam/images/phone/frame_0008.jpg 20
572 | office31//webcam/images/phone/frame_0009.jpg 20
573 | office31//webcam/images/phone/frame_0010.jpg 20
574 | office31//webcam/images/phone/frame_0011.jpg 20
575 | office31//webcam/images/phone/frame_0012.jpg 20
576 | office31//webcam/images/phone/frame_0013.jpg 20
577 | office31//webcam/images/phone/frame_0014.jpg 20
578 | office31//webcam/images/phone/frame_0015.jpg 20
579 | office31//webcam/images/phone/frame_0016.jpg 20
580 | office31//webcam/images/bookcase/frame_0001.jpg 3
581 | office31//webcam/images/bookcase/frame_0002.jpg 3
582 | office31//webcam/images/bookcase/frame_0003.jpg 3
583 | office31//webcam/images/bookcase/frame_0004.jpg 3
584 | office31//webcam/images/bookcase/frame_0005.jpg 3
585 | office31//webcam/images/bookcase/frame_0006.jpg 3
586 | office31//webcam/images/bookcase/frame_0007.jpg 3
587 | office31//webcam/images/bookcase/frame_0008.jpg 3
588 | office31//webcam/images/bookcase/frame_0009.jpg 3
589 | office31//webcam/images/bookcase/frame_0010.jpg 3
590 | office31//webcam/images/bookcase/frame_0011.jpg 3
591 | office31//webcam/images/bookcase/frame_0012.jpg 3
592 | office31//webcam/images/projector/frame_0001.jpg 22
593 | office31//webcam/images/projector/frame_0002.jpg 22
594 | office31//webcam/images/projector/frame_0003.jpg 22
595 | office31//webcam/images/projector/frame_0004.jpg 22
596 | office31//webcam/images/projector/frame_0005.jpg 22
597 | office31//webcam/images/projector/frame_0006.jpg 22
598 | office31//webcam/images/projector/frame_0007.jpg 22
599 | office31//webcam/images/projector/frame_0008.jpg 22
600 | office31//webcam/images/projector/frame_0009.jpg 22
601 | office31//webcam/images/projector/frame_0010.jpg 22
602 | office31//webcam/images/projector/frame_0011.jpg 22
603 | office31//webcam/images/projector/frame_0012.jpg 22
604 | office31//webcam/images/projector/frame_0013.jpg 22
605 | office31//webcam/images/projector/frame_0014.jpg 22
606 | office31//webcam/images/projector/frame_0015.jpg 22
607 | office31//webcam/images/projector/frame_0016.jpg 22
608 | office31//webcam/images/projector/frame_0017.jpg 22
609 | office31//webcam/images/projector/frame_0018.jpg 22
610 | office31//webcam/images/projector/frame_0019.jpg 22
611 | office31//webcam/images/projector/frame_0020.jpg 22
612 | office31//webcam/images/projector/frame_0021.jpg 22
613 | office31//webcam/images/projector/frame_0022.jpg 22
614 | office31//webcam/images/projector/frame_0023.jpg 22
615 | office31//webcam/images/projector/frame_0024.jpg 22
616 | office31//webcam/images/projector/frame_0025.jpg 22
617 | office31//webcam/images/projector/frame_0026.jpg 22
618 | office31//webcam/images/projector/frame_0027.jpg 22
619 | office31//webcam/images/projector/frame_0028.jpg 22
620 | office31//webcam/images/projector/frame_0029.jpg 22
621 | office31//webcam/images/projector/frame_0030.jpg 22
622 | office31//webcam/images/stapler/frame_0001.jpg 28
623 | office31//webcam/images/stapler/frame_0002.jpg 28
624 | office31//webcam/images/stapler/frame_0003.jpg 28
625 | office31//webcam/images/stapler/frame_0004.jpg 28
626 | office31//webcam/images/stapler/frame_0005.jpg 28
627 | office31//webcam/images/stapler/frame_0006.jpg 28
628 | office31//webcam/images/stapler/frame_0007.jpg 28
629 | office31//webcam/images/stapler/frame_0008.jpg 28
630 | office31//webcam/images/stapler/frame_0009.jpg 28
631 | office31//webcam/images/stapler/frame_0010.jpg 28
632 | office31//webcam/images/stapler/frame_0011.jpg 28
633 | office31//webcam/images/stapler/frame_0012.jpg 28
634 | office31//webcam/images/stapler/frame_0013.jpg 28
635 | office31//webcam/images/stapler/frame_0014.jpg 28
636 | office31//webcam/images/stapler/frame_0015.jpg 28
637 | office31//webcam/images/stapler/frame_0016.jpg 28
638 | office31//webcam/images/stapler/frame_0017.jpg 28
639 | office31//webcam/images/stapler/frame_0018.jpg 28
640 | office31//webcam/images/stapler/frame_0019.jpg 28
641 | office31//webcam/images/stapler/frame_0020.jpg 28
642 | office31//webcam/images/stapler/frame_0021.jpg 28
643 | office31//webcam/images/stapler/frame_0022.jpg 28
644 | office31//webcam/images/stapler/frame_0023.jpg 28
645 | office31//webcam/images/stapler/frame_0024.jpg 28
646 | office31//webcam/images/trash_can/frame_0001.jpg 30
647 | office31//webcam/images/trash_can/frame_0002.jpg 30
648 | office31//webcam/images/trash_can/frame_0003.jpg 30
649 | office31//webcam/images/trash_can/frame_0004.jpg 30
650 | office31//webcam/images/trash_can/frame_0005.jpg 30
651 | office31//webcam/images/trash_can/frame_0006.jpg 30
652 | office31//webcam/images/trash_can/frame_0007.jpg 30
653 | office31//webcam/images/trash_can/frame_0008.jpg 30
654 | office31//webcam/images/trash_can/frame_0009.jpg 30
655 | office31//webcam/images/trash_can/frame_0010.jpg 30
656 | office31//webcam/images/trash_can/frame_0011.jpg 30
657 | office31//webcam/images/trash_can/frame_0012.jpg 30
658 | office31//webcam/images/trash_can/frame_0013.jpg 30
659 | office31//webcam/images/trash_can/frame_0014.jpg 30
660 | office31//webcam/images/trash_can/frame_0015.jpg 30
661 | office31//webcam/images/trash_can/frame_0016.jpg 30
662 | office31//webcam/images/trash_can/frame_0017.jpg 30
663 | office31//webcam/images/trash_can/frame_0018.jpg 30
664 | office31//webcam/images/trash_can/frame_0019.jpg 30
665 | office31//webcam/images/trash_can/frame_0020.jpg 30
666 | office31//webcam/images/trash_can/frame_0021.jpg 30
667 | office31//webcam/images/bike_helmet/frame_0001.jpg 2
668 | office31//webcam/images/bike_helmet/frame_0002.jpg 2
669 | office31//webcam/images/bike_helmet/frame_0003.jpg 2
670 | office31//webcam/images/bike_helmet/frame_0004.jpg 2
671 | office31//webcam/images/bike_helmet/frame_0005.jpg 2
672 | office31//webcam/images/bike_helmet/frame_0006.jpg 2
673 | office31//webcam/images/bike_helmet/frame_0007.jpg 2
674 | office31//webcam/images/bike_helmet/frame_0008.jpg 2
675 | office31//webcam/images/bike_helmet/frame_0009.jpg 2
676 | office31//webcam/images/bike_helmet/frame_0010.jpg 2
677 | office31//webcam/images/bike_helmet/frame_0011.jpg 2
678 | office31//webcam/images/bike_helmet/frame_0012.jpg 2
679 | office31//webcam/images/bike_helmet/frame_0013.jpg 2
680 | office31//webcam/images/bike_helmet/frame_0014.jpg 2
681 | office31//webcam/images/bike_helmet/frame_0015.jpg 2
682 | office31//webcam/images/bike_helmet/frame_0016.jpg 2
683 | office31//webcam/images/bike_helmet/frame_0017.jpg 2
684 | office31//webcam/images/bike_helmet/frame_0018.jpg 2
685 | office31//webcam/images/bike_helmet/frame_0019.jpg 2
686 | office31//webcam/images/bike_helmet/frame_0020.jpg 2
687 | office31//webcam/images/bike_helmet/frame_0021.jpg 2
688 | office31//webcam/images/bike_helmet/frame_0022.jpg 2
689 | office31//webcam/images/bike_helmet/frame_0023.jpg 2
690 | office31//webcam/images/bike_helmet/frame_0024.jpg 2
691 | office31//webcam/images/bike_helmet/frame_0025.jpg 2
692 | office31//webcam/images/bike_helmet/frame_0026.jpg 2
693 | office31//webcam/images/bike_helmet/frame_0027.jpg 2
694 | office31//webcam/images/bike_helmet/frame_0028.jpg 2
695 | office31//webcam/images/headphones/frame_0001.jpg 10
696 | office31//webcam/images/headphones/frame_0002.jpg 10
697 | office31//webcam/images/headphones/frame_0003.jpg 10
698 | office31//webcam/images/headphones/frame_0004.jpg 10
699 | office31//webcam/images/headphones/frame_0005.jpg 10
700 | office31//webcam/images/headphones/frame_0006.jpg 10
701 | office31//webcam/images/headphones/frame_0007.jpg 10
702 | office31//webcam/images/headphones/frame_0008.jpg 10
703 | office31//webcam/images/headphones/frame_0009.jpg 10
704 | office31//webcam/images/headphones/frame_0010.jpg 10
705 | office31//webcam/images/headphones/frame_0011.jpg 10
706 | office31//webcam/images/headphones/frame_0012.jpg 10
707 | office31//webcam/images/headphones/frame_0013.jpg 10
708 | office31//webcam/images/headphones/frame_0014.jpg 10
709 | office31//webcam/images/headphones/frame_0015.jpg 10
710 | office31//webcam/images/headphones/frame_0016.jpg 10
711 | office31//webcam/images/headphones/frame_0017.jpg 10
712 | office31//webcam/images/headphones/frame_0018.jpg 10
713 | office31//webcam/images/headphones/frame_0019.jpg 10
714 | office31//webcam/images/headphones/frame_0020.jpg 10
715 | office31//webcam/images/headphones/frame_0021.jpg 10
716 | office31//webcam/images/headphones/frame_0022.jpg 10
717 | office31//webcam/images/headphones/frame_0023.jpg 10
718 | office31//webcam/images/headphones/frame_0024.jpg 10
719 | office31//webcam/images/headphones/frame_0025.jpg 10
720 | office31//webcam/images/headphones/frame_0026.jpg 10
721 | office31//webcam/images/headphones/frame_0027.jpg 10
722 | office31//webcam/images/desk_lamp/frame_0001.jpg 7
723 | office31//webcam/images/desk_lamp/frame_0002.jpg 7
724 | office31//webcam/images/desk_lamp/frame_0003.jpg 7
725 | office31//webcam/images/desk_lamp/frame_0004.jpg 7
726 | office31//webcam/images/desk_lamp/frame_0005.jpg 7
727 | office31//webcam/images/desk_lamp/frame_0006.jpg 7
728 | office31//webcam/images/desk_lamp/frame_0007.jpg 7
729 | office31//webcam/images/desk_lamp/frame_0008.jpg 7
730 | office31//webcam/images/desk_lamp/frame_0009.jpg 7
731 | office31//webcam/images/desk_lamp/frame_0010.jpg 7
732 | office31//webcam/images/desk_lamp/frame_0011.jpg 7
733 | office31//webcam/images/desk_lamp/frame_0012.jpg 7
734 | office31//webcam/images/desk_lamp/frame_0013.jpg 7
735 | office31//webcam/images/desk_lamp/frame_0014.jpg 7
736 | office31//webcam/images/desk_lamp/frame_0015.jpg 7
737 | office31//webcam/images/desk_lamp/frame_0016.jpg 7
738 | office31//webcam/images/desk_lamp/frame_0017.jpg 7
739 | office31//webcam/images/desk_lamp/frame_0018.jpg 7
740 | office31//webcam/images/desk_chair/frame_0001.jpg 6
741 | office31//webcam/images/desk_chair/frame_0002.jpg 6
742 | office31//webcam/images/desk_chair/frame_0003.jpg 6
743 | office31//webcam/images/desk_chair/frame_0004.jpg 6
744 | office31//webcam/images/desk_chair/frame_0005.jpg 6
745 | office31//webcam/images/desk_chair/frame_0006.jpg 6
746 | office31//webcam/images/desk_chair/frame_0007.jpg 6
747 | office31//webcam/images/desk_chair/frame_0008.jpg 6
748 | office31//webcam/images/desk_chair/frame_0009.jpg 6
749 | office31//webcam/images/desk_chair/frame_0010.jpg 6
750 | office31//webcam/images/desk_chair/frame_0011.jpg 6
751 | office31//webcam/images/desk_chair/frame_0012.jpg 6
752 | office31//webcam/images/desk_chair/frame_0013.jpg 6
753 | office31//webcam/images/desk_chair/frame_0014.jpg 6
754 | office31//webcam/images/desk_chair/frame_0015.jpg 6
755 | office31//webcam/images/desk_chair/frame_0016.jpg 6
756 | office31//webcam/images/desk_chair/frame_0017.jpg 6
757 | office31//webcam/images/desk_chair/frame_0018.jpg 6
758 | office31//webcam/images/desk_chair/frame_0019.jpg 6
759 | office31//webcam/images/desk_chair/frame_0020.jpg 6
760 | office31//webcam/images/desk_chair/frame_0021.jpg 6
761 | office31//webcam/images/desk_chair/frame_0022.jpg 6
762 | office31//webcam/images/desk_chair/frame_0023.jpg 6
763 | office31//webcam/images/desk_chair/frame_0024.jpg 6
764 | office31//webcam/images/desk_chair/frame_0025.jpg 6
765 | office31//webcam/images/desk_chair/frame_0026.jpg 6
766 | office31//webcam/images/desk_chair/frame_0027.jpg 6
767 | office31//webcam/images/desk_chair/frame_0028.jpg 6
768 | office31//webcam/images/desk_chair/frame_0029.jpg 6
769 | office31//webcam/images/desk_chair/frame_0030.jpg 6
770 | office31//webcam/images/desk_chair/frame_0031.jpg 6
771 | office31//webcam/images/desk_chair/frame_0032.jpg 6
772 | office31//webcam/images/desk_chair/frame_0033.jpg 6
773 | office31//webcam/images/desk_chair/frame_0034.jpg 6
774 | office31//webcam/images/desk_chair/frame_0035.jpg 6
775 | office31//webcam/images/desk_chair/frame_0036.jpg 6
776 | office31//webcam/images/desk_chair/frame_0037.jpg 6
777 | office31//webcam/images/desk_chair/frame_0038.jpg 6
778 | office31//webcam/images/desk_chair/frame_0039.jpg 6
779 | office31//webcam/images/desk_chair/frame_0040.jpg 6
780 | office31//webcam/images/bottle/frame_0001.jpg 4
781 | office31//webcam/images/bottle/frame_0002.jpg 4
782 | office31//webcam/images/bottle/frame_0003.jpg 4
783 | office31//webcam/images/bottle/frame_0004.jpg 4
784 | office31//webcam/images/bottle/frame_0005.jpg 4
785 | office31//webcam/images/bottle/frame_0006.jpg 4
786 | office31//webcam/images/bottle/frame_0007.jpg 4
787 | office31//webcam/images/bottle/frame_0008.jpg 4
788 | office31//webcam/images/bottle/frame_0009.jpg 4
789 | office31//webcam/images/bottle/frame_0010.jpg 4
790 | office31//webcam/images/bottle/frame_0011.jpg 4
791 | office31//webcam/images/bottle/frame_0012.jpg 4
792 | office31//webcam/images/bottle/frame_0013.jpg 4
793 | office31//webcam/images/bottle/frame_0014.jpg 4
794 | office31//webcam/images/bottle/frame_0015.jpg 4
795 | office31//webcam/images/bottle/frame_0016.jpg 4
796 |
--------------------------------------------------------------------------------
/dataset/ASDADataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | from torch.utils.data.sampler import SubsetRandomSampler
5 | from .image_list import ImageList
6 |
7 | class ASDADataset:
8 | # Active Semi-supervised DA Dataset class
9 | def __init__(self, dataset, domain, data_dir, num_classes, batch_size=128, num_workers=4, transforms=None):
10 | self.dataset = dataset
11 | self.domain = domain
12 | self.data_dir = data_dir
13 | self.batch_size = batch_size
14 | self.num_workers = num_workers
15 | self.num_classes = num_classes
16 | self.train_loader, self.valid_loader, self.test_loader = None, None, None
17 |
18 | self.build_dsets(transforms)
19 |
20 | def get_num_classes(self):
21 | return self.num_classes
22 |
23 | def get_dsets(self):
24 | return self.train_dataset, self.query_dataset, self.valid_dataset, self.test_dataset
25 |
26 | def build_dsets(self, transforms=None):
27 | assert transforms is not None
28 |
29 | if self.dataset == "domainnet":
30 | train_list = open(os.path.join(self.data_dir, "image_list", self.dataset, self.domain+"_train.txt")).readlines()
31 | test_list = open(os.path.join(self.data_dir, "image_list", self.dataset, self.domain+"_test.txt")).readlines()
32 | valid_list = train_list.copy()
33 | else:
34 | train_list = open(os.path.join(self.data_dir, "image_list", self.dataset, self.domain+".txt")).readlines()
35 | test_list = train_list.copy()
36 | valid_list = train_list.copy()
37 |
38 | train_dataset = ImageList(train_list, root=self.data_dir, transform=transforms['train'])
39 | query_dataset = ImageList(train_list, root=self.data_dir, transform=transforms['query'])
40 | valid_dataset = ImageList(valid_list, root=self.data_dir, transform=transforms['test'])
41 | test_dataset = ImageList(test_list, root=self.data_dir, transform=transforms['test'])
42 |
43 | self.train_dataset = train_dataset
44 | self.query_dataset = query_dataset
45 | self.valid_dataset = valid_dataset
46 | self.test_dataset = test_dataset
47 |
48 |
49 | def get_loaders(self, valid_type='val', valid_ratio=1.0, rebuilt=False):
50 |
51 | if self.train_loader and self.valid_loader and self.test_loader and not rebuilt:
52 | return self.train_loader, self.valid_loader, self.test_loader
53 |
54 | num_train = len(self.train_dataset)
55 | self.train_size = num_train
56 |
57 | if valid_type == 'split':
58 | indices = list(range(num_train))
59 | split = int(np.floor(valid_ratio * num_train))
60 | np.random.shuffle(indices)
61 | train_idx, valid_idx = indices[split:], indices[:split]
62 |
63 | elif valid_type == 'val':
64 | train_idx = np.arange(len(self.train_dataset))
65 | if valid_ratio == 1.0:
66 | valid_idx = np.arange(len(self.valid_dataset))
67 | else:
68 | indices = list(range(len(self.valid_dataset)))
69 | split = int(np.floor(valid_ratio * num_train))
70 | np.random.shuffle(indices)
71 | valid_idx = indices[:split]
72 | else:
73 | raise NotImplementedError
74 |
75 | train_sampler = SubsetRandomSampler(train_idx)
76 | valid_sampler = SubsetRandomSampler(valid_idx)
77 |
78 | self.train_loader = torch.utils.data.DataLoader(self.train_dataset, sampler=train_sampler, \
79 | batch_size=self.batch_size, num_workers=self.num_workers)
80 | self.valid_loader = torch.utils.data.DataLoader(self.valid_dataset, sampler=valid_sampler, batch_size=self.batch_size,
81 | num_workers=self.num_workers)
82 | self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
83 |
84 | self.train_idx = train_idx
85 | self.valid_idx = valid_idx
86 |
87 | return self.train_loader, self.valid_loader, self.test_loader
--------------------------------------------------------------------------------
/dataset/image_list.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import os.path as osp
3 | import numpy as np
4 |
5 | def make_dataset(image_list, labels):
6 | if labels:
7 | len_ = len(image_list)
8 | images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
9 | else:
10 | if type(image_list[0]) is tuple:
11 | return image_list
12 |
13 | if len(image_list[0].split()) > 2:
14 | images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
15 | else:
16 | images = [(val.split()[0], int(val.split()[1])) for val in image_list]
17 | return images
18 |
19 |
20 | def pil_loader(root, path):
21 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
22 | with open(osp.join(root, path), 'rb') as f:
23 | img = Image.open(f)
24 | return img.convert('RGB')
25 |
26 |
27 | class ImageList(object):
28 | """A generic data loader where the images are arranged in this way: ::
29 | root/dog/xxx.png
30 | root/dog/xxy.png
31 | root/dog/xxz.png
32 | root/cat/123.png
33 | root/cat/nsdf3.png
34 | root/cat/asd932_.png
35 | Args:
36 | root (string): Root directory path.
37 | transform (callable, optional): A function/transform that takes in an PIL image
38 | and returns a transformed version. E.g, ``transforms.RandomCrop``
39 | target_transform (callable, optional): A function/transform that takes in the
40 | target and transforms it.
41 | loader (callable, optional): A function to load an image given its path.
42 | Attributes:
43 | classes (list): List of the class names.
44 | class_to_idx (dict): Dict with items (class_name, class_index).
45 | imgs (list): List of (image path, class_index) tuples
46 | """
47 |
48 | def __init__(self, image_list, labels=None, root='data', transform=None, target_transform=None, rand_transform=None):
49 | samples = make_dataset(image_list, labels)
50 | self.samples = samples
51 | self.transform = transform
52 | self.target_transform = target_transform
53 | self.rand_transform = rand_transform
54 | self.rand_num = 0
55 | self.loader = pil_loader
56 | self.root = root
57 |
58 | def __getitem__(self, index):
59 |
60 | path, target = self.samples[index]
61 | target = int(target)
62 |
63 | sample_ = self.loader(self.root, path)
64 |
65 | if self.transform is not None:
66 | sample = self.transform(sample_)
67 | if self.target_transform is not None:
68 | target = self.target_transform(target)
69 | if self.rand_transform is not None:
70 | rand_sample = []
71 | for i in range(self.rand_num):
72 | rand_sample.append(self.rand_transform(sample_))
73 |
74 | return sample, target, index, *rand_sample
75 | else:
76 | return sample, target, index
77 |
78 | def __len__(self):
79 | return len(self.samples)
80 |
81 | def add_item(self, addition):
82 | # self.samples = np.concatenate((self.samples, addition), axis=0)
83 | self.samples.extend(addition)
84 | return self.samples
85 |
86 | def remove_item(self, reduced):
87 | self.samples = np.delete(self.samples, reduced, axis=0)
88 | return self.samples
89 |
--------------------------------------------------------------------------------
/dataset/randaugment.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from rpmcruz/autoaugment
2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3 | import random
4 |
5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 |
10 |
11 | def ShearX(img, v): # [-0.3, 0.3]
12 | assert -0.3 <= v <= 0.3
13 | if random.random() > 0.5:
14 | v = -v
15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
16 |
17 |
18 | def ShearY(img, v): # [-0.3, 0.3]
19 | assert -0.3 <= v <= 0.3
20 | if random.random() > 0.5:
21 | v = -v
22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
23 |
24 |
25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
26 | assert -0.45 <= v <= 0.45
27 | if random.random() > 0.5:
28 | v = -v
29 | v = v * img.size[0]
30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
31 |
32 |
33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
34 | assert 0 <= v
35 | if random.random() > 0.5:
36 | v = -v
37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
38 |
39 |
40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
41 | assert -0.45 <= v <= 0.45
42 | if random.random() > 0.5:
43 | v = -v
44 | v = v * img.size[1]
45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
46 |
47 |
48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
49 | assert 0 <= v
50 | if random.random() > 0.5:
51 | v = -v
52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
53 |
54 |
55 | def Rotate(img, v): # [-30, 30]
56 | assert -30 <= v <= 30
57 | if random.random() > 0.5:
58 | v = -v
59 | return img.rotate(v)
60 |
61 |
62 | def AutoContrast(img, _):
63 | return PIL.ImageOps.autocontrast(img)
64 |
65 |
66 | def Invert(img, _):
67 | return PIL.ImageOps.invert(img)
68 |
69 |
70 | def Equalize(img, _):
71 | return PIL.ImageOps.equalize(img)
72 |
73 |
74 | def Flip(img, _): # not from the paper
75 | return PIL.ImageOps.mirror(img)
76 |
77 |
78 | def Solarize(img, v): # [0, 256]
79 | assert 0 <= v <= 256
80 | return PIL.ImageOps.solarize(img, v)
81 |
82 |
83 | def SolarizeAdd(img, addition=0, threshold=128):
84 | img_np = np.array(img).astype(np.int)
85 | img_np = img_np + addition
86 | img_np = np.clip(img_np, 0, 255)
87 | img_np = img_np.astype(np.uint8)
88 | img = Image.fromarray(img_np)
89 | return PIL.ImageOps.solarize(img, threshold)
90 |
91 |
92 | def Posterize(img, v): # [4, 8]
93 | v = int(v)
94 | v = max(1, v)
95 | return PIL.ImageOps.posterize(img, v)
96 |
97 |
98 | def Contrast(img, v): # [0.1,1.9]
99 | assert 0.1 <= v <= 1.9
100 | return PIL.ImageEnhance.Contrast(img).enhance(v)
101 |
102 |
103 | def Color(img, v): # [0.1,1.9]
104 | assert 0.1 <= v <= 1.9
105 | return PIL.ImageEnhance.Color(img).enhance(v)
106 |
107 |
108 | def Brightness(img, v): # [0.1,1.9]
109 | assert 0.1 <= v <= 1.9
110 | return PIL.ImageEnhance.Brightness(img).enhance(v)
111 |
112 |
113 | def Sharpness(img, v): # [0.1,1.9]
114 | assert 0.1 <= v <= 1.9
115 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
116 |
117 |
118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
119 | assert 0.0 <= v <= 0.2
120 | if v <= 0.:
121 | return img
122 |
123 | v = v * img.size[0]
124 | return CutoutAbs(img, v)
125 |
126 |
127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
128 | # assert 0 <= v <= 20
129 | if v < 0:
130 | return img
131 | w, h = img.size
132 | x0 = np.random.uniform(w)
133 | y0 = np.random.uniform(h)
134 |
135 | x0 = int(max(0, x0 - v / 2.))
136 | y0 = int(max(0, y0 - v / 2.))
137 | x1 = min(w, x0 + v)
138 | y1 = min(h, y0 + v)
139 |
140 | xy = (x0, y0, x1, y1)
141 | color = (125, 123, 114)
142 | # color = (0, 0, 0)
143 | img = img.copy()
144 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
145 | return img
146 |
147 |
148 | def SamplePairing(imgs): # [0, 0.4]
149 | def f(img1, v):
150 | i = np.random.choice(len(imgs))
151 | img2 = PIL.Image.fromarray(imgs[i])
152 | return PIL.Image.blend(img1, img2, v)
153 |
154 | return f
155 |
156 |
157 | def Identity(img, v):
158 | return img
159 |
160 |
161 | def augment_list(): # 16 oeprations and their ranges
162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
163 | # l = [
164 | # (Identity, 0., 1.0),
165 | # (ShearX, 0., 0.3), # 0
166 | # (ShearY, 0., 0.3), # 1
167 | # (TranslateX, 0., 0.33), # 2
168 | # (TranslateY, 0., 0.33), # 3
169 | # (Rotate, 0, 30), # 4
170 | # (AutoContrast, 0, 1), # 5
171 | # (Invert, 0, 1), # 6
172 | # (Equalize, 0, 1), # 7
173 | # (Solarize, 0, 110), # 8
174 | # (Posterize, 4, 8), # 9
175 | # # (Contrast, 0.1, 1.9), # 10
176 | # (Color, 0.1, 1.9), # 11
177 | # (Brightness, 0.1, 1.9), # 12
178 | # (Sharpness, 0.1, 1.9), # 13
179 | # # (Cutout, 0, 0.2), # 14
180 | # # (SamplePairing(imgs), 0, 0.4), # 15
181 | # ]
182 |
183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
184 | l = [
185 | (AutoContrast, 0, 1),
186 | (Equalize, 0, 1),
187 | (Invert, 0, 1),
188 | (Rotate, 0, 30),
189 | (Posterize, 0, 4),
190 | (Solarize, 0, 256),
191 | (SolarizeAdd, 0, 110),
192 | (Color, 0.1, 1.9),
193 | (Contrast, 0.1, 1.9),
194 | (Brightness, 0.1, 1.9),
195 | (Sharpness, 0.1, 1.9),
196 | (ShearX, 0., 0.3),
197 | (ShearY, 0., 0.3),
198 | (CutoutAbs, 0, 40),
199 | (TranslateXabs, 0., 100),
200 | (TranslateYabs, 0., 100),
201 | ]
202 |
203 | return l
204 |
205 |
206 | # class Lighting(object):
207 | # """Lighting noise(AlexNet - style PCA - based noise)"""
208 | #
209 | # def __init__(self, alphastd, eigval, eigvec):
210 | # self.alphastd = alphastd
211 | # self.eigval = torch.Tensor(eigval)
212 | # self.eigvec = torch.Tensor(eigvec)
213 | #
214 | # def __call__(self, img):
215 | # if self.alphastd == 0:
216 | # return img
217 | #
218 | # alpha = img.new().resize_(3).normal_(0, self.alphastd)
219 | # rgb = self.eigvec.type_as(img).clone() \
220 | # .mul(alpha.view(1, 3).expand(3, 3)) \
221 | # .mul(self.eigval.view(1, 3).expand(3, 3)) \
222 | # .sum(1).squeeze()
223 | #
224 | # return img.add(rgb.view(3, 1, 1).expand_as(img))
225 |
226 |
227 | # class CutoutDefault(object):
228 | # """
229 | # Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
230 | # """
231 | # def __init__(self, length):
232 | # self.length = length
233 | #
234 | # def __call__(self, img):
235 | # h, w = img.size(1), img.size(2)
236 | # mask = np.ones((h, w), np.float32)
237 | # y = np.random.randint(h)
238 | # x = np.random.randint(w)
239 | #
240 | # y1 = np.clip(y - self.length // 2, 0, h)
241 | # y2 = np.clip(y + self.length // 2, 0, h)
242 | # x1 = np.clip(x - self.length // 2, 0, w)
243 | # x2 = np.clip(x + self.length // 2, 0, w)
244 | #
245 | # mask[y1: y2, x1: x2] = 0.
246 | # mask = torch.from_numpy(mask)
247 | # mask = mask.expand_as(img)
248 | # img *= mask
249 | # return img
250 |
251 |
252 | class RandAugment:
253 | def __init__(self, n, m):
254 | self.n = n
255 | self.m = m # [0, 30]
256 | self.augment_list = augment_list()
257 |
258 | def __call__(self, img):
259 |
260 | if self.n == 0:
261 | return img
262 |
263 | ops = random.choices(self.augment_list, k=self.n)
264 | for op, minval, maxval in ops:
265 | val = (float(self.m) / 30) * float(maxval - minval) + minval
266 | img = op(img, val)
267 |
268 | return img
--------------------------------------------------------------------------------
/dataset/randaugmentMC.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from
2 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py
3 | # https://github.com/google-research/fixmatch/blob/master/third_party/auto_augment/augmentations.py
4 | # https://github.com/google-research/fixmatch/blob/master/libml/ctaugment.py
5 | import logging
6 | import random
7 |
8 | import numpy as np
9 | import PIL
10 | import PIL.ImageOps
11 | import PIL.ImageEnhance
12 | import PIL.ImageDraw
13 | from PIL import Image
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | PARAMETER_MAX = 10
18 |
19 |
20 | def AutoContrast(img, **kwarg):
21 | return PIL.ImageOps.autocontrast(img)
22 |
23 |
24 | def Brightness(img, v, max_v, bias=0):
25 | v = _float_parameter(v, max_v) + bias
26 | return PIL.ImageEnhance.Brightness(img).enhance(v)
27 |
28 |
29 | def Color(img, v, max_v, bias=0):
30 | v = _float_parameter(v, max_v) + bias
31 | return PIL.ImageEnhance.Color(img).enhance(v)
32 |
33 |
34 | def Contrast(img, v, max_v, bias=0):
35 | v = _float_parameter(v, max_v) + bias
36 | return PIL.ImageEnhance.Contrast(img).enhance(v)
37 |
38 |
39 | def Cutout(img, v, max_v, bias=0):
40 | if v == 0:
41 | return img
42 | v = _float_parameter(v, max_v) + bias
43 | v = int(v * min(img.size))
44 | return CutoutAbs(img, v)
45 |
46 |
47 | def CutoutAbs(img, v, **kwarg):
48 | w, h = img.size
49 | x0 = np.random.uniform(0, w)
50 | y0 = np.random.uniform(0, h)
51 | x0 = int(max(0, x0 - v / 2.))
52 | y0 = int(max(0, y0 - v / 2.))
53 | x1 = int(min(w, x0 + v))
54 | y1 = int(min(h, y0 + v))
55 | xy = (x0, y0, x1, y1)
56 | # gray
57 | color = (127, 127, 127)
58 | img = img.copy()
59 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
60 | return img
61 |
62 |
63 | def Equalize(img, **kwarg):
64 | return PIL.ImageOps.equalize(img)
65 |
66 |
67 | def Identity(img, **kwarg):
68 | return img
69 |
70 |
71 | def Invert(img, **kwarg):
72 | return PIL.ImageOps.invert(img)
73 |
74 |
75 | def Posterize(img, v, max_v, bias=0):
76 | v = _int_parameter(v, max_v) + bias
77 | return PIL.ImageOps.posterize(img, v)
78 |
79 |
80 | def Rotate(img, v, max_v, bias=0):
81 | v = _int_parameter(v, max_v) + bias
82 | if random.random() < 0.5:
83 | v = -v
84 | return img.rotate(v)
85 |
86 |
87 | def Sharpness(img, v, max_v, bias=0):
88 | v = _float_parameter(v, max_v) + bias
89 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
90 |
91 |
92 | def ShearX(img, v, max_v, bias=0):
93 | v = _float_parameter(v, max_v) + bias
94 | if random.random() < 0.5:
95 | v = -v
96 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
97 |
98 |
99 | def ShearY(img, v, max_v, bias=0):
100 | v = _float_parameter(v, max_v) + bias
101 | if random.random() < 0.5:
102 | v = -v
103 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
104 |
105 |
106 | def Solarize(img, v, max_v, bias=0):
107 | v = _int_parameter(v, max_v) + bias
108 | return PIL.ImageOps.solarize(img, 256 - v)
109 |
110 |
111 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
112 | v = _int_parameter(v, max_v) + bias
113 | if random.random() < 0.5:
114 | v = -v
115 | img_np = np.array(img).astype(np.int)
116 | img_np = img_np + v
117 | img_np = np.clip(img_np, 0, 255)
118 | img_np = img_np.astype(np.uint8)
119 | img = Image.fromarray(img_np)
120 | return PIL.ImageOps.solarize(img, threshold)
121 |
122 |
123 | def TranslateX(img, v, max_v, bias=0):
124 | v = _float_parameter(v, max_v) + bias
125 | if random.random() < 0.5:
126 | v = -v
127 | v = int(v * img.size[0])
128 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
129 |
130 |
131 | def TranslateY(img, v, max_v, bias=0):
132 | v = _float_parameter(v, max_v) + bias
133 | if random.random() < 0.5:
134 | v = -v
135 | v = int(v * img.size[1])
136 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
137 |
138 |
139 | def _float_parameter(v, max_v):
140 | return float(v) * max_v / PARAMETER_MAX
141 |
142 |
143 | def _int_parameter(v, max_v):
144 | return int(v * max_v / PARAMETER_MAX)
145 |
146 |
147 | def fixmatch_augment_pool():
148 | # FixMatch paper
149 | augs = [(AutoContrast, None, None),
150 | (Brightness, 0.9, 0.05),
151 | (Color, 0.9, 0.05),
152 | (Contrast, 0.9, 0.05),
153 | (Equalize, None, None),
154 | (Identity, None, None),
155 | (Posterize, 4, 4),
156 | (Rotate, 30, 0),
157 | (Sharpness, 0.9, 0.05),
158 | (ShearX, 0.3, 0),
159 | (ShearY, 0.3, 0),
160 | (Solarize, 256, 0),
161 | (TranslateX, 0.3, 0),
162 | (TranslateY, 0.3, 0)]
163 | return augs
164 |
165 |
166 | def my_augment_pool():
167 | # Test
168 | augs = [(AutoContrast, None, None),
169 | (Brightness, 1.8, 0.1),
170 | (Color, 1.8, 0.1),
171 | (Contrast, 1.8, 0.1),
172 | (Cutout, 0.2, 0),
173 | (Equalize, None, None),
174 | (Invert, None, None),
175 | (Posterize, 4, 4),
176 | (Rotate, 30, 0),
177 | (Sharpness, 1.8, 0.1),
178 | (ShearX, 0.3, 0),
179 | (ShearY, 0.3, 0),
180 | (Solarize, 256, 0),
181 | (SolarizeAdd, 110, 0),
182 | (TranslateX, 0.45, 0),
183 | (TranslateY, 0.45, 0)]
184 | return augs
185 |
186 |
187 | class RandAugmentPC(object):
188 | def __init__(self, n, m):
189 | assert n >= 1
190 | assert 1 <= m <= 10
191 | self.n = n
192 | self.m = m
193 | self.augment_pool = my_augment_pool()
194 |
195 | def __call__(self, img):
196 | ops = random.choices(self.augment_pool, k=self.n)
197 | for op, max_v, bias in ops:
198 | prob = np.random.uniform(0.2, 0.8)
199 | if random.random() + prob >= 1:
200 | img = op(img, v=self.m, max_v=max_v, bias=bias)
201 | img = CutoutAbs(img, 16)
202 | return img
203 |
204 |
205 | class RandAugmentMC(object):
206 | def __init__(self, n, m):
207 | assert n >= 1
208 | assert 1 <= m <= 10
209 | self.n = n
210 | self.m = m
211 | self.augment_pool = fixmatch_augment_pool()
212 |
213 | def __call__(self, img):
214 | ops = random.choices(self.augment_pool, k=self.n)
215 | for op, max_v, bias in ops:
216 | v = np.random.randint(1, self.m)
217 | if random.random() < 0.5:
218 | img = op(img, v=v, max_v=max_v, bias=bias)
219 | img = CutoutAbs(img, 16)
220 | return img
221 |
--------------------------------------------------------------------------------
/dataset/transform.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from PIL import Image
3 | from .randaugment import RandAugment
4 | from torchvision.transforms import (Resize, Compose, ToTensor, Normalize, CenterCrop, RandomCrop,
5 | RandomResizedCrop, RandomHorizontalFlip)
6 | from .randaugmentMC import RandAugmentMC
7 |
8 | class ResizeImage():
9 | def __init__(self, size):
10 | if isinstance(size, int):
11 | self.size = (int(size), int(size))
12 | else:
13 | self.size = size
14 |
15 | def __call__(self, img):
16 | th, tw = self.size
17 | return img.resize((th, tw))
18 |
19 |
20 | def build_transforms(cfg, domain='source'):
21 | if cfg.DATASET.NAME in ["mnist", "svhn"]:
22 | transforms = None
23 | else:
24 | # train
25 | choices = cfg.DATASET.SOURCE_TRANSFORMS if domain=='source' else cfg.DATASET.TARGET_TRANSFORMS
26 | train_transform = build_transform(choices)
27 | # query
28 | choices = cfg.DATASET.QUERY_TRANSFORMS
29 | query_transform = build_transform(choices)
30 | # test
31 | choices = cfg.DATASET.TEST_TRANSFORMS
32 | test_transform = build_transform(choices)
33 |
34 | transforms = {'train':train_transform, 'query':query_transform, 'test':test_transform}
35 |
36 | return transforms
37 |
38 | def build_transform(choices):
39 | transform = []
40 | if 'Resize' in choices:
41 | transform += [Resize((256, 256))] # make sure resize to equal length and width
42 |
43 | if 'ResizeImage' in choices:
44 | transform += [ResizeImage(256)]
45 |
46 | if 'RandomHorizontalFlip' in choices:
47 | transform += [RandomHorizontalFlip(p=0.5)]
48 |
49 | if 'RandomCrop' in choices:
50 | transform += [RandomCrop(224)]
51 |
52 | if 'RandomResizedCrop' in choices:
53 | transform += [RandomResizedCrop(224)]
54 |
55 | if 'CenterCrop' in choices:
56 | transform += [CenterCrop(224)]
57 |
58 |
59 | transform += [ToTensor()]
60 |
61 | if 'Normalize' in choices:
62 | normalize = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
63 |
64 | transform += [normalize]
65 |
66 | return Compose(transform)
67 |
68 |
69 | rand_transform = transforms.Compose([
70 | RandAugment(1, 2.0),
71 | transforms.Resize(256),
72 | transforms.RandomHorizontalFlip(0.5),
73 | transforms.RandomCrop(224),
74 | transforms.ToTensor(),
75 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
76 | std=[0.229, 0.224, 0.225]),
77 | ])
78 |
79 | rand_transform2 = transforms.Compose([
80 | ResizeImage(256),
81 | transforms.RandomHorizontalFlip(),
82 | transforms.RandomCrop(224),
83 | RandAugmentMC(n=2, m=10),
84 | transforms.ToTensor(),
85 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
86 | ])
--------------------------------------------------------------------------------
/fig/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tsun/LADA/5de718ecd0b1eccc337ff54d4cd854ffb580d7e2/fig/framework.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import logging
4 | import os
5 | import time
6 | import copy
7 | import pprint as pp
8 | from tqdm import tqdm, trange
9 | from collections import defaultdict
10 | import numpy as np
11 | import shutil
12 | import socket
13 | hostName = socket.gethostname()
14 | pid = os.getpid()
15 |
16 | from config.defaults import _C as cfg
17 | from utils.logger import init_logger
18 | from utils.utils import resetRNGseed
19 | from dataset.ASDADataset import ASDADataset
20 | from dataset.transform import build_transforms
21 |
22 | from model import get_model
23 | import utils.utils as utils
24 |
25 | from active.sampler import get_strategy
26 | from active.budget import BudgetAllocator
27 |
28 |
29 | # repeatability
30 | torch.backends.cudnn.deterministic = True
31 | torch.backends.cudnn.benchmark = False
32 |
33 | def run_active_adaptation(cfg, source_model, src_dset, num_classes, device):
34 | source = cfg.DATASET.SOURCE_DOMAIN
35 | target = cfg.DATASET.TARGET_DOMAIN
36 | da_strat = cfg.ADA.DA
37 | al_strat = cfg.ADA.AL
38 |
39 | transforms = build_transforms(cfg, 'target')
40 | tgt_dset = ASDADataset(cfg.DATASET.NAME, cfg.DATASET.TARGET_DOMAIN, data_dir=cfg.DATASET.ROOT,
41 | num_classes=cfg.DATASET.NUM_CLASS, batch_size=cfg.DATALOADER.BATCH_SIZE,
42 | num_workers=cfg.DATALOADER.NUM_WORKERS, transforms=transforms)
43 |
44 | target_test_loader = tgt_dset.get_loaders()[2]
45 |
46 | # Evaluate source model on target test
47 | if cfg.TRAINER.EVAL_ACC:
48 | transfer_perf, _ = utils.test(source_model, device, target_test_loader)
49 | logging.info('{}->{} performance (Before {}): Task={:.2f}'.format(source, target, da_strat, transfer_perf))
50 |
51 |
52 | # Main Active DA loop
53 | logging.info('------------------------------------------------------')
54 | model_init = 'source' if cfg.TRAINER.TRAIN_ON_SOURCE else 'scratch'
55 | logging.info('Running strategy: Init={} AL={} DA={}'.format(model_init, al_strat, da_strat))
56 | logging.info('------------------------------------------------------')
57 |
58 | # Run unsupervised DA at round 0, where applicable
59 | start_perf = 0.
60 |
61 | # Instantiate active sampling strategy
62 | sampling_strategy = get_strategy(al_strat, src_dset, tgt_dset, source_model, device, num_classes, cfg)
63 | del source_model
64 |
65 | if cfg.TRAINER.MAX_UDA_EPOCHS > 0:
66 | target_model = sampling_strategy.train_uda(epochs=cfg.TRAINER.MAX_UDA_EPOCHS)
67 |
68 | # Evaluate adapted source model on target test
69 | if cfg.TRAINER.EVAL_ACC:
70 | start_perf, _ = utils.test(target_model, device, target_test_loader)
71 | logging.info('{}->{} performance (After {}): {:.2f}'.format(source, target, da_strat, start_perf))
72 | logging.info('------------------------------------------------------')
73 |
74 |
75 | # Run Active DA
76 | # Keep track of labeled vs unlabeled data
77 | idxs_lb = np.zeros(len(tgt_dset.train_idx), dtype=bool)
78 |
79 | budget = np.round(len(tgt_dset.train_idx) * cfg.ADA.BUDGET) if cfg.ADA.BUDGET <= 1.0 else np.round(cfg.ADA.BUDGET)
80 | budget_allocator = BudgetAllocator(budget=budget, cfg=cfg)
81 |
82 | tqdm_rat = trange(cfg.TRAINER.MAX_EPOCHS)
83 | target_accs = defaultdict(list)
84 | target_accs[0.0].append(start_perf)
85 |
86 | for epoch in tqdm_rat:
87 |
88 | curr_budget, used_budget = budget_allocator.get_budget(epoch)
89 | tqdm_rat.set_description('# Target labels={} Allowed labels={}'.format(used_budget, curr_budget))
90 | tqdm_rat.refresh()
91 |
92 | # Select instances via AL strategy
93 | if curr_budget > 0:
94 | logging.info('Selecting instances...')
95 | idxs = sampling_strategy.query(curr_budget, epoch)
96 | idxs_lb[idxs] = True
97 | sampling_strategy.update(idxs_lb)
98 | else:
99 | logging.info('No budget for current epoch, skipped...')
100 |
101 | # Update model with new data via DA strategy
102 | target_model = sampling_strategy.train(epoch=epoch)
103 |
104 | # Evaluate on target test and train splits
105 | if cfg.TRAINER.EVAL_ACC:
106 | test_perf, _ = utils.test(target_model, device, target_test_loader)
107 | out_str = '{}->{} Test performance (Epoch {}, # Target labels={:d}): {:.2f}'.format(source, target, epoch,
108 | int(curr_budget+used_budget), test_perf)
109 | logging.info(out_str)
110 | logging.info('------------------------------------------------------')
111 |
112 | target_accs[curr_budget+used_budget].append(test_perf)
113 |
114 |
115 | logging.info("\n{}".format(target_accs))
116 |
117 | return target_accs
118 |
119 |
120 | def ADAtrain(cfg, task):
121 | logging.info("Running task: {}".format(task))
122 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
123 | transforms = build_transforms(cfg, 'source')
124 | src_dset = ASDADataset(cfg.DATASET.NAME, cfg.DATASET.SOURCE_DOMAIN, data_dir=cfg.DATASET.ROOT,
125 | num_classes=cfg.DATASET.NUM_CLASS, batch_size=cfg.DATALOADER.BATCH_SIZE,
126 | num_workers=cfg.DATALOADER.NUM_WORKERS, transforms=transforms)
127 |
128 | src_train_loader, src_valid_loader, src_test_loader = src_dset.get_loaders(valid_type=cfg.DATASET.SOURCE_VALID_TYPE,
129 | valid_ratio=cfg.DATASET.SOURCE_VALID_RATIO)
130 |
131 | # model
132 | source_model = get_model(cfg.MODEL.BACKBONE.NAME, num_cls=cfg.DATASET.NUM_CLASS, normalize=cfg.MODEL.NORMALIZE, temp=cfg.MODEL.TEMP).to(device)
133 | source_file = '{}_{}_source_{}.pth'.format(cfg.DATASET.SOURCE_DOMAIN, cfg.MODEL.BACKBONE.NAME, cfg.TRAINER.MAX_SOURCE_EPOCHS)
134 | source_dir = os.path.join('checkpoints', 'source')
135 | if not os.path.exists(source_dir):
136 | os.makedirs(source_dir)
137 | source_path = os.path.join(source_dir, source_file)
138 | best_source_file = '{}_{}_source_best_{}.pth'.format(cfg.DATASET.SOURCE_DOMAIN, cfg.MODEL.BACKBONE.NAME, cfg.TRAINER.MAX_SOURCE_EPOCHS)
139 | best_source_path = os.path.join(source_dir, best_source_file)
140 |
141 | if cfg.TRAINER.TRAIN_ON_SOURCE and cfg.TRAINER.MAX_SOURCE_EPOCHS>0:
142 | if cfg.TRAINER.LOAD_FROM_CHECKPOINT and os.path.exists(source_path):
143 | logging.info('Loading source checkpoint: {}'.format(source_path))
144 | source_model.load_state_dict(torch.load(source_path, map_location=device), strict=False)
145 | best_source_model = source_model
146 | else:
147 | logging.info('Training {} model...'.format(cfg.DATASET.SOURCE_DOMAIN))
148 | best_val_acc, best_source_model = 0.0, None
149 | source_optimizer = utils.get_optim(cfg.OPTIM.SOURCE_NAME, source_model.parameters(cfg.OPTIM.SOURCE_LR, cfg.OPTIM.BASE_LR_MULT), lr=cfg.OPTIM.SOURCE_LR)
150 |
151 | for epoch in range(cfg.TRAINER.MAX_SOURCE_EPOCHS):
152 | utils.train(source_model, device, src_train_loader, source_optimizer, epoch)
153 |
154 | val_acc, _ = utils.test(source_model, device, src_valid_loader, split="source valid")
155 | logging.info('[Epoch: {}] Valid Accuracy: {:.3f} '.format(epoch, val_acc))
156 |
157 | if (val_acc > best_val_acc):
158 | best_val_acc = val_acc
159 | best_source_model = copy.deepcopy(source_model)
160 | torch.save(best_source_model.state_dict(), best_source_path)
161 |
162 | del source_model
163 | # rename file in case of abnormal exit
164 | shutil.move(best_source_path, source_path)
165 | else:
166 | best_source_model = source_model
167 |
168 | # Evaluate on source test set
169 | if cfg.TRAINER.EVAL_ACC:
170 | test_acc, _ = utils.test(best_source_model, device, src_test_loader, split="source test")
171 | logging.info('{} Test Accuracy: {:.3f} '.format(cfg.DATASET.SOURCE_DOMAIN, test_acc))
172 |
173 | # Run active adaptation experiments
174 | target_accs = run_active_adaptation(cfg, best_source_model, src_dset, cfg.DATASET.NUM_CLASS, device)
175 | pp.pprint(target_accs)
176 |
177 |
178 |
179 | def main():
180 | parser = argparse.ArgumentParser(description='Optimal Budget Allocation for Active Domain Adaptation')
181 | parser.add_argument('--cfg', default='', metavar='FILE', help='path to config file', type=str)
182 | parser.add_argument('--timestamp', default=time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()),
183 | type=str, help='timestamp')
184 | parser.add_argument('--gpu', default='0', type=str, help='which gpu to use')
185 | parser.add_argument('--note', default=None, type=str, help='note to experiment')
186 | parser.add_argument('--log', default='./log', type=str, help='logging directory')
187 | parser.add_argument('--nolog', action='store_true', help='whether use logger')
188 | parser.add_argument("opts", help="Modify config options using the command-line",
189 | default=None, nargs=argparse.REMAINDER)
190 |
191 | args = parser.parse_args()
192 |
193 | cfg.merge_from_file(args.cfg)
194 | cfg.merge_from_list(args.opts)
195 |
196 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
197 |
198 | logger = "{}_{}_{}_{}".format(args.timestamp, cfg.DATASET.NAME, cfg.ADA.AL, cfg.ADA.DA) if not args.nolog else None
199 | init_logger(logger, dir=args.log)
200 |
201 | if args.note is not None:
202 | logging.info("Experiment note : {}".format(args.note))
203 | logging.info('------------------------------------------------------')
204 | logging.info("Running on {} gpu={} pid={}".format(hostName, args.gpu, pid))
205 | logging.info(cfg)
206 | logging.info('------------------------------------------------------')
207 |
208 | if type(cfg.SEED) is tuple or type(cfg.SEED) is list:
209 | seeds = cfg.SEED
210 | else:
211 | seeds = [cfg.SEED]
212 |
213 | for seed in seeds:
214 | logging.info("Using random seed: {}".format(seed))
215 | resetRNGseed(seed)
216 |
217 | if cfg.ADA.TASKS is not None:
218 | ada_tasks = cfg.ADA.TASKS
219 | else:
220 | ada_tasks = [[source, target] for source in cfg.DATASET.SOURCE_DOMAINS
221 | for target in cfg.DATASET.TARGET_DOMAINS if source != target]
222 |
223 | for [source, target] in ada_tasks:
224 | cfg.DATASET.SOURCE_DOMAIN = source
225 | cfg.DATASET.TARGET_DOMAIN = target
226 |
227 | cfg.freeze()
228 | ADAtrain(cfg, task=source + '-->' + target)
229 | cfg.defrost()
230 |
231 |
232 | if __name__ == '__main__':
233 | main()
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import get_model
2 | from .adaptor import AdaptNet
3 | from .network import TaskNet
--------------------------------------------------------------------------------
/model/adaptor.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.nn as nn
4 |
5 | from .models import register_model, get_model
6 |
7 | @register_model('AdaptNet')
8 | class AdaptNet(nn.Module):
9 | "Defines an Adapt Network."
10 |
11 | def __init__(self, num_cls=10, model='LeNet', src_weights_init=None, weights_init=None, weight_sharing='full',
12 | normalize=False, temp=0.05):
13 | super(AdaptNet, self).__init__()
14 | self.name = 'AdaptNet'
15 | self.base_model = model
16 |
17 | self.num_cls = num_cls
18 | self.cls_criterion = nn.CrossEntropyLoss()
19 | self.gan_criterion = nn.CrossEntropyLoss()
20 | self.weight_sharing = weight_sharing
21 | self.normalize = normalize
22 | self.temp = temp
23 | self.setup_net()
24 | if weights_init is not None:
25 | self.load(weights_init)
26 | elif src_weights_init is not None:
27 | self.load_src_net(src_weights_init)
28 | else:
29 | raise Exception('AdaptNet must be initialized with weights.')
30 |
31 | def custom_copy(self, src_net, weight_sharing):
32 | """
33 | Vary degree of weight sharing between source and target CNN's
34 | """
35 | tgt_net = copy.deepcopy(src_net)
36 | if weight_sharing != 'None':
37 | if weight_sharing == 'classifier':
38 | tgt_net.classifier = src_net.classifier
39 | elif weight_sharing == 'full':
40 | tgt_net = src_net
41 | return tgt_net
42 |
43 | def setup_net(self):
44 | """Setup source, target and discriminator networks."""
45 | self.src_net = get_model(self.base_model, num_cls=self.num_cls, normalize=self.normalize, temp=self.temp)
46 | self.tgt_net = self.custom_copy(self.src_net, self.weight_sharing)
47 |
48 | input_dim = self.num_cls
49 | self.discrim = nn.Sequential(
50 | nn.Linear(input_dim, 500),
51 | nn.ReLU(),
52 | nn.Linear(500, 500),
53 | nn.ReLU(),
54 | nn.Linear(500, 2),
55 | )
56 |
57 | self.image_size = self.src_net.image_size
58 | self.num_channels = self.src_net.num_channels
59 |
60 |
61 | def load(self, init_path):
62 | "Loads full src and tgt models."
63 | net_init_dict = torch.load(init_path, map_location=torch.device('cpu'))
64 | self.load_state_dict(net_init_dict, strict=False)
65 |
66 | def load_src_net(self, init_path):
67 | """Initialize source and target with source
68 | weights."""
69 | if type(init_path) is str:
70 | self.src_net.load(init_path)
71 | self.tgt_net.load(init_path)
72 | else:
73 | # initialize from model
74 | self.src_net.load_state_dict(init_path.state_dict())
75 | self.tgt_net.load_state_dict(init_path.state_dict())
76 |
77 | def save(self, out_path):
78 | torch.save(self.state_dict(), out_path)
79 |
80 | def save_tgt_net(self, out_path):
81 | torch.save(self.tgt_net.state_dict(), out_path)
--------------------------------------------------------------------------------
/model/grl.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Any, Tuple
2 | import numpy as np
3 | import torch.nn as nn
4 | from torch.autograd import Function
5 | import torch
6 |
7 |
8 | class GradientReverseFunction(Function):
9 |
10 | @staticmethod
11 | def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
12 | ctx.coeff = coeff
13 | output = input * 1.0
14 | return output
15 |
16 | @staticmethod
17 | def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
18 | return grad_output.neg() * ctx.coeff, None
19 |
20 |
21 | class GradientReverseLayer(nn.Module):
22 | def __init__(self):
23 | super(GradientReverseLayer, self).__init__()
24 |
25 | def forward(self, *input):
26 | return GradientReverseFunction.apply(*input)
27 |
28 |
29 | class WarmStartGradientReverseLayer(nn.Module):
30 | """Gradient Reverse Layer :math:`\mathcal{R}(x)` with warm start
31 |
32 | The forward and backward behaviours are:
33 |
34 | .. math::
35 | \mathcal{R}(x) = x,
36 |
37 | \dfrac{ d\mathcal{R}} {dx} = - \lambda I.
38 |
39 | :math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule:
40 |
41 | .. math::
42 | \lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo
43 |
44 | where :math:`i` is the iteration step.
45 |
46 | Args:
47 | alpha (float, optional): :math:`α`. Default: 1.0
48 | lo (float, optional): Initial value of :math:`\lambda`. Default: 0.0
49 | hi (float, optional): Final value of :math:`\lambda`. Default: 1.0
50 | max_iters (int, optional): :math:`N`. Default: 1000
51 | auto_step (bool, optional): If True, increase :math:`i` each time `forward` is called.
52 | Otherwise use function `step` to increase :math:`i`. Default: False
53 | """
54 |
55 | def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1.,
56 | max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False):
57 | super(WarmStartGradientReverseLayer, self).__init__()
58 | self.alpha = alpha
59 | self.lo = lo
60 | self.hi = hi
61 | self.iter_num = 0
62 | self.max_iters = max_iters
63 | self.auto_step = auto_step
64 |
65 | def forward(self, input: torch.Tensor) -> torch.Tensor:
66 | """"""
67 | coeff = np.float(2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters))
68 | - (self.hi - self.lo) + self.lo)
69 | if self.auto_step:
70 | self.step()
71 | return GradientReverseFunction.apply(input, coeff)
72 |
73 | def step(self):
74 | """Increase iteration number :math:`i` by 1"""
75 | self.iter_num += 1
76 |
--------------------------------------------------------------------------------
/model/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | models = {}
4 | def register_model(name):
5 | def decorator(cls):
6 | models[name] = cls
7 | return cls
8 | return decorator
9 |
10 | def get_model(name, num_cls=10, **args):
11 | net = models[name](num_cls=num_cls, **args)
12 | if torch.cuda.is_available():
13 | net = net.cuda()
14 | return net
15 |
--------------------------------------------------------------------------------
/model/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models
4 | import torch.nn.functional as F
5 | from .grl import GradientReverseFunction
6 | from .models import register_model
7 |
8 |
9 | class TaskNet(nn.Module):
10 | num_channels = 3
11 | image_size = 32
12 | name = 'TaskNet'
13 |
14 | def __init__(self, num_cls=10, normalize=False, temp=0.05):
15 | super(TaskNet, self).__init__()
16 | self.num_cls = num_cls
17 | self.setup_net()
18 | self.criterion = nn.CrossEntropyLoss()
19 | self.normalize = normalize
20 | self.temp = temp
21 |
22 | def forward(self, x, with_emb=False, reverse_grad=False):
23 | x = self.conv_params(x)
24 | x = x.view(x.size(0), -1)
25 | #x = x.clone()
26 | emb = self.fc_params(x)
27 |
28 | if isinstance(self.classifier, nn.Sequential): # LeNet
29 | emb = self.classifier[:-1](emb)
30 | if reverse_grad: emb = GradientReverseFunction.apply(emb)
31 | if self.normalize: emb = F.normalize(emb) / self.temp
32 | score = self.classifier[-1](emb)
33 | else: # ResNet
34 | if reverse_grad: emb = GradientReverseFunction.apply(emb)
35 | if self.normalize: emb = F.normalize(emb) / self.temp
36 | score = self.classifier(emb)
37 |
38 | if with_emb:
39 | return score, emb
40 | else:
41 | return score
42 |
43 | def setup_net(self):
44 | """Method to be implemented in each class."""
45 | pass
46 |
47 | def load(self, init_path):
48 | net_init_dict = torch.load(init_path)
49 | self.load_state_dict(net_init_dict, strict=False)
50 |
51 | def save(self, out_path):
52 | torch.save(self.state_dict(), out_path)
53 |
54 | def parameters(self, lr, lr_scalar=0.1):
55 | parameter_list = [
56 | {'params': self.conv_params.parameters(), 'lr': lr * lr_scalar},
57 | {'params': self.fc_params.parameters(), 'lr': lr},
58 | {'params': self.classifier.parameters(), 'lr': lr},
59 | ]
60 |
61 | return parameter_list
62 |
63 |
64 | @register_model('ResNet34Fc')
65 | class ResNet34Fc(TaskNet):
66 | num_channels = 3
67 | name = 'ResNet34Fc'
68 |
69 | def setup_net(self):
70 | model = torchvision.models.resnet34(pretrained=True)
71 | model.fc = nn.Identity()
72 | self.conv_params = model
73 | self.fc_params = nn.Linear(512, 512)
74 | self.classifier = nn.Linear(512, self.num_cls, bias=False)
75 |
76 |
77 | class BatchNorm1d(nn.Module):
78 | def __init__(self, dim):
79 | super(BatchNorm1d, self).__init__()
80 | self.BatchNorm1d = nn.BatchNorm1d(dim)
81 |
82 | def __call__(self, x):
83 | if x.size(0) == 1:
84 | x = torch.cat((x,x), 0)
85 | x = self.BatchNorm1d(x)[:1]
86 | else:
87 | x = self.BatchNorm1d(x)
88 | return x
89 |
90 |
91 | @register_model('ResNet50Fc')
92 | class ResNet50Fc(TaskNet):
93 | num_channels = 3
94 | name = 'ResNet50Fc'
95 |
96 | def setup_net(self):
97 | model = torchvision.models.resnet50(pretrained=True)
98 | model.fc = nn.Identity()
99 | self.conv_params = model
100 | self.fc_params = nn.Sequential(nn.Linear(2048, 256), BatchNorm1d(256))
101 | self.classifier = nn.Linear(256, self.num_cls)
102 |
103 |
104 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | # Table 1
2 | # office-home 5%-budget
3 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA ft
4 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA mme
5 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA RAA
6 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA LAA
7 |
8 | # office-home 10%-budget
9 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA RAA LADA.S_M 5 ADA.BUDGET 0.1
10 | python main.py --cfg configs/officehome.yaml --gpu 0 --log log/oh/LADA ADA.AL LAS ADA.DA LAA LADA.S_M 5 ADA.BUDGET 0.1
11 |
12 | # office-home rsut 10%-budget
13 | python main.py --cfg configs/officehome_RSUT.yaml --gpu 0 --log log/oh_RSUT/LADA ADA.AL LAS ADA.DA RAA LADA.S_M 5 ADA.BUDGET 0.1
14 | python main.py --cfg configs/officehome_RSUT.yaml --gpu 0 --log log/oh_RSUT/LADA ADA.AL LAS ADA.DA LAA LADA.S_M 5 ADA.BUDGET 0.1
15 |
16 | # office-31 5%-budget
17 | python main.py --cfg configs/office31.yaml --gpu 0 --log log/office31/LADA ADA.AL LAS ADA.DA LAA
18 |
19 | # visda 5%-budget
20 | python main.py --cfg configs/visda.yaml --gpu 0 --log log/visda/LADA ADA.AL LAS ADA.DA LAA
21 |
22 |
23 |
--------------------------------------------------------------------------------
/solver/CDACsolver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from dataset.image_list import ImageList
5 | from .solver import BaseSolver, register_solver
6 | from collections import defaultdict
7 | import numpy as np
8 | from dataset.transform import rand_transform2
9 |
10 | def get_losses_unlabeled(net, im_data, im_data_bar, im_data_bar2, target, BCE, w_cons, device):
11 | """ Get losses for unlabeled samples."""
12 |
13 | output, feat = net(im_data, with_emb=True, reverse_grad=True)
14 | output_bar, feat_bar = net(im_data_bar, with_emb=True, reverse_grad=True)
15 | prob, prob_bar = F.softmax(output, dim=1), F.softmax(output_bar, dim=1)
16 |
17 | # loss for adversarial adpative clustering
18 | aac_loss = advbce_unlabeled(target=target, feat=feat, prob=prob, prob_bar=prob_bar, device=device, bce=BCE)
19 |
20 | output = net.forward_emb(feat)
21 | output_bar = net.forward_emb(feat_bar)
22 | output_bar2 = net(im_data_bar2)
23 |
24 | prob = F.softmax(output, dim=1)
25 | prob_bar = F.softmax(output_bar, dim=1)
26 | prob_bar2 = F.softmax(output_bar2, dim=1)
27 |
28 | max_probs, pseudo_labels = torch.max(prob.detach_(), dim=-1)
29 | mask = max_probs.ge(0.95).float()
30 |
31 | # loss for pseudo labeling
32 | pl_loss = (F.cross_entropy(output_bar2, pseudo_labels, reduction='none') * mask).mean()
33 |
34 | # loss for consistency
35 | con_loss = w_cons * F.mse_loss(prob_bar, prob_bar2)
36 |
37 | return aac_loss, pl_loss, con_loss
38 |
39 |
40 | def advbce_unlabeled(target, feat, prob, prob_bar, device, bce):
41 | """ Construct adversarial adpative clustering loss."""
42 | target_ulb = pairwise_target(feat, target, device)
43 | prob_bottleneck_row, _ = PairEnum2D(prob)
44 | _, prob_bottleneck_col = PairEnum2D(prob_bar)
45 | adv_bce_loss = -bce(prob_bottleneck_row, prob_bottleneck_col, target_ulb)
46 | return adv_bce_loss
47 |
48 |
49 | def pairwise_target(feat, target, device, topk=5):
50 | """ Produce pairwise similarity label."""
51 | feat_detach = feat.detach()
52 | # For unlabeled data
53 | if target is None:
54 | rank_feat = feat_detach
55 | rank_idx = torch.argsort(rank_feat, dim=1, descending=True)
56 | rank_idx1, rank_idx2 = PairEnum2D(rank_idx)
57 | rank_idx1, rank_idx2 = rank_idx1[:, :topk], rank_idx2[:, :topk]
58 | rank_idx1, _ = torch.sort(rank_idx1, dim=1)
59 | rank_idx2, _ = torch.sort(rank_idx2, dim=1)
60 | rank_diff = rank_idx1 - rank_idx2
61 | rank_diff = torch.sum(torch.abs(rank_diff), dim=1)
62 | target_ulb = torch.ones_like(rank_diff).float().to(device)
63 | target_ulb[rank_diff > 0] = 0
64 | # For labeled data
65 | elif target is not None:
66 | target_row, target_col = PairEnum1D(target)
67 | target_ulb = torch.zeros(target.size(0) * target.size(0)).float().to(device)
68 | target_ulb[target_row == target_col] = 1
69 | else:
70 | raise ValueError('Please check your target.')
71 | return target_ulb
72 |
73 |
74 | def PairEnum1D(x):
75 | """ Enumerate all pairs of feature in x with 1 dimension."""
76 | assert x.ndimension() == 1, 'Input dimension must be 1'
77 | x1 = x.repeat(x.size(0), )
78 | x2 = x.repeat(x.size(0)).view(-1, x.size(0)).transpose(1, 0).reshape(-1)
79 | return x1, x2
80 |
81 |
82 | def PairEnum2D(x):
83 | """ Enumerate all pairs of feature in x with 2 dimensions."""
84 | assert x.ndimension() == 2, 'Input dimension must be 2'
85 | x1 = x.repeat(x.size(0), 1)
86 | x2 = x.repeat(1, x.size(0)).view(-1, x.size(1))
87 | return x1, x2
88 |
89 |
90 | def sigmoid_rampup(current, rampup_length):
91 | """ Exponential rampup from https://arxiv.org/abs/1610.02242"""
92 | if rampup_length == 0:
93 | return 1.0
94 | else:
95 | current = np.clip(current, 0.0, rampup_length)
96 | phase = 1.0 - current / rampup_length
97 | return float(np.exp(-5.0 * phase * phase))
98 |
99 |
100 | class BCE(nn.Module):
101 | eps = 1e-7
102 |
103 | def forward(self, prob1, prob2, simi):
104 | P = prob1.mul_(prob2)
105 | P = P.sum(1)
106 | P.mul_(simi).add_(simi.eq(-1).type_as(P))
107 | neglogP = -P.add_(BCE.eps).log_()
108 | return neglogP.mean()
109 |
110 |
111 | class BCE_softlabels(nn.Module):
112 | """ Construct binary cross-entropy loss."""
113 | eps = 1e-7
114 |
115 | def forward(self, prob1, prob2, simi):
116 | P = prob1.mul_(prob2)
117 | P = P.sum(1)
118 | neglogP = - (simi * torch.log(P + BCE.eps) + (1. - simi) * torch.log(1. - P + BCE.eps))
119 | return neglogP.mean()
120 |
121 |
122 | def inv_lr_scheduler(param_lr, optimizer, iter_num, gamma=0.0001,
123 | power=0.75, init_lr=0.001):
124 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
125 | lr = init_lr * (1 + gamma * iter_num) ** (- power)
126 | i = 0
127 | for param_group in optimizer.param_groups:
128 | param_group['lr'] = lr * param_lr[i]
129 | i += 1
130 | return optimizer
131 |
132 |
133 | def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
134 | return np.float(2.0 * (high - low) /
135 | (1.0 + np.exp(- alpha * iter_num / max_iter)) -
136 | (high - low) + low)
137 |
138 |
139 | @register_solver('CDAC')
140 | class CDACSolver(BaseSolver):
141 | """
142 | Implements Cross-Domain Adaptive Clustering for Semi-Supervised Domain Adaptation: https://arxiv.org/abs/2104.09415
143 | https://github.com/lijichang/CVPR2021-SSDA
144 | """
145 |
146 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt,
147 | ada_stage, device, cfg, **kwargs):
148 | super(CDACSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader,
149 | joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs)
150 |
151 | def solve(self, epoch):
152 | self.net.train()
153 |
154 | self.tgt_unsup_loader.dataset.rand_transform = rand_transform2
155 | self.tgt_unsup_loader.dataset.rand_num = 2
156 |
157 | data_iter_s = iter(self.src_loader)
158 | data_iter_t = iter(self.tgt_sup_loader)
159 | data_iter_t_unl = iter(self.tgt_unsup_loader)
160 |
161 | len_train_source = len(self.src_loader)
162 | len_train_target = len(self.tgt_sup_loader)
163 | len_train_target_semi = len(self.tgt_unsup_loader)
164 |
165 | BCE = BCE_softlabels().to(self.device)
166 | criterion = nn.CrossEntropyLoss().to(self.device)
167 |
168 | iter_per_epoch = len(self.src_loader)
169 | for batch_idx in range(iter_per_epoch):
170 | rampup = sigmoid_rampup(batch_idx+epoch*iter_per_epoch, 20000)
171 | w_cons = 30.0 * rampup
172 |
173 | self.tgt_opt = inv_lr_scheduler([0.1, 1.0, 1.0], self.tgt_opt, batch_idx+epoch*iter_per_epoch,
174 | init_lr=0.01)
175 |
176 | if len(self.tgt_sup_loader) > 0:
177 | if batch_idx % len_train_target == 0:
178 | data_iter_t = iter(self.tgt_sup_loader)
179 | if batch_idx % len_train_target_semi == 0:
180 | data_iter_t_unl = iter(self.tgt_unsup_loader)
181 | if batch_idx % len_train_source == 0:
182 | data_iter_s = iter(self.src_loader)
183 | data_t = next(data_iter_t)
184 | data_t_unl = next(data_iter_t_unl)
185 | data_s = next(data_iter_s)
186 |
187 | # load labeled source data
188 | x_s, target_s = data_s[0], data_s[1]
189 | im_data_s = x_s.to(self.device)
190 | gt_labels_s = target_s.to(self.device)
191 |
192 | # load labeled target data
193 | x_t, target_t = data_t[0], data_t[1]
194 | im_data_t = x_t.to(self.device)
195 | gt_labels_t = target_t.to(self.device)
196 |
197 | # load unlabeled target data
198 | x_tu, x_bar_tu, x_bar2_tu = data_t_unl[0], data_t_unl[3], data_t_unl[4]
199 | im_data_tu = x_tu.to(self.device)
200 | im_data_bar_tu = x_bar_tu.to(self.device)
201 | im_data_bar2_tu = x_bar2_tu.to(self.device)
202 |
203 | self.tgt_opt.zero_grad()
204 | # construct losses for overall labeled data
205 | data = torch.cat((im_data_s, im_data_t), 0)
206 | target = torch.cat((gt_labels_s, gt_labels_t), 0)
207 | out1 = self.net(data)
208 | ce_loss = criterion(out1, target)
209 |
210 | ce_loss.backward(retain_graph=True)
211 | self.tgt_opt.step()
212 | self.tgt_opt.zero_grad()
213 |
214 | # construct losses for unlabeled target data
215 | aac_loss, pl_loss, con_loss = get_losses_unlabeled(self.net, im_data=im_data_tu, im_data_bar=im_data_bar_tu,
216 | im_data_bar2=im_data_bar2_tu, target=None, BCE=BCE,
217 | w_cons=w_cons, device=self.device)
218 | loss = (aac_loss + pl_loss + con_loss) * self.cfg.ADA.UNSUP_WT * 10
219 | else:
220 | if batch_idx % len_train_source == 0:
221 | data_iter_s = iter(self.src_loader)
222 | data_s, label_s, _ = next(data_iter_s)
223 | data_s, label_s = data_s.to(self.device), label_s.to(self.device)
224 |
225 | self.tgt_opt.zero_grad()
226 | output_s = self.net(data_s)
227 | loss = nn.CrossEntropyLoss()(output_s, label_s) * self.cfg.ADA.SRC_SUP_WT
228 |
229 | loss.backward()
230 | self.tgt_opt.step()
231 |
232 |
233 |
234 |
235 |
--------------------------------------------------------------------------------
/solver/MCCsolver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import copy
5 | import logging
6 | import numpy as np
7 | from .solver import BaseSolver, register_solver
8 |
9 | def Entropy(input_):
10 | bs = input_.size(0)
11 | epsilon = 1e-5
12 | entropy = -input_ * torch.log(input_ + epsilon)
13 | entropy = torch.sum(entropy, dim=1)
14 | return entropy
15 |
16 |
17 | @register_solver('MCC')
18 | class MCCSolver(BaseSolver):
19 | """
20 | Implements MCC from Minimum Class Confusion for Versatile Domain Adaptation: https://arxiv.org/abs/1912.03699
21 | https://github.com/thuml/Versatile-Domain-Adaptation
22 | """
23 |
24 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt,
25 | ada_stage, device, cfg, **kwargs):
26 | super(MCCSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader,
27 | joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs)
28 |
29 | def solve(self, epoch):
30 | src_iter = iter(self.src_loader)
31 | tgt_un_iter = iter(self.tgt_unsup_loader)
32 | tgt_s_iter = iter(self.tgt_sup_loader)
33 | iter_per_epoch = len(self.src_loader)
34 |
35 | self.net.train()
36 |
37 | for batch_idx in range(iter_per_epoch):
38 | if batch_idx % len(self.src_loader) == 0:
39 | src_iter = iter(self.src_loader)
40 |
41 | if batch_idx % len(self.tgt_unsup_loader) == 0:
42 | tgt_un_iter = iter(self.tgt_unsup_loader)
43 |
44 | data_s, label_s, _ = next(src_iter)
45 | data_s, label_s = data_s.to(self.device), label_s.to(self.device)
46 |
47 | self.tgt_opt.zero_grad()
48 | output_s = self.net(data_s)
49 | loss = nn.CrossEntropyLoss()(output_s, label_s) * self.cfg.ADA.SRC_SUP_WT
50 |
51 | if len(self.tgt_sup_loader) > 0:
52 | try:
53 | data_ts, label_ts, idx_ts = next(tgt_s_iter)
54 | except:
55 | tgt_s_iter = iter(self.tgt_sup_loader)
56 | data_ts, label_ts, idx_ts = next(tgt_s_iter)
57 |
58 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device)
59 | output_ts = self.net(data_ts)
60 |
61 | loss += nn.CrossEntropyLoss()(output_ts, label_ts)
62 |
63 | data_tu, label_tu, _ = next(tgt_un_iter)
64 | data_tu, label_tu = data_tu.to(self.device), label_tu.to(self.device)
65 | output_tu = self.net(data_tu)
66 |
67 | outputs_target_temp = output_tu / self.cfg.MODEL.TEMP
68 | target_softmax_out_temp = nn.Softmax(dim=1)(outputs_target_temp)
69 | target_entropy_weight = Entropy(target_softmax_out_temp).detach()
70 | target_entropy_weight = 1 + torch.exp(-target_entropy_weight)
71 | target_entropy_weight = self.cfg.DATALOADER.BATCH_SIZE * target_entropy_weight / torch.sum(target_entropy_weight)
72 | cov_matrix_t = target_softmax_out_temp.mul(target_entropy_weight.view(-1, 1)).transpose(1, 0).mm(
73 | target_softmax_out_temp)
74 | cov_matrix_t = cov_matrix_t / (torch.sum(cov_matrix_t, dim=1)+1e-12)
75 | mcc_loss = (torch.sum(cov_matrix_t) - torch.trace(cov_matrix_t)) / self.cfg.DATASET.NUM_CLASS
76 |
77 | loss += mcc_loss
78 |
79 | loss.backward()
80 | self.tgt_opt.step()
81 |
82 |
83 |
84 |
--------------------------------------------------------------------------------
/solver/MMEsolver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .solver import BaseSolver, register_solver
6 |
7 | @register_solver('mme')
8 | class MMESolver(BaseSolver):
9 | """
10 | Implements MME from Semi-supervised Domain Adaptation via Minimax Entropy: https://arxiv.org/abs/1904.06487
11 | """
12 |
13 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs):
14 | super(MMESolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage,
15 | device, cfg, **kwargs)
16 |
17 | def solve(self, epoch):
18 | """
19 | Semisupervised adaptation via MME: XE on labeled source + XE on labeled target + \
20 | adversarial ent. minimization on unlabeled target
21 | """
22 | self.net.train()
23 |
24 | if not self.ada_stage:
25 | src_sup_wt, lambda_unsup = 1.0, 0.1
26 | else:
27 | src_sup_wt, lambda_unsup = self.cfg.ADA.SRC_SUP_WT, self.cfg.ADA.UNSUP_WT
28 |
29 | tgt_sup_iter = iter(self.tgt_sup_loader)
30 |
31 | joint_loader = zip(self.src_loader, self.tgt_unsup_loader) # changed to tgt_loader to be consistent with CLUE implementation
32 | for batch_idx, ((data_s, label_s, _), (data_tu, label_tu, _)) in enumerate(joint_loader):
33 | data_s, label_s = data_s.to(self.device), label_s.to(self.device)
34 | data_tu = data_tu.to(self.device)
35 |
36 | if self.ada_stage:
37 | try:
38 | data_ts, label_ts, _ = next(tgt_sup_iter)
39 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device)
40 | except:
41 | # no labeled target data
42 | try:
43 | tgt_sup_iter = iter(self.tgt_sup_loader)
44 | data_ts, label_ts, _ = next(tgt_sup_iter)
45 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device)
46 | except:
47 | data_ts, label_ts = None, None
48 |
49 | # zero gradients for optimizer
50 | self.tgt_opt.zero_grad()
51 |
52 | # extract features
53 | score_s = self.net(data_s)
54 | xeloss_src = src_sup_wt * nn.CrossEntropyLoss()(score_s, label_s)
55 |
56 | xeloss_tgt = 0
57 | if self.ada_stage and data_ts is not None:
58 | score_ts = self.net(data_ts)
59 | xeloss_tgt = nn.CrossEntropyLoss()(score_ts, label_ts)
60 |
61 | xeloss = xeloss_src + xeloss_tgt
62 |
63 | xeloss.backward()
64 | self.tgt_opt.step()
65 |
66 | # Add adversarial entropy
67 | self.tgt_opt.zero_grad()
68 |
69 | score_tu = self.net(data_tu, reverse_grad=True)
70 | probs_tu = F.softmax(score_tu, dim=1)
71 | loss_adent = lambda_unsup * torch.mean(torch.sum(probs_tu * (torch.log(probs_tu + 1e-5)), 1))
72 | loss_adent.backward()
73 |
74 | self.tgt_opt.step()
75 |
76 |
77 |
--------------------------------------------------------------------------------
/solver/PAAsolver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import copy
5 | import numpy as np
6 |
7 | from dataset.image_list import ImageList
8 | from dataset.transform import rand_transform
9 | from .solver import BaseSolver, register_solver
10 |
11 |
12 | @register_solver('RAA')
13 | class RAASolver(BaseSolver):
14 | """
15 | Implement Random Anchor set Augmentation (RAA)
16 | """
17 |
18 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt,
19 | ada_stage, device, cfg, **kwargs):
20 | super(RAASolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader,
21 | joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs)
22 |
23 | def solve(self, epoch, seq_query_loader):
24 | K = self.cfg.LADA.A_K
25 | th = self.cfg.LADA.A_TH
26 |
27 | # create an anchor set
28 | if len(self.tgt_sup_loader) > 0:
29 | tgt_sup_dataset = self.tgt_sup_loader.dataset
30 | tgt_sup_samples = [tgt_sup_dataset.samples[i] for i in self.tgt_sup_loader.sampler.indices]
31 | seed_dataset = ImageList(tgt_sup_samples, root=tgt_sup_dataset.root, transform=tgt_sup_dataset.transform)
32 | seed_dataset.rand_transform = rand_transform
33 | seed_dataset.rand_num = self.cfg.LADA.A_RAND_NUM
34 | seed_loader = torch.utils.data.DataLoader(seed_dataset, shuffle=True,
35 | batch_size=self.tgt_sup_loader.batch_size, num_workers=self.tgt_sup_loader.num_workers)
36 | seed_idxs = self.tgt_sup_loader.sampler.indices.tolist()
37 | seed_iter = iter(seed_loader)
38 | seed_labels = [seed_dataset.samples[i][1] for i in range(len(seed_dataset))]
39 |
40 | if K > 0:
41 | # build nearest neighbors
42 | self.net.eval()
43 | tgt_idxs = []
44 | tgt_embs = []
45 | tgt_labels = []
46 | tgt_data = []
47 | seq_query_loader = copy.deepcopy(seq_query_loader)
48 | seq_query_loader.dataset.transform = copy.deepcopy(self.tgt_loader.dataset.transform)
49 | with torch.no_grad():
50 | for sample_ in seq_query_loader:
51 | sample = copy.deepcopy(sample_)
52 | del sample_
53 | data, label, idx = sample[0], sample[1], sample[2]
54 | data, label = data.to(self.device), label.to(self.device)
55 | score, emb = self.net(data, with_emb=True)
56 | tgt_embs.append(F.normalize(emb).detach().clone().cpu())
57 | tgt_labels.append(label.cpu())
58 | tgt_idxs.append(idx.cpu())
59 | tgt_data.append(data.cpu())
60 |
61 | tgt_embs = torch.cat(tgt_embs)
62 | tgt_data = torch.cat(tgt_data)
63 | tgt_idxs = torch.cat(tgt_idxs)
64 |
65 | self.net.train()
66 |
67 | src_iter = iter(self.src_loader)
68 | iter_per_epoch = len(self.src_loader)
69 |
70 | for batch_idx in range(iter_per_epoch):
71 | if batch_idx % len(self.src_loader) == 0:
72 | src_iter = iter(self.src_loader)
73 |
74 | data_s, label_s, _ = next(src_iter)
75 | data_s, label_s = data_s.to(self.device), label_s.to(self.device)
76 |
77 | self.tgt_opt.zero_grad()
78 | output_s = self.net(data_s)
79 | loss = nn.CrossEntropyLoss()(output_s, label_s)
80 |
81 | if len(self.tgt_sup_loader) > 0:
82 | try:
83 | data_ts, label_ts, idx_ts, *data_rand_ts = next(seed_iter)
84 | except:
85 | seed_iter = iter(seed_loader)
86 | data_ts, label_ts, idx_ts, *data_rand_ts = next(seed_iter)
87 |
88 |
89 | if len(data_rand_ts)>0:
90 | for i, r_data in enumerate(data_rand_ts):
91 | alpha = 0.2
92 | mask = torch.FloatTensor(np.random.beta(alpha, alpha, size=(data_ts.shape[0], 1, 1, 1)))
93 | data_ts = (data_ts * mask) + (r_data * (1 - mask))
94 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device)
95 | output_ts, emb_ts = self.net(data_ts, with_emb=True)
96 | loss += nn.CrossEntropyLoss()(output_ts, label_ts)
97 | else:
98 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device)
99 | output_ts, emb_ts = self.net(data_ts, with_emb=True)
100 | loss += nn.CrossEntropyLoss()(output_ts, label_ts)
101 |
102 | loss.backward()
103 | self.tgt_opt.step()
104 |
105 | if len(self.tgt_sup_loader) > 0 and K > 0 and len(seed_idxs) < tgt_embs.shape[0]:
106 | nn_idxs = torch.randint(0, tgt_data.shape[0], (data_ts.shape[0],)).to(self.device)
107 |
108 | data_nn = tgt_data[nn_idxs].to(self.device)
109 |
110 | with torch.no_grad():
111 | output_nn, emb_nn = self.net(data_nn, with_emb=True)
112 | prob_nn = torch.softmax(output_nn, dim=-1)
113 | tgt_embs[nn_idxs] = F.normalize(emb_nn).detach().clone().cpu()
114 |
115 | conf_samples = []
116 | conf_idx = []
117 | conf_pl = []
118 | dist = np.eye(prob_nn.shape[-1])[np.array(seed_labels)].sum(0) + 1
119 | sp = 1 - dist / dist.max() + dist.min() / dist.max()
120 |
121 | for i in range(prob_nn.shape[0]):
122 | idx = tgt_idxs[nn_idxs[i]].item()
123 | pl_i = prob_nn[i].argmax(-1).item()
124 | if np.random.random() <= sp[pl_i] and prob_nn[i].max(-1)[0] >= th and idx not in seed_idxs:
125 | conf_samples.append((self.tgt_loader.dataset.samples[idx][0], pl_i))
126 | conf_idx.append(idx)
127 | conf_pl.append(pl_i)
128 |
129 | seed_dataset.add_item(conf_samples)
130 | seed_idxs.extend(conf_idx)
131 | seed_labels.extend(conf_pl)
132 |
133 |
134 | @register_solver('LAA')
135 | class LAASolver(BaseSolver):
136 | """
137 | Local context-aware Anchor set Augmentation (LAA)
138 | """
139 |
140 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt,
141 | ada_stage, device, cfg, **kwargs):
142 | super(LAASolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader,
143 | joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs)
144 |
145 | def solve(self, epoch, seq_query_loader):
146 | K = self.cfg.LADA.A_K
147 | th = self.cfg.LADA.A_TH
148 |
149 | # create an anchor set
150 | if len(self.tgt_sup_loader) > 0:
151 | tgt_sup_dataset = self.tgt_sup_loader.dataset
152 | tgt_sup_samples = [tgt_sup_dataset.samples[i] for i in self.tgt_sup_loader.sampler.indices]
153 | seed_dataset = ImageList(tgt_sup_samples, root=tgt_sup_dataset.root, transform=tgt_sup_dataset.transform)
154 | seed_dataset.rand_transform = rand_transform
155 | seed_dataset.rand_num = self.cfg.LADA.A_RAND_NUM
156 | seed_loader = torch.utils.data.DataLoader(seed_dataset, shuffle=True,
157 | batch_size=self.tgt_sup_loader.batch_size, num_workers=self.tgt_sup_loader.num_workers)
158 | seed_idxs = self.tgt_sup_loader.sampler.indices.tolist()
159 | seed_iter = iter(seed_loader)
160 | seed_labels = [seed_dataset.samples[i][1] for i in range(len(seed_dataset))]
161 |
162 | if K > 0:
163 | # build nearest neighbors
164 | self.net.eval()
165 | tgt_idxs = []
166 | tgt_embs = []
167 | tgt_labels = []
168 | tgt_data = []
169 | seq_query_loader = copy.deepcopy(seq_query_loader)
170 | seq_query_loader.dataset.transform = copy.deepcopy(self.tgt_loader.dataset.transform)
171 | with torch.no_grad():
172 | for sample_ in seq_query_loader:
173 | sample = copy.deepcopy(sample_)
174 | del sample_
175 | data, label, idx = sample[0], sample[1], sample[2]
176 | data, label = data.to(self.device), label.to(self.device)
177 | score, emb = self.net(data, with_emb=True)
178 | tgt_embs.append(F.normalize(emb).detach().clone().cpu())
179 | tgt_labels.append(label.cpu())
180 | tgt_idxs.append(idx.cpu())
181 | tgt_data.append(data.cpu())
182 |
183 | tgt_embs = torch.cat(tgt_embs)
184 | tgt_data = torch.cat(tgt_data)
185 | tgt_idxs = torch.cat(tgt_idxs)
186 |
187 | self.net.train()
188 |
189 | src_iter = iter(self.src_loader)
190 | iter_per_epoch = len(self.src_loader)
191 |
192 | for batch_idx in range(iter_per_epoch):
193 | if batch_idx % len(self.src_loader) == 0:
194 | src_iter = iter(self.src_loader)
195 |
196 | data_s, label_s, _ = next(src_iter)
197 | data_s, label_s = data_s.to(self.device), label_s.to(self.device)
198 |
199 | self.tgt_opt.zero_grad()
200 | output_s = self.net(data_s)
201 | loss = nn.CrossEntropyLoss()(output_s, label_s)
202 |
203 | if len(self.tgt_sup_loader) > 0:
204 | try:
205 | data_ts, label_ts, idx_ts, *data_rand_ts = next(seed_iter)
206 | except:
207 | seed_iter = iter(seed_loader)
208 | data_ts, label_ts, idx_ts, *data_rand_ts = next(seed_iter)
209 |
210 | if len(data_rand_ts) > 0:
211 | for i, r_data in enumerate(data_rand_ts):
212 | alpha = 0.2
213 | mask = torch.FloatTensor(np.random.beta(alpha, alpha, size=(data_ts.shape[0], 1, 1, 1)))
214 | data_ts = (data_ts * mask) + (r_data * (1 - mask))
215 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device)
216 | output_ts, emb_ts = self.net(data_ts, with_emb=True)
217 | loss += nn.CrossEntropyLoss()(output_ts, label_ts)
218 | else:
219 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device)
220 | output_ts, emb_ts = self.net(data_ts, with_emb=True)
221 | loss += nn.CrossEntropyLoss()(output_ts, label_ts)
222 |
223 | loss.backward()
224 | self.tgt_opt.step()
225 |
226 | if len(self.tgt_sup_loader) > 0 and K > 0 and len(seed_idxs) < tgt_embs.shape[0]:
227 | mask = torch.ones(tgt_embs.shape[0])
228 | re_idxs = tgt_idxs[mask == 1]
229 |
230 | sim = F.normalize(emb_ts.cpu()).mm(tgt_embs[re_idxs].transpose(1, 0))
231 | sim_topk, topk = torch.topk(sim, k=K, dim=1)
232 |
233 | rand_nn = torch.randint(0, topk.shape[1], (topk.shape[0], 1))
234 | nn_idxs = torch.gather(topk, dim=-1, index=rand_nn).squeeze(1)
235 | nn_idxs = re_idxs[nn_idxs]
236 |
237 | data_nn = tgt_data[nn_idxs].to(self.device)
238 |
239 | with torch.no_grad():
240 | output_nn, emb_nn = self.net(data_nn, with_emb=True)
241 | prob_nn = torch.softmax(output_nn, dim=-1)
242 | tgt_embs[nn_idxs] = F.normalize(emb_nn).detach().clone().cpu()
243 |
244 | conf_samples = []
245 | conf_idx = []
246 | conf_pl = []
247 | dist = np.eye(prob_nn.shape[-1])[np.array(seed_labels)].sum(0) + 1
248 | dist = dist / dist.max()
249 | sp = 1 - dist / dist.max() + dist.min() / dist.max()
250 |
251 | for i in range(prob_nn.shape[0]):
252 | idx = tgt_idxs[nn_idxs[i]].item()
253 | pl_i = prob_nn[i].argmax(-1).item()
254 | if np.random.random() <= sp[pl_i] and prob_nn[i].max(-1)[0] >= th and idx not in seed_idxs:
255 | conf_samples.append((self.tgt_loader.dataset.samples[idx][0], pl_i))
256 | conf_idx.append(idx)
257 | conf_pl.append(pl_i)
258 |
259 | seed_dataset.add_item(conf_samples)
260 | seed_idxs.extend(conf_idx)
261 | seed_labels.extend(conf_pl)
262 |
263 |
--------------------------------------------------------------------------------
/solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .solver import *
2 | from .MMEsolver import *
3 | from .PAAsolver import *
4 | from .MCCsolver import *
5 | from .CDACsolver import *
--------------------------------------------------------------------------------
/solver/solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .utils import ConditionalEntropyLoss
5 | from model.grl import GradientReverseFunction
6 |
7 |
8 | solvers = {}
9 | def register_solver(name):
10 | def decorator(cls):
11 | solvers[name] = cls
12 | return cls
13 | return decorator
14 |
15 | def get_solver(name, *args, kwargs={}):
16 | solver = solvers[name](*args, **kwargs)
17 | return solver
18 |
19 | class BaseSolver:
20 | """
21 | Base DA solver class
22 | """
23 |
24 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg):
25 | self.net = net
26 | self.src_loader = src_loader
27 | self.tgt_loader = tgt_loader
28 | self.tgt_sup_loader = tgt_sup_loader
29 | self.tgt_unsup_loader = tgt_unsup_loader
30 | self.joint_sup_loader = joint_sup_loader
31 | self.tgt_opt = tgt_opt
32 | self.ada_stage = ada_stage
33 | self.device = device
34 | self.cfg = cfg
35 |
36 | def solve(self, epoch):
37 | pass
38 |
39 | @register_solver('ft_joint')
40 | class JointFTSolver(BaseSolver):
41 | """
42 | Finetune on target labels
43 | """
44 |
45 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs):
46 | super(JointFTSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt,
47 | ada_stage, device, cfg, **kwargs)
48 |
49 | def solve(self, epoch):
50 | """
51 | Finetune on source and target labels jointly
52 | """
53 | self.net.train()
54 | joint_sup_iter = iter(self.joint_sup_loader)
55 |
56 | while True:
57 | try:
58 | data, target, _ = next(joint_sup_iter)
59 | data, target = data.to(self.device), target.to(self.device)
60 | except:
61 | break
62 |
63 | self.tgt_opt.zero_grad()
64 | output = self.net(data)
65 | loss = nn.CrossEntropyLoss()(output, target)
66 |
67 | loss.backward()
68 | self.tgt_opt.step()
69 |
70 |
71 | @register_solver('ft_tgt')
72 | class TargetFTSolver(BaseSolver):
73 | """
74 | Finetune on target labels
75 | """
76 |
77 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs):
78 | super(TargetFTSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt,
79 | ada_stage, device, cfg, **kwargs)
80 |
81 | def solve(self, epoch):
82 | """
83 | Finetune on target labels
84 | """
85 | self.net.train()
86 | if self.ada_stage: tgt_sup_iter = iter(self.tgt_sup_loader)
87 |
88 | while True:
89 | try:
90 | data_t, target_t, _ = next(tgt_sup_iter)
91 | data_t, target_t = data_t.to(self.device), target_t.to(self.device)
92 | except:
93 | break
94 |
95 | self.tgt_opt.zero_grad()
96 | output = self.net(data_t)
97 | loss = nn.CrossEntropyLoss()(output, target_t)
98 | loss.backward()
99 | self.tgt_opt.step()
100 |
101 |
102 | @register_solver('ft')
103 | class FTSolver(BaseSolver):
104 | """
105 | Finetune on source and target labels with separate loaders
106 | """
107 |
108 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs):
109 | super(FTSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage,
110 | device, cfg, **kwargs)
111 |
112 | def solve(self, epoch):
113 | self.net.train()
114 |
115 | if self.ada_stage:
116 | src_sup_wt = 1.0
117 | else:
118 | src_sup_wt = self.cfg.ADA.SRC_SUP_WT
119 |
120 | tgt_sup_wt = self.cfg.ADA.TGT_SUP_WT
121 |
122 | tgt_sup_iter = iter(self.tgt_sup_loader)
123 |
124 | for batch_idx, (data_s, label_s, _) in enumerate(self.src_loader):
125 | data_s, label_s = data_s.to(self.device), label_s.to(self.device)
126 |
127 | if self.ada_stage:
128 | try:
129 | data_ts, label_ts, _ = next(tgt_sup_iter)
130 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device)
131 | except:
132 | # no labeled target data
133 | try:
134 | tgt_sup_iter = iter(self.tgt_sup_loader)
135 | data_ts, label_ts, _ = next(tgt_sup_iter)
136 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device)
137 | except:
138 | data_ts, label_ts = None, None
139 |
140 | # zero gradients for optimizer
141 | self.tgt_opt.zero_grad()
142 |
143 | # extract features
144 | score_s = self.net(data_s)
145 | xeloss_src = src_sup_wt * nn.CrossEntropyLoss()(score_s, label_s)
146 |
147 | xeloss_tgt = 0
148 | if self.ada_stage and data_ts is not None:
149 | score_ts = self.net(data_ts)
150 | xeloss_tgt = tgt_sup_wt * nn.CrossEntropyLoss()(score_ts, label_ts)
151 |
152 | xeloss = xeloss_src + xeloss_tgt
153 | xeloss.backward()
154 | self.tgt_opt.step()
155 |
156 |
157 | @register_solver('dann')
158 | class DANNSolver(BaseSolver):
159 | """
160 | Implements DANN from Unsupervised Domain Adaptation by Backpropagation: https://arxiv.org/abs/1409.7495
161 | """
162 |
163 | def __init__(self, net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt, ada_stage, device, cfg, **kwargs):
164 | super(DANNSolver, self).__init__(net, src_loader, tgt_loader, tgt_sup_loader, tgt_unsup_loader, joint_sup_loader, tgt_opt,
165 | ada_stage, device, cfg, **kwargs)
166 |
167 | def solve(self, epoch, disc, disc_opt):
168 | """
169 | Semi-supervised adaptation via DANN: XE on labeled source + XE on labeled target + \
170 | ent. minimization on target + DANN on source<->target
171 | """
172 | gan_criterion = nn.CrossEntropyLoss()
173 | cent = ConditionalEntropyLoss().to(self.device)
174 |
175 | self.net.train()
176 | disc.train()
177 |
178 | if not self.ada_stage:
179 | src_sup_wt, lambda_unsup, lambda_cent = 1.0, 0.1, 0.01 # Hardcoded for unsupervised DA
180 | else:
181 | src_sup_wt, lambda_unsup, lambda_cent = self.cfg.ADA.SRC_SUP_WT, self.cfg.ADA.UNSUP_WT, self.cfg.ADA.CEN_WT
182 | tgt_sup_iter = iter(self.tgt_sup_loader)
183 |
184 | joint_loader = zip(self.src_loader, self.tgt_loader) # changed to tgt_loader to be consistent with CLUE implementation
185 | for batch_idx, ((data_s, label_s, _), (data_tu, label_tu, _)) in enumerate(joint_loader):
186 | data_s, label_s = data_s.to(self.device), label_s.to(self.device)
187 | data_tu = data_tu.to(self.device)
188 |
189 | if self.ada_stage:
190 | try:
191 | data_ts, label_ts, _ = next(tgt_sup_iter)
192 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device)
193 | except:
194 | # no labeled target data
195 | try:
196 | tgt_sup_iter = iter(self.tgt_sup_loader)
197 | data_ts, label_ts, _ = next(tgt_sup_iter)
198 | data_ts, label_ts = data_ts.to(self.device), label_ts.to(self.device)
199 | except:
200 | data_ts, label_ts = None, None
201 |
202 | # zero gradients for optimizers
203 | self.tgt_opt.zero_grad()
204 | disc_opt.zero_grad()
205 |
206 | # Train with target labels
207 | score_s = self.net(data_s)
208 | xeloss_src = src_sup_wt * nn.CrossEntropyLoss()(score_s, label_s)
209 |
210 | xeloss_tgt = 0
211 | if self.ada_stage and data_ts is not None:
212 | score_ts = self.net(data_ts)
213 | xeloss_tgt = nn.CrossEntropyLoss()(score_ts, label_ts)
214 |
215 | # extract and concat features
216 | score_tu = self.net(data_tu)
217 | f = torch.cat((score_s, score_tu), 0)
218 |
219 | # predict with discriminator
220 | f_rev = GradientReverseFunction.apply(f)
221 | pred_concat = disc(f_rev)
222 |
223 | target_dom_s = torch.ones(len(data_s)).long().to(self.device)
224 | target_dom_t = torch.zeros(len(data_tu)).long().to(self.device)
225 | label_concat = torch.cat((target_dom_s, target_dom_t), 0)
226 |
227 | # compute loss for disciminator
228 | loss_domain = gan_criterion(pred_concat, label_concat)
229 | loss_cent = cent(score_tu)
230 |
231 | loss_final = (xeloss_src + xeloss_tgt) + (lambda_unsup * loss_domain) + (lambda_cent * loss_cent)
232 |
233 | loss_final.backward()
234 |
235 | self.tgt_opt.step()
236 | disc_opt.step()
237 |
238 |
239 |
--------------------------------------------------------------------------------
/solver/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | class ConditionalEntropyLoss(torch.nn.Module):
5 | """
6 | Conditional entropy loss utility class
7 | """
8 | def __init__(self):
9 | super(ConditionalEntropyLoss, self).__init__()
10 |
11 | def forward(self, x):
12 | b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
13 | b = b.sum(dim=1)
14 | return -1.0 * b.mean(dim=0)
15 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | pil_logger = logging.getLogger('PIL')
5 | pil_logger.setLevel(logging.INFO)
6 |
7 | logger_init = False
8 |
9 | def init_logger(_log_file, dir='log/'):
10 | logger = logging.getLogger()
11 | for handler in logger.handlers[:]:
12 | logger.removeHandler(handler)
13 |
14 | logger.setLevel('DEBUG')
15 | BASIC_FORMAT = "%(asctime)s:%(levelname)s:%(message)s"
16 | DATE_FORMAT = '%Y-%m-%d %H.%M.%S'
17 | formatter = logging.Formatter(BASIC_FORMAT, DATE_FORMAT)
18 | chlr = logging.StreamHandler()
19 | chlr.setFormatter(formatter)
20 | logger.addHandler(chlr)
21 |
22 | if _log_file is not None:
23 | if not os.path.exists(dir):
24 | os.makedirs(dir)
25 | log_file = osp.join(dir, _log_file + '.log')
26 | fhlr = logging.FileHandler(log_file)
27 | fhlr.setFormatter(formatter)
28 | logger.addHandler(fhlr)
29 |
30 | global logger_init
31 | logger_init = True
--------------------------------------------------------------------------------
/utils/lr_schedule.py:
--------------------------------------------------------------------------------
1 |
2 | def inv_lr_scheduler(param_lr, optimizer, iter_num, gamma=0.0001,
3 | power=0.75, init_lr=0.001):
4 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
5 | lr = init_lr * (1 + gamma * iter_num) ** (- power)
6 | i = 0
7 | for param_group in optimizer.param_groups:
8 | param_group['lr'] = lr * param_lr[i]
9 | i += 1
10 | return optimizer
11 |
12 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from tqdm import tqdm
6 | import os
7 | # from model import get_model
8 | import torch.optim as optim
9 | # from solver import get_solver
10 | import logging
11 |
12 | def resetRNGseed(seed):
13 | np.random.seed(seed)
14 | random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 |
19 |
20 | def train(model, device, train_loader, optimizer, epoch):
21 | """
22 | Test model on provided data for single epoch
23 | """
24 | model.train()
25 | total_loss, correct = 0.0, 0
26 | for batch_idx, (data, target, _) in enumerate(tqdm(train_loader)):
27 | data, target = data.to(device), target.to(device)
28 | optimizer.zero_grad()
29 | output = model(data)
30 | loss = nn.CrossEntropyLoss()(output, target)
31 | total_loss += loss.item()
32 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
33 | corr = pred.eq(target.view_as(pred)).sum().item()
34 | correct += corr
35 | loss.backward()
36 | optimizer.step()
37 |
38 | train_acc = 100. * correct / len(train_loader.sampler)
39 | avg_loss = total_loss / len(train_loader.sampler)
40 | logging.info('Train Epoch: {} | Avg. Loss: {:.3f} | Train Acc: {:.3f}'.format(epoch, avg_loss, train_acc))
41 | return avg_loss
42 |
43 | def test(model, device, test_loader, split="target test"):
44 | """
45 | Test model on provided data
46 | """
47 | # logging.info('Evaluating model on {}...'.format(split))
48 | model.eval()
49 | test_loss = 0
50 | correct = 0
51 | with torch.no_grad():
52 | for data, target, _ in test_loader:
53 | data, target = data.to(device), target.to(device)
54 | output = model(data)
55 | loss = nn.CrossEntropyLoss()(output, target)
56 | test_loss += loss.item() # sum up batch loss
57 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
58 | corr = pred.eq(target.view_as(pred)).sum().item()
59 | correct += corr
60 | del loss, output
61 |
62 | test_loss /= len(test_loader.sampler)
63 | test_acc = 100. * correct / len(test_loader.sampler)
64 |
65 | return test_acc, test_loss
66 |
67 |
68 | # def run_unsupervised_da(model, src_train_loader, tgt_sup_loader, tgt_unsup_loader, train_idx, num_classes, device,
69 | # cfg):
70 | # """
71 | # Unsupervised adaptation of source model to target at round 0
72 | # Returns:
73 | # Model post adaptation
74 | # """
75 | # source = cfg.DATASET.SOURCE_DOMAIN
76 | # target = cfg.DATASET.TARGET_DOMAIN
77 | # da_strat = cfg.ADA.DA
78 | #
79 | # adapt_dir = os.path.join('checkpoints', 'adapt')
80 | # adapt_net_file = os.path.join(adapt_dir, '{}_{}_{}_{}.pth'.format(da_strat, source, target, cfg.MODEL.BACKBONE.NAME))
81 | #
82 | # if not os.path.exists(adapt_dir):
83 | # os.makedirs(adapt_dir)
84 | #
85 | # if os.path.exists(adapt_net_file):
86 | # logging.info('Found pretrained checkpoint, loading...')
87 | # adapt_model = get_model('AdaptNet', num_cls=num_classes, weights_init=adapt_net_file, model=cfg.MODEL.BACKBONE.NAME)
88 | # else:
89 | # logging.info('No pretrained checkpoint found, training...')
90 | # source_file = '{}_{}_source.pth'.format(source, cfg.MODEL.BACKBONE.NAME)
91 | # source_path = os.path.join('checkpoints', 'source', source_file)
92 | # adapt_model = get_model('AdaptNet', num_cls=num_classes, src_weights_init=source_path, model=cfg.MODEL.BACKBONE.NAME)
93 | # opt_net_tgt = optim.Adadelta(adapt_model.tgt_net.parameters(cfg.OPTIM.UDA_LR, cfg.OPTIM.BASE_LR_MULT), lr=cfg.OPTIM.UDA_LR, weight_decay=0.00001)
94 | # uda_solver = get_solver(da_strat, adapt_model.tgt_net, src_train_loader, tgt_sup_loader, tgt_unsup_loader,
95 | # train_idx, opt_net_tgt, 0, device, cfg)
96 | # for epoch in range(cfg.TRAINER.MAX_UDA_EPOCHS):
97 | # if da_strat == 'dann':
98 | # opt_dis_adapt = optim.Adadelta(adapt_model.discriminator.parameters(), lr=cfg.OPTIM.UDA_LR, weight_decay=0.00001)
99 | # uda_solver.solve(epoch, adapt_model.discriminator, opt_dis_adapt)
100 | # elif da_strat in ['mme', 'ft']:
101 | # uda_solver.solve(epoch)
102 | # adapt_model.save(adapt_net_file)
103 | #
104 | # model, src_model, discriminator = adapt_model.tgt_net, adapt_model.src_net, adapt_model.discriminator
105 | # return model, src_model, discriminator
106 |
107 |
108 | def get_optim(name, *args, **kwargs):
109 | if name == 'Adadelta':
110 | return optim.Adadelta(*args, **kwargs)
111 | elif name == 'Adam':
112 | return optim.Adam(*args, **kwargs)
113 | elif name == 'SGD':
114 | return optim.SGD(*args, **kwargs, momentum=0.9, nesterov=True)
--------------------------------------------------------------------------------