├── Dockerfile ├── README.md ├── code ├── center_loss.py ├── compute_score.py ├── data.py ├── networks.py ├── steps_separation_adaptation.py ├── train.py ├── transformations.py └── utilities.py ├── data ├── amazon_0-9_20-30_test.txt ├── amazon_0-9_train_all.txt ├── art_0-24_train_all.txt ├── art_0-64_test.txt ├── clipart_0-24_train_all.txt ├── clipart_0-64_test.txt ├── dslr_0-9_20-30_test.txt ├── dslr_0-9_train_all.txt ├── product_0-24_train_all.txt ├── product_0-64_test.txt ├── real_world_0-24_train_all.txt ├── real_world_0-64_test.txt ├── webcam_0-9_20-30_test.txt └── webcam_0-9_train_all.txt ├── image.jpg ├── requirements_ROS.txt ├── train_resnet50_office31.sh ├── train_resnet50_officehome.sh └── train_vgg_office31.sh /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch 2 | MAINTAINER Silvia Bucci & Mohammad Reza Loghmani 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | RUN apt-get update -y 5 | RUN apt-get install -y --upgrade vim git 6 | WORKDIR /ROS 7 | COPY . . 8 | RUN pip install --upgrade pip 9 | RUN pip install --upgrade -r requirements_ROS.txt 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ROS (Rotation-based Open Set) 2 | 3 | PyTorch official implementation of "[On the Effectiveness of Image Rotation for Open Set Domain Adaptation](https://arxiv.org/abs/2007.12360)" in *European Conference on Computer Vision 2020,* **ECCV2020** 4 | 5 | ![Test Image 1](image.jpg) 6 | 7 | ## Experiments 8 | In order to replicate the results shown in the paper (Tables 1,2) please follow these instructions: 9 | 10 | 1. Download Office-31 and Office-Home datasets: 11 | 12 | - Office-31: 13 | https://drive.google.com/file/d/0B4IapRTv9pJ1WGZVd1VDMmhwdlE/view 14 | save the folder as "office" in ROS folder. 15 | At the end you should have: 16 | office/amazon, 17 | office/webcam, 18 | office/dslr 19 | 20 | - Office-Home: 21 | https://drive.google.com/file/d/0B81rNlvomiwed0V1YUxQdC1uOTg/view 22 | save the folder as "office-home" in ROS folder. 23 | At the end you should have: 24 | office-home/OfficeHomeDataset_10072016/art, 25 | office-home/OfficeHomeDataset_10072016/clipart, 26 | office-home/OfficeHomeDataset_10072016/real_world, 27 | office-home/OfficeHomeDataset_10072016/product 28 | 29 | 2. Use the python version: Python 3.6.8 30 | 31 | Install all the libreries requested with the command: 32 | 33 | pip3 install -r requirements_ROS.txt 34 | 35 | Please note that, for the sake of convenience, we also provide a Dockerfile to directly create a docker container with all the necessary requirements. 36 | 37 | 3. Go into the folder ROS and: 38 | 39 | 3a. In order to replicate the experiments of Office31 dataset with ResNet-50 (Table 1) run: 40 | 41 | train_resnet50_office31.sh replacing 42 | "/.../" with "/path_in_which_you_save_ROS/" 43 | 44 | 3b. In order to replicate the experiments of Office-Home dataset with ResNet-50 (in Table 2) run: 45 | 46 | train_resnet50_officehome.sh replacing 47 | "/.../" with "/path_in_which_you_save_ROS/" 48 | 49 | 3c. In order to replicate the experiments of Office31 dataset with VggNet (in Table 1) run: 50 | 51 | train_vgg_office31.sh replacing 52 | "/.../" with "/path_in_which_you_save_ROS/" 53 | 54 | 55 | You can also replicate the results obtained for STA_max,STA_sum,OSBP and UAN (Tables 1,2) following the instructions of the GitHub repositories proposed by the authors: 56 | 57 | - STA: https://github.com/thuml/Separate_to_Adapt 58 | 59 | - OSBP: https://github.com/ksaito-ut/OPDA_BP 60 | 61 | - UAN: https://github.com/thuml/Universal-Domain-Adaptation 62 | 63 | ## Citation 64 | 65 | To cite, please use the following reference: 66 | ``` 67 | @inproceedings{BucciLoghmaniTommasi2020, 68 | title={On the Effectiveness of Image Rotation for Open Set Domain Adaptation}, 69 | author={Silvia Bucci, Mohammad Reza Loghmani, Tatiana Tommasi}, 70 | booktitle={European Conference on Computer Vision (ECCV)}, 71 | year={2020} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /code/center_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class CenterLoss(nn.Module): 5 | """Center loss. 6 | 7 | Reference: 8 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 9 | 10 | Args: 11 | num_classes (int): number of classes. 12 | feat_dim (int): feature dimension. 13 | """ 14 | def __init__(self, num_classes=10, feat_dim=2, use_gpu=True,device=None): 15 | super(CenterLoss, self).__init__() 16 | self.num_classes = num_classes 17 | self.feat_dim = feat_dim 18 | self.use_gpu = use_gpu 19 | self.device=device 20 | 21 | if self.use_gpu: 22 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).to(self.device)) 23 | else: 24 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 25 | 26 | def forward(self, x, labels): 27 | """ 28 | Args: 29 | x: feature matrix with shape (batch_size, feat_dim). 30 | labels: ground truth labels with shape (batch_size). 31 | """ 32 | batch_size = x.size(0) 33 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 34 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 35 | distmat.addmm_(1, -2, x, self.centers.t()) 36 | 37 | classes = torch.arange(self.num_classes).long() 38 | if self.use_gpu: classes = classes.to(self.device) 39 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 40 | 41 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 42 | 43 | dist = distmat * mask.float() 44 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 45 | 46 | return loss,self.centers -------------------------------------------------------------------------------- /code/compute_score.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from transformations import * 3 | from utilities import * 4 | from networks import * 5 | import numpy as np 6 | from random import sample, random 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | import sys 10 | import torchvision 11 | from itertools import chain 12 | from skimage.transform import resize 13 | from center_loss import CenterLoss 14 | import math 15 | from sklearn.metrics import roc_auc_score 16 | import random 17 | from sklearn import preprocessing 18 | from itertools import cycle 19 | 20 | def create_txt_target(type_subset,subset_index_high,source,target,folder_txt_files,folder_txt_files_saving,n_classes): 21 | path = folder_txt_files+target+'_test.txt' 22 | new_file = folder_txt_files_saving+source+'_'+target+'_test_'+type_subset+'.txt' 23 | cont = 0 24 | f = open(path, 'r') 25 | list_images = f.readlines() 26 | 27 | w = open(new_file, 'w') 28 | 29 | for i in list_images: 30 | for k in subset_index_high: 31 | if cont==k: 32 | if type_subset is 'low': 33 | words = i.split(' ') 34 | w.write(words[0]+' '+str(n_classes)+'\n') 35 | else: 36 | w.write(i) 37 | cont=cont+1 38 | if type_subset is 'low': 39 | path = folder_txt_files+source+'_train_all.txt' 40 | f = open(path, 'r') 41 | list_images = f.readlines() 42 | for j in list_images: 43 | w.write(j) 44 | 45 | 46 | def compute_scores_all_target(target_test,feature_extractor,discriminator_p,net,vgg,n_classes,ss_classes,device,source,target,folder_txt_files,folder_txt_files_saving): 47 | 48 | all_target_labels = [] 49 | all_target_predictions = [] 50 | len_target = len(target_test) 51 | 52 | with torch.no_grad(): 53 | len_target = len(target_test) 54 | scores = torch.zeros(len_target) 55 | scores_entropy_ss = torch.zeros(len_target) 56 | 57 | if vgg: 58 | target_original = torch.zeros(len_target,4096) 59 | else: 60 | target_original = torch.zeros(len_target,2048) 61 | 62 | for (i, (im_target,label_target)) in enumerate(target_test): 63 | all_target_labels.append(label_target[0].item()) 64 | k_list = torch.zeros(n_classes) 65 | logit = 0 66 | 67 | for j in range(ss_classes): 68 | ss_data_orig = tl.prepro.crop(im_target[0], 224, 224, is_random=False) 69 | if j==0 or j==1 or j==2 or j==3: 70 | ss_data=np.rot90(ss_data_orig,k=j) 71 | 72 | if j==0: 73 | ss_data_orig = np.transpose(ss_data_orig, [2, 0, 1]) 74 | ss_data_orig = np.asarray(ss_data_orig, np.float32) / 255.0 75 | ss_data_orig = torch.from_numpy(ss_data_orig).to(device) 76 | net.eval() 77 | (ft1, _, _, class_label) = net.forward(ss_data_orig) 78 | target_original[i] = ft1 79 | net.train() 80 | 81 | ss_data = np.transpose(ss_data, [2, 0, 1]) 82 | ss_data = np.asarray(ss_data, np.float32) / 255.0 83 | ss_data = torch.from_numpy(ss_data).to(device) 84 | 85 | feature_extractor.eval() 86 | ft_ss= feature_extractor.forward(ss_data) 87 | feature_extractor.train() 88 | 89 | double_input = torch.cat((ft1.cpu(), ft_ss.cpu()), 1) 90 | ft_ss=double_input 91 | ft_ss = ft_ss.to(device) 92 | 93 | discriminator_p.eval() 94 | p0,features = discriminator_p.forward(ft_ss) 95 | discriminator_p.train() 96 | 97 | p0 = nn.Softmax(dim=-1)(p0) 98 | 99 | scores_entropy_ss[i] = scores_entropy_ss[i]+EntropyLoss(p0) 100 | 101 | for k in range(n_classes): 102 | logit = p0[0][(ss_classes*k)+j] 103 | k_list[k]=k_list[k]+logit 104 | 105 | k_list=k_list/ss_classes 106 | #normality score 107 | scores[i] = max(k_list) 108 | #entropy ss score 109 | scores_entropy_ss[i] = scores_entropy_ss[i]/ss_classes 110 | 111 | all_target_labels = np.asarray(all_target_labels) 112 | 113 | #normalization entropy ss score 114 | scores_entropy_ss = (scores_entropy_ss-min(scores_entropy_ss))/(max(scores_entropy_ss)-min(scores_entropy_ss)) 115 | scores_entropy_ss = 1-scores_entropy_ss 116 | 117 | #score = max(entropy,normality) 118 | score_sum = np.maximum(scores_entropy_ss,scores) 119 | 120 | scores_ordered_entropy_ss_sum = (score_sum).argsort().numpy() 121 | scores_entropy_ss_ordered_sum = score_sum[scores_ordered_entropy_ss_sum] 122 | label_target_sorted_entropy_ss_sum = all_target_labels[scores_ordered_entropy_ss_sum] 123 | 124 | scores_for_mean = np.asarray(scores_entropy_ss_ordered_sum) 125 | mean_scores = sum(scores_for_mean)/len(scores_for_mean) 126 | number = int((str(mean_scores).split('.'))[1][0])+1 127 | threshold = float('0.'+str(number)) 128 | num_low=0 129 | for value in scores_for_mean: 130 | if value>threshold: 131 | num_low = num_low+1 132 | select_high=num_low 133 | select_low =num_low 134 | 135 | create_txt_target('high',np.asarray(scores_ordered_entropy_ss_sum)[-select_high:],source,target,folder_txt_files,folder_txt_files_saving,n_classes) 136 | create_txt_target('low',np.asarray(scores_ordered_entropy_ss_sum)[0:select_low],source,target,folder_txt_files,folder_txt_files_saving,n_classes) 137 | 138 | 139 | return select_low -------------------------------------------------------------------------------- /code/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | from random import sample, random 8 | import scipy.io as sio 9 | import codecs 10 | import os 11 | import os.path 12 | 13 | 14 | 15 | 16 | def _dataset_info(txt_labels,folder_dataset): 17 | with open(txt_labels, 'r') as f: 18 | images_list = f.readlines() 19 | 20 | file_names = [] 21 | labels = [] 22 | for row in images_list: 23 | row = row.split(' ') 24 | file_name = folder_dataset+row[0] 25 | file_names.append(file_name) 26 | labels.append(int(row[1])) 27 | 28 | return file_names, labels 29 | 30 | 31 | def get_split_dataset_info(txt_list,folder_dataset): 32 | names, labels = _dataset_info(txt_list,folder_dataset) 33 | return names, labels 34 | 35 | class CustomDataset(data.Dataset): 36 | def __init__(self, names, labels, img_transformer=None,returns=None,is_train=None,ss_classes=None,n_classes=None,only_4_rotations=None,n_classes_target=None): 37 | self.data_path = "" 38 | self.names = names 39 | self.labels = labels 40 | self.N = len(self.names) 41 | self._image_transformer = img_transformer 42 | self.is_train = is_train 43 | self.returns = returns 44 | self.ss_classes = ss_classes 45 | self.n_classes = n_classes 46 | self.only_4_rotations = only_4_rotations 47 | self.n_classes_target = n_classes_target 48 | 49 | def __getitem__(self, index): 50 | framename = self.data_path + '/' + self.names[index] 51 | img = Image.open(framename).convert('RGB') 52 | 53 | if self.returns==3: 54 | data,data_ss,label_ss = self._image_transformer(img,self.labels[index], self.is_train,self.ss_classes,self.n_classes,self.only_4_rotations,self.n_classes_target) 55 | return data,data_ss,label_ss 56 | elif self.returns==4: 57 | data,data_ss,label,label_ss =self._image_transformer(img,self.labels[index], self.is_train,self.ss_classes,self.n_classes,self.only_4_rotations,self.n_classes_target) 58 | return data,data_ss,label,label_ss 59 | elif self.returns==5: 60 | data,data_ss,label,label_ss,label_ss_center = self._image_transformer(img,self.labels[index], self.is_train,self.ss_classes,self.n_classes,self.only_4_rotations,self.n_classes_target) 61 | return data,data_ss,label,label_ss,label_ss_center 62 | elif self.returns==6: 63 | data,data_ss,label,label_ss,label_ss_center,label_object_center = self._image_transformer(img,self.labels[index], self.is_train,self.ss_classes,self.n_classes,self.only_4_rotations,self.n_classes_target) 64 | return data,data_ss,label,label_ss,label_ss_center,label_object_center 65 | elif self.returns==2: 66 | data,label = self._image_transformer(img,self.labels[index], self.is_train,self.ss_classes,self.n_classes,self.only_4_rotations,self.n_classes_target) 67 | return data,label 68 | 69 | def __len__(self): 70 | return len(self.names) -------------------------------------------------------------------------------- /code/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd.variable import * 5 | from torchvision import models 6 | import os 7 | import numpy as np 8 | from utilities import * 9 | import torch.nn.functional as F 10 | import scipy.io as sio 11 | 12 | 13 | class BaseFeatureExtractor(nn.Module): 14 | def forward(self, *input): 15 | pass 16 | 17 | def __init__(self): 18 | super(BaseFeatureExtractor, self).__init__() 19 | def output_num(self): 20 | pass 21 | resnet_dict = {"resnet50":models.resnet50} 22 | 23 | class ResNetFc(BaseFeatureExtractor): 24 | def __init__(self,device, model_name='resnet50', normalize=True): 25 | super(ResNetFc, self).__init__() 26 | self.model_resnet = models.resnet50(pretrained=True) 27 | self.normalize = normalize 28 | self.mean = False 29 | self.std = False 30 | 31 | model_resnet = self.model_resnet 32 | self.conv1 = model_resnet.conv1 33 | self.bn1 = model_resnet.bn1 34 | self.relu = model_resnet.relu 35 | self.maxpool = model_resnet.maxpool 36 | self.layer1 = model_resnet.layer1 37 | self.layer2 = model_resnet.layer2 38 | self.layer3 = model_resnet.layer3 39 | self.layer4 = model_resnet.layer4 40 | self.avgpool = model_resnet.avgpool 41 | self.__in_features = model_resnet.fc.in_features 42 | self.device = device 43 | self.fc = nn.Linear(self.__in_features, self.__in_features) 44 | self.bn_sharedfc = nn.BatchNorm1d(self.__in_features) 45 | self.leaky_relu = nn.LeakyReLU(0.2, inplace=True) 46 | 47 | def get_mean(self): 48 | if self.mean is False: 49 | self.mean = Variable( 50 | torch.from_numpy(np.asarray([0.485, 0.456, 0.406], dtype=np.float32).reshape((1, 3, 1, 1)))).to(self.device) 51 | return self.mean 52 | 53 | def get_std(self): 54 | if self.std is False: 55 | self.std = Variable( 56 | torch.from_numpy(np.asarray([0.229, 0.224, 0.225], dtype=np.float32).reshape((1, 3, 1, 1)))).to(self.device) 57 | return self.std 58 | 59 | def forward(self, x): 60 | if self.normalize: 61 | x = (x - self.get_mean()) / self.get_std() 62 | x = self.conv1(x) 63 | x = self.bn1(x) 64 | x = self.relu(x) 65 | x = self.maxpool(x) 66 | x = self.layer1(x) 67 | x = self.layer2(x) 68 | x = self.layer3(x) 69 | x = self.layer4(x) 70 | x = self.avgpool(x) 71 | x = x.view(x.size(0), -1) 72 | return x 73 | 74 | def output_num(self): 75 | 76 | return self.__in_features,1024 77 | 78 | 79 | class VGGFc(BaseFeatureExtractor): 80 | def __init__(self,device, model_name='vgg19',normalize=True): 81 | super(VGGFc, self).__init__() 82 | self.model_vgg = models.vgg19(pretrained=True) 83 | 84 | self.normalize = normalize 85 | self.mean = False 86 | self.std = False 87 | model_vgg = self.model_vgg 88 | mod = list(model_vgg.features.children()) 89 | self.features = nn.Sequential(*mod) 90 | mod2 = list(model_vgg.classifier.children())[:-1] 91 | self.classifier = nn.Sequential(*mod2) 92 | self.__in_features = 4096 93 | self.device = device 94 | 95 | def get_mean(self): 96 | if self.mean is False: 97 | self.mean = Variable( 98 | torch.from_numpy(np.asarray([0.485, 0.456, 0.406], dtype=np.float32).reshape((1, 3, 1, 1)))).to(self.device) 99 | return self.mean 100 | 101 | def get_std(self): 102 | if self.std is False: 103 | self.std = Variable( 104 | torch.from_numpy(np.asarray([0.229, 0.224, 0.225], dtype=np.float32).reshape((1, 3, 1, 1)))).to(self.device) 105 | return self.std 106 | 107 | def forward(self, x): 108 | 109 | if self.normalize: 110 | x = (x - self.get_mean()) / self.get_std() 111 | x = self.features(x) 112 | x = x.view(x.size(0), 512 * 7 * 7) 113 | x = self.classifier(x) 114 | 115 | return x 116 | 117 | def output_num(self): 118 | 119 | return self.__in_features 120 | 121 | class CLS(nn.Module): 122 | def __init__(self, in_dim, out_dim, bottle_neck_dim=256,vgg=None): 123 | super(CLS, self).__init__() 124 | self.vgg = vgg 125 | if bottle_neck_dim: 126 | if not vgg: 127 | self.bottleneck = nn.Linear(in_dim[0], bottle_neck_dim) 128 | self.fc = nn.Linear(bottle_neck_dim, out_dim) 129 | self.main = nn.Sequential( 130 | self.bottleneck, 131 | nn.Sequential( 132 | nn.BatchNorm1d(bottle_neck_dim), 133 | nn.LeakyReLU(0.2, inplace=True), 134 | self.fc 135 | ), 136 | nn.Softmax(dim=-1) 137 | ) 138 | else: 139 | self.bottleneck = nn.Linear(in_dim, bottle_neck_dim) 140 | self.fc = nn.Linear(bottle_neck_dim, out_dim) 141 | self.main = nn.Sequential( 142 | self.bottleneck, 143 | nn.Sequential( 144 | nn.BatchNorm1d(bottle_neck_dim), 145 | nn.LeakyReLU(0.2, inplace=True), 146 | self.fc 147 | ), 148 | nn.Softmax(dim=-1) 149 | ) 150 | 151 | def forward(self, x): 152 | 153 | out_last = [x] 154 | x_last = x 155 | for module in self.main.children(): 156 | x_last = module(x_last) 157 | out_last.append(x_last) 158 | 159 | return out_last 160 | 161 | 162 | class Discriminator(nn.Module): 163 | def __init__(self, n=None,n_s = None,vgg=None): 164 | super(Discriminator, self).__init__() 165 | self.n = n 166 | def f(): 167 | if vgg: 168 | return nn.Sequential( 169 | nn.Linear(4096*2, 256), 170 | nn.BatchNorm1d(256), 171 | nn.LeakyReLU(0.2, inplace=True), 172 | nn.Linear(256, n_s)) 173 | else: 174 | return nn.Sequential( 175 | nn.Linear(2048*2, 256), 176 | nn.BatchNorm1d(256), 177 | nn.LeakyReLU(0.2, inplace=True), 178 | nn.Linear(256, n_s)) 179 | 180 | def f_feat(): 181 | if vgg: 182 | return nn.Sequential( 183 | nn.Linear(4096*2, 256), 184 | nn.BatchNorm1d(256), 185 | nn.LeakyReLU(0.2, inplace=True)) 186 | else: 187 | return nn.Sequential( 188 | nn.Linear(2048*2, 256), 189 | nn.BatchNorm1d(256), 190 | nn.LeakyReLU(0.2, inplace=True)) 191 | 192 | for i in range(n): 193 | self.__setattr__('discriminator_%04d'%i, f()) 194 | self.__setattr__('discriminator_feat_%04d'%i, f_feat()) 195 | 196 | def forward(self, x): 197 | 198 | outs = [self.__getattr__('discriminator_%04d'%i)(x) for i in range(self.n)] 199 | outs_feat = [self.__getattr__('discriminator_feat_%04d'%i)(x) for i in range(self.n)] 200 | 201 | return torch.cat(outs, dim=-1),torch.cat(outs_feat, dim=-1) -------------------------------------------------------------------------------- /code/steps_separation_adaptation.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from transformations import * 3 | from utilities import * 4 | from networks import * 5 | import numpy as np 6 | from random import sample, random 7 | from PIL import Image 8 | import torchvision.transforms as transforms 9 | import sys 10 | import torchvision 11 | from itertools import chain 12 | from skimage.transform import resize 13 | from center_loss import CenterLoss 14 | import math 15 | from sklearn.metrics import roc_auc_score 16 | import random 17 | from sklearn import preprocessing 18 | from itertools import cycle 19 | from compute_score import compute_scores_all_target 20 | 21 | 22 | def skip(data, label, is_train): 23 | return False 24 | 25 | 26 | class Trainer: 27 | def __init__(self, args, device,rand): 28 | self.args = args 29 | self.device = device 30 | self.source = self.args.source 31 | self.target = self.args.target 32 | self.batch_size = self.args.batch_size 33 | self.learning_rate = self.args.learning_rate 34 | self.epochs_step1 = self.args.epochs_step1 35 | self.epochs_step2 = self.args.epochs_step2 36 | self.n_classes = self.args.n_classes 37 | self.n_classes_target = self.args.n_classes_target 38 | self.ss_classes = self.args.ss_classes 39 | self.cls_weight_source = self.args.cls_weight_source 40 | self.ss_weight_target = self.args.ss_weight_target 41 | self.ss_weight_source = self.args.ss_weight_source 42 | self.entropy_weight= self.args.entropy_weight 43 | self.folder_dataset = self.args.folder_dataset 44 | self.folder_name = self.args.folder_name 45 | self.folder_txt_files = self.args.folder_txt_files 46 | self.folder_txt_files_saving = self.args.folder_txt_files_saving 47 | self.folder_log = self.args.folder_log 48 | self.divison_learning_rate_backbone = self.args.divison_learning_rate_backbone 49 | self.only_4_rotations = self.args.only_4_rotations 50 | self.use_weight_net_first_part = self.args.use_weight_net_first_part 51 | self.weight_class_unknown = self.args.weight_class_unknown 52 | self.weight_center_loss = self.args.weight_center_loss 53 | self.use_VGG = self.args.use_VGG 54 | self.n_workers = self.args.n_workers 55 | 56 | 57 | def _do_train(self): 58 | 59 | # STEP 1 ------------------------------------------------------------------------------------- 60 | 61 | #data------------------------------------------------------------------------------------- 62 | 63 | torch.backends.cudnn.benchmark 64 | 65 | if self.use_VGG: 66 | feature_extractor = VGGFc(self.device,model_name='vgg19') 67 | else: 68 | feature_extractor = ResNetFc(self.device,model_name='resnet50') 69 | 70 | #### source on which perform training of cls and self-sup task 71 | images,labels = get_split_dataset_info(self.folder_txt_files+self.source+'_train_all.txt',self.folder_dataset) 72 | ds_source_ss = CustomDataset(images,labels,img_transformer=transform_source_ss,returns=6,is_train=True,ss_classes=self.ss_classes,n_classes=self.n_classes,only_4_rotations=self.only_4_rotations,n_classes_target=self.n_classes_target) 73 | source_train_ss = torch.utils.data.DataLoader(ds_source_ss, batch_size=self.batch_size, shuffle=True, num_workers=self.n_workers, pin_memory=True, drop_last=True) 74 | 75 | images,labels = get_split_dataset_info(self.folder_txt_files+self.target+'_test.txt',self.folder_dataset) 76 | ds_target_train = CustomDataset(images,labels,img_transformer=transform_target_train,returns=2,is_train=True,ss_classes=self.ss_classes,n_classes=self.n_classes,only_4_rotations=self.only_4_rotations,n_classes_target=self.n_classes_target) 77 | target_train = torch.utils.data.DataLoader(ds_target_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_workers, pin_memory=True, drop_last=True) 78 | 79 | #### target on which compute the scores to select highest batch (integrate to the learning of ss task) and lower batch (integrate to the learning of cls task for the class unknown) 80 | images,labels = get_split_dataset_info(self.folder_txt_files+self.target+'_test.txt',self.folder_dataset) 81 | ds_target_test_for_scores = CustomDataset(images,labels,img_transformer=transform_target_test_for_scores,returns=2,is_train=False,ss_classes=self.ss_classes,n_classes=self.n_classes,only_4_rotations=self.only_4_rotations,n_classes_target=self.n_classes_target) 82 | target_test_for_scores = torch.utils.data.DataLoader(ds_target_test_for_scores, batch_size=1, shuffle=False, num_workers=self.n_workers, pin_memory=True, drop_last=False) 83 | 84 | 85 | #### target for the final evaluation 86 | images,labels = get_split_dataset_info(self.folder_txt_files+self.target+'_test.txt',self.folder_dataset) 87 | ds_target_test = CustomDataset(images,labels,img_transformer=transform_target_test,returns=2,is_train=False,ss_classes=self.ss_classes,n_classes=self.n_classes,only_4_rotations=self.only_4_rotations,n_classes_target=self.n_classes_target) 88 | target_test = torch.utils.data.DataLoader(ds_target_test, batch_size=1, shuffle=False, num_workers=self.n_workers, pin_memory=True, drop_last=False) 89 | 90 | # network ----------------------------------------------------------------------------------------------- 91 | if self.only_4_rotations: 92 | discriminator_p = Discriminator(n = 1,n_s = self.ss_classes,vgg=self.use_VGG) 93 | else: 94 | discriminator_p = Discriminator(n = self.n_classes,n_s = self.ss_classes,vgg=self.use_VGG) 95 | 96 | cls = CLS(feature_extractor.output_num(), self.n_classes+1, bottle_neck_dim=256,vgg=self.use_VGG) 97 | 98 | discriminator_p.to(self.device) 99 | feature_extractor.to(self.device) 100 | cls.to(self.device) 101 | 102 | net = nn.Sequential(feature_extractor, cls).to(self.device) 103 | 104 | center_loss = CenterLoss(num_classes=self.ss_classes*self.n_classes, feat_dim=256*self.n_classes, use_gpu=True,device=self.device) 105 | if self.use_VGG: 106 | center_loss_object = CenterLoss(num_classes=self.n_classes, feat_dim=4096, use_gpu=True,device=self.device) 107 | else: 108 | center_loss_object = CenterLoss(num_classes=self.n_classes, feat_dim=2048, use_gpu=True,device=self.device) 109 | 110 | # scheduler, optimizer --------------------------------------------------------- 111 | max_iter = int(self.epochs_step1*len(source_train_ss)) 112 | scheduler = lambda step, initial_lr : inverseDecaySheduler(step, initial_lr, gamma=10, power=0.75, max_iter=max_iter) 113 | 114 | params = list(discriminator_p.parameters()) 115 | if self.weight_center_loss>0: 116 | params = params+ list(center_loss.parameters()) 117 | 118 | optimizer_discriminator_p = OptimWithSheduler(optim.SGD(params, lr=self.learning_rate, weight_decay=5e-4, momentum=0.9, nesterov=True),scheduler) 119 | 120 | if not self.use_VGG: 121 | for name,param in feature_extractor.named_parameters(): 122 | words= name.split('.') 123 | if words[1] =='layer4': 124 | param.requires_grad = True 125 | else: 126 | param.requires_grad = False 127 | 128 | params_cls = list(cls.parameters()) 129 | optimizer_cls = OptimWithSheduler(optim.SGD([{'params': params_cls},{'params': feature_extractor.parameters(), 'lr': (self.learning_rate/self.divison_learning_rate_backbone)}], lr=self.learning_rate, weight_decay=5e-4, momentum=0.9, nesterov=True),scheduler) 130 | 131 | else: 132 | for name,param in feature_extractor.named_parameters(): 133 | words= name.split('.') 134 | if words[1] =='classifier': 135 | param.requires_grad = True 136 | else: 137 | param.requires_grad = False 138 | params_cls = list(cls.parameters()) 139 | optimizer_cls = OptimWithSheduler(optim.SGD([{'params': params_cls},{'params': feature_extractor.parameters(), 'lr': (self.learning_rate/self.divison_learning_rate_backbone)}], lr=self.learning_rate, weight_decay=5e-4, momentum=0.9, nesterov=True),scheduler) 140 | 141 | 142 | log = Logger(self.folder_log+'/step', clear=True) 143 | target_train = cycle(target_train) 144 | 145 | k=0 146 | print('\n') 147 | print('Separation known/unknown phase------------------------------------------------------------------------------------------') 148 | print('\n') 149 | 150 | while k 0: 196 | for param in center_loss.parameters(): 197 | param.grad.data *= (1./self.weight_center_loss) 198 | 199 | 200 | log.step += 1 201 | 202 | k += 1 203 | counter = AccuracyCounter() 204 | counter.addOntBatch(variable_to_numpy(predict_prob_source), variable_to_numpy(label_source)) 205 | acc_train = torch.from_numpy(np.asarray([counter.reportAccuracy()], dtype=np.float32)).to(self.device) 206 | counter_ss = AccuracyCounter() 207 | counter_ss.addOntBatch(variable_to_numpy(p0), variable_to_numpy(label_source_ss)) 208 | acc_train_rot = torch.from_numpy(np.asarray([counter_ss.reportAccuracy()], dtype=np.float32)).to(self.device) 209 | track_scalars(log, ['loss_object_class', 'acc_train', 'loss_rotation','acc_train_rot','loss_center'],globals()) 210 | 211 | select_low = compute_scores_all_target(target_test_for_scores,feature_extractor,discriminator_p,net,self.use_VGG,self.n_classes,self.ss_classes,self.device,self.source,self.target,self.folder_txt_files,self.folder_txt_files_saving) 212 | 213 | # ========================= Add target samples to cls and discriminator_p classifiers in function of the score 214 | #data--------------------------------------------------------------------------------------------------------------- 215 | self.only_4_rotations = True 216 | 217 | images,labels = get_split_dataset_info(self.folder_txt_files_saving+self.source+'_'+self.target+'_test_high.txt',self.folder_dataset) 218 | ds_target_high = CustomDataset(images,labels,img_transformer=transform_target_ss_step2,returns=3,is_train=True,ss_classes=self.ss_classes,n_classes=self.n_classes,only_4_rotations=self.only_4_rotations,n_classes_target=self.n_classes_target) 219 | target_train_high = torch.utils.data.DataLoader(ds_target_high, batch_size=self.batch_size, shuffle=True, num_workers=self.n_workers, pin_memory=True, drop_last=True) 220 | 221 | 222 | images,labels = get_split_dataset_info(self.folder_txt_files+self.target+'_test.txt',self.folder_dataset) 223 | ds_target = CustomDataset(images,labels,img_transformer=transform_target_ss_step2,returns=3,is_train=True,ss_classes=self.ss_classes,n_classes=self.n_classes,only_4_rotations=self.only_4_rotations,n_classes_target=self.n_classes_target) 224 | target_train = torch.utils.data.DataLoader(ds_target, batch_size=self.batch_size, shuffle=True, num_workers=self.n_workers, pin_memory=True, drop_last=True) 225 | 226 | images,labels = get_split_dataset_info(self.folder_txt_files_saving+self.source+'_'+self.target+'_test_low.txt',self.folder_dataset) 227 | ds_target_low = CustomDataset(images,labels,img_transformer=transform_source_ss_step2,returns=6,is_train=True,ss_classes=self.ss_classes,n_classes=self.n_classes,only_4_rotations=self.only_4_rotations,n_classes_target=self.n_classes_target) 228 | target_train_low = torch.utils.data.DataLoader(ds_target_low, batch_size=self.batch_size, shuffle=True, num_workers=self.n_workers, pin_memory=True, drop_last=True) 229 | 230 | # network -------------------------------------------------------------------------------------------------------------------------- 231 | discriminator_p = Discriminator(n = 1,n_s = self.ss_classes,vgg=self.use_VGG) 232 | discriminator_p.to(self.device) 233 | 234 | if not self.use_weight_net_first_part: 235 | if self.use_VGG: 236 | feature_extractor = VGGFc(self.device,model_name='vgg19') 237 | else: 238 | feature_extractor = ResNetFc(self.device,model_name='resnet50') 239 | cls = CLS(feature_extractor.output_num(), self.n_classes+1, bottle_neck_dim=256,vgg=self.use_VGG) 240 | feature_extractor.to(self.device) 241 | cls.to(self.device) 242 | net = nn.Sequential(feature_extractor, cls).to(self.device) 243 | 244 | if len(target_train_low) >= len(target_train_high): 245 | length = len(target_train_low) 246 | else: 247 | length = len(target_train_high) 248 | 249 | max_iter = int(self.epochs_step2*length) 250 | 251 | scheduler = lambda step, initial_lr : inverseDecaySheduler(step, initial_lr, gamma=10, power=0.75, max_iter=max_iter) 252 | params = list(discriminator_p.parameters()) 253 | 254 | optimizer_discriminator_p = OptimWithSheduler(optim.SGD(params, lr=self.learning_rate, weight_decay=5e-4, momentum=0.9, nesterov=True),scheduler) 255 | 256 | if not self.use_VGG: 257 | for name,param in feature_extractor.named_parameters(): 258 | words= name.split('.') 259 | if words[1] =='layer4': 260 | param.requires_grad = True 261 | else: 262 | param.requires_grad = False 263 | 264 | params_cls = list(cls.parameters()) 265 | optimizer_cls = OptimWithSheduler(optim.SGD([{'params': params_cls},{'params': feature_extractor.parameters(), 'lr': (self.learning_rate/self.divison_learning_rate_backbone)}], lr=self.learning_rate, weight_decay=5e-4, momentum=0.9, nesterov=True),scheduler) 266 | 267 | else: 268 | for name,param in feature_extractor.named_parameters(): 269 | words= name.split('.') 270 | if words[1] =='classifier': 271 | param.requires_grad = True 272 | else: 273 | param.requires_grad = False 274 | params_cls = list(cls.parameters()) 275 | optimizer_cls = OptimWithSheduler(optim.SGD([{'params': params_cls},{'params': feature_extractor.parameters(), 'lr': (self.learning_rate/self.divison_learning_rate_backbone)}], lr=self.learning_rate, weight_decay=5e-4, momentum=0.9, nesterov=True),scheduler) 276 | 277 | k=0 278 | print('\n') 279 | print('Adaptation phase--------------------------------------------------------------------------------------------------------') 280 | print('\n') 281 | ss_weight_target = self.ss_weight_target 282 | weight_class_unknown = 1/(select_low*(self.n_classes/(len(source_train_ss)*self.batch_size))) 283 | 284 | while k len(target_train_high): 289 | num_iterations = len(target_train_low) 290 | num_iterations_smaller = len(target_train_high) 291 | target_train_low_iter = iter(target_train_low) 292 | target_train_high_iter = cycle(target_train_high) 293 | else: 294 | num_iterations = len(target_train_high) 295 | num_iterations_smaller = len(target_train_low) 296 | target_train_low_iter = cycle(target_train_low) 297 | target_train_high_iter = iter(target_train_high) 298 | 299 | for i in range(num_iterations): 300 | 301 | global entropy_loss 302 | 303 | (im_target_entropy,_,_) = next(iteration) 304 | (im_source,im_source_ss,label_source,label_source_ss,_,_) = next(target_train_low_iter) 305 | (im_target,im_target_ss,label_target_ss) = next(target_train_high_iter) 306 | 307 | im_source = im_source.to(self.device) 308 | im_source_ss = im_source_ss.to(self.device) 309 | label_source = label_source.to(self.device) 310 | label_source_ss = label_source_ss.to(self.device) 311 | im_target = im_target.to(self.device) 312 | im_target_ss = im_target_ss.to(self.device) 313 | label_target_ss = label_target_ss.to(self.device) 314 | im_target_entropy = im_target_entropy.to(self.device) 315 | 316 | 317 | ft1_ss = feature_extractor.forward(im_target_ss) 318 | ft1_original = feature_extractor.forward(im_target) 319 | double_input_t = torch.cat((ft1_original, ft1_ss), 1) 320 | ft1_ss=double_input_t 321 | 322 | (_, _, _, predict_prob_source) = net.forward(im_source) 323 | 324 | (_ ,_, _, _) = net.forward(im_target_entropy) 325 | (_, _, _, predict_prob_target) = net.forward(im_target) 326 | 327 | p0_t,_ = discriminator_p.forward(ft1_ss) 328 | p0_t = nn.Softmax(dim=-1)(p0_t) 329 | 330 | # =========================loss function 331 | class_weight = np.ones((self.n_classes+1),dtype=np.dtype('f')) 332 | class_weight[self.n_classes]= weight_class_unknown*self.weight_class_unknown 333 | class_weight = (torch.from_numpy(class_weight)).to(self.device) 334 | ce = CrossEntropyLoss(label_source, predict_prob_source,class_weight) 335 | 336 | entropy = EntropyLoss(predict_prob_target) 337 | d1_t = CrossEntropyLoss(label_target_ss,p0_t) 338 | 339 | with OptimizerManager([optimizer_cls, optimizer_discriminator_p]): 340 | loss_object_class = self.cls_weight_source*ce 341 | loss_rotation = ss_weight_target*d1_t 342 | entropy_loss = self.entropy_weight*entropy 343 | 344 | loss = loss_object_class + loss_rotation + entropy_loss 345 | loss.backward() 346 | log.step += 1 347 | 348 | k += 1 349 | counter = AccuracyCounter() 350 | counter.addOntBatch(variable_to_numpy(predict_prob_source), variable_to_numpy(label_source)) 351 | acc_train = torch.from_numpy(np.asarray([counter.reportAccuracy()], dtype=np.float32)).to(self.device) 352 | 353 | counter_ss = AccuracyCounter() 354 | counter_ss.addOntBatch(variable_to_numpy(p0_t), variable_to_numpy(label_target_ss)) 355 | acc_train_rot = torch.from_numpy(np.asarray([counter_ss.reportAccuracy()], dtype=np.float32)).to(self.device) 356 | track_scalars(log, ['loss_object_class', 'acc_train', 'loss_rotation', 'acc_train_rot','entropy_loss'], globals()) 357 | 358 | global predict_prob 359 | global label 360 | global predict_index 361 | 362 | # =================================evaluation 363 | if k%10==0 or k==(self.epochs_step2): 364 | with TrainingModeManager([feature_extractor, cls], train=False) as mgr, Accumulator(['predict_prob','predict_index', 'label']) as accumulator: 365 | for (i, (im, label)) in enumerate(target_test): 366 | with torch.no_grad(): 367 | im = im.to(self.device) 368 | label = label.to(self.device) 369 | (ss, fs,_, predict_prob) = net.forward(im) 370 | predict_prob,label = [variable_to_numpy(x) for x in (predict_prob,label)] 371 | label = np.argmax(label, axis=-1).reshape(-1, 1) 372 | predict_index = np.argmax(predict_prob, axis=-1).reshape(-1, 1) 373 | accumulator.updateData(globals()) 374 | 375 | for x in accumulator.keys(): 376 | globals()[x] = accumulator[x] 377 | y_true = label.flatten() 378 | y_pred = predict_index.flatten() 379 | m = extended_confusion_matrix(y_true, y_pred, true_labels=range(self.n_classes_target), pred_labels=range(self.n_classes+1)) 380 | 381 | cm = m 382 | cm = cm.astype(np.float) / np.sum(cm, axis=1, keepdims=True) 383 | acc_os_star = sum([cm[i][i] for i in range(self.n_classes)]) / (self.n_classes) 384 | unkn = sum([cm[i][self.n_classes] for i in range(self.n_classes,self.n_classes_target)]) / (self.n_classes_target - (self.n_classes)) 385 | acc_os = (acc_os_star * (self.n_classes) + unkn) / (self.n_classes+1) 386 | hos = (2*acc_os_star*unkn)/(acc_os_star+unkn) 387 | print('os',acc_os) 388 | print('os*', acc_os_star) 389 | print('unkn',unkn) 390 | print('hos',hos) 391 | 392 | net.train() 393 | 394 | #torch.save(net.state_dict(),self.folder_name+'/model.pkl') -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from steps_separation_adaptation import Trainer 4 | import numpy as np 5 | import torch 6 | import os 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser(description="Script to launch training",formatter_class=argparse.ArgumentDefaultsHelpFormatter) 11 | 12 | #domains 13 | parser.add_argument("--source", help="Source") 14 | parser.add_argument("--target", help="Target") 15 | 16 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size") 17 | parser.add_argument("--learning_rate", type=float, default=0.003, help="Learning rate") 18 | parser.add_argument("--divison_learning_rate_backbone", type=float, default=10.0, help="Scaling factor of the learning rate used for the part pf the backbone not freezed") 19 | 20 | #epochs step1 and step2 21 | parser.add_argument("--epochs_step1", type=int, default=80, help="Epochs of step1") 22 | parser.add_argument("--epochs_step2", type=int, default=80,help="Epochs of step2") 23 | 24 | #number of classes: known, unknown and the classes of self-sup task 25 | parser.add_argument("--n_classes", type=int, default=25, help="Number of classes of source domain -- known classes") 26 | parser.add_argument("--n_classes_target", type=int, default=65,help="Number of classes of target domain -- known+unknown classes") 27 | parser.add_argument("--ss_classes", "-rc", type=int, default=4, help="Number of classes for the self-supervised task") 28 | 29 | #weights used during training 30 | parser.add_argument("--ss_weight_source", type=float, default=3.0, help="Weight of the source domain for the ss task (it acts in step1)") 31 | parser.add_argument("--ss_weight_target", type=float, default=3.0, help="Weight of the target domain for the ss task (it acts in step2)") 32 | parser.add_argument("--cls_weight_source", type=float, default=1.0, help="Weight for the cls task (it acts in step1 and step2)") 33 | parser.add_argument("--entropy_weight", type=float, default=0.1, help="Weight for the ss task (it acts in step2)") 34 | parser.add_argument("--weight_center_loss", type=float, default=0.0, help="Weight of the center loss for the ss task (it acts in step1)") 35 | parser.add_argument("--weight_class_unknown", type=float, default=1.0, help="Power of learning of the unknown class (it acts in step2)") 36 | 37 | #path of the folders used 38 | parser.add_argument("--folder_dataset",default=None, help="Path to the dataset") 39 | parser.add_argument("--folder_txt_files", default='/.../ROS/data/',help="Path to the txt files of the dataset") 40 | parser.add_argument("--folder_txt_files_saving", default='/.../ROS/data/',help="Path where to save the new txt files") 41 | parser.add_argument("--folder_log", default=None, help="Path of the log folder") 42 | 43 | #to select gpu/num of workers 44 | parser.add_argument("--gpu", type=int, default=0, help="gpu chosen for the training") 45 | parser.add_argument("--n_workers", type=int, default=4, help="num of worker used") 46 | 47 | parser.add_argument("--use_VGG", action='store_true', default=False, help="If use VGG") 48 | parser.add_argument("--use_weight_net_first_part", action='store_true', default=False, help="If use the weight computed in the step1 for step2") 49 | parser.add_argument("--only_4_rotations", action='store_true', default=False,help="If not use rotation for class") 50 | return parser.parse_args() 51 | 52 | 53 | args = get_args() 54 | 55 | orig_stdout = sys.stdout 56 | rand = np.random.randint(200000) 57 | 58 | words = args.folder_txt_files.split('/ROS/') 59 | args.folder_log = words[0]+'/'+'ROS/outputs/logs/' + str(rand) 60 | args.folder_name = words[0]+'/'+'ROS/outputs/' + str(rand) 61 | args.folder_txt_files_saving = args.folder_txt_files + str(rand) 62 | 63 | gpu = str(args.gpu) 64 | device = torch.device("cuda:"+gpu) 65 | 66 | if not os.path.exists(args.folder_name): 67 | os.makedirs(args.folder_name) 68 | 69 | print('\n') 70 | print('TRAIN START!') 71 | print('\n') 72 | print('THE OUTPUT IS SAVED IN A TXT FILE HERE -------------------------------------------> ', args.folder_name) 73 | print('\n') 74 | 75 | f = open(args.folder_name + '/out.txt', 'w') 76 | sys.stdout = f 77 | print("\n%s to %s - %d ss classes" % (args.source, args.target, args.ss_classes)) 78 | 79 | trainer = Trainer(args, device, rand) 80 | trainer._do_train() 81 | 82 | print(args) 83 | sys.stdout = orig_stdout 84 | f.close() 85 | -------------------------------------------------------------------------------- /code/transformations.py: -------------------------------------------------------------------------------- 1 | from data import * 2 | from utilities import * 3 | from networks import * 4 | import numpy as np 5 | from random import sample, random 6 | from PIL import Image 7 | import torchvision.transforms as transforms 8 | import sys 9 | import torchvision 10 | from itertools import chain 11 | from skimage.transform import resize 12 | from center_loss import CenterLoss 13 | import math 14 | from sklearn.metrics import roc_auc_score 15 | import random 16 | from sklearn import preprocessing 17 | from itertools import cycle 18 | from scipy.misc import imread, imresize 19 | 20 | 21 | def transform_target_train(data, label, is_train,ss_classes,n_classes,only_4_rotations,n_classes_target): 22 | data = imresize(data, (256,256)) 23 | original_image = tl.prepro.crop(data, 224, 224, is_random=is_train) 24 | original_image = np.transpose(original_image, [2, 0, 1]) 25 | original_image = np.asarray(original_image, np.float32) / 255.0 26 | label = one_hot(n_classes_target, label) 27 | return original_image, label 28 | 29 | def transform_source_ss(data, label, is_train,ss_classes,n_classes,only_4_rotations,n_classes_target): 30 | ss_transformation = np.random.randint(ss_classes) 31 | data = imresize(data, (256,256)) 32 | original_image = tl.prepro.crop(data, 224, 224, is_random=is_train) 33 | data = tl.prepro.crop(data, 224, 224, is_random=is_train) 34 | 35 | if ss_transformation==0: 36 | ss_data=data 37 | if ss_transformation==1: 38 | ss_data=np.rot90(data,k=1) 39 | if ss_transformation==2: 40 | ss_data=np.rot90(data,k=2) 41 | if ss_transformation==3: 42 | ss_data=np.rot90(data,k=3) 43 | 44 | if only_4_rotations: 45 | ss_label = one_hot(ss_classes,ss_transformation) 46 | label_ss_center = ss_transformation 47 | else: 48 | ss_label = one_hot(ss_classes*n_classes,(ss_classes*label)+ss_transformation) 49 | label_ss_center = (ss_classes*label)+ss_transformation 50 | 51 | ss_data = np.transpose(ss_data, [2, 0, 1]) 52 | ss_data = np.asarray(ss_data, np.float32) / 255.0 53 | 54 | original_image = np.transpose(original_image, [2, 0, 1]) 55 | original_image = np.asarray(original_image, np.float32) / 255.0 56 | label_object_center = label 57 | label = one_hot(n_classes+1, label) 58 | 59 | return original_image,ss_data,label,ss_label,label_ss_center,label_object_center 60 | 61 | 62 | def transform_source_ss_step2(data, label, is_train,ss_classes,n_classes,only_4_rotations,n_classes_target): 63 | data = imresize(data, (256,256)) 64 | ss_transformation = np.random.randint(ss_classes) 65 | 66 | original_image = tl.prepro.crop(data, 224, 224, is_random=is_train) 67 | data = tl.prepro.crop(data, 224, 224, is_random=is_train) 68 | 69 | if ss_transformation==0: 70 | ss_data=data 71 | if ss_transformation==1: 72 | ss_data=np.rot90(data,k=1) 73 | if ss_transformation==2: 74 | ss_data=np.rot90(data,k=2) 75 | if ss_transformation==3: 76 | ss_data=np.rot90(data,k=3) 77 | 78 | ss_label = one_hot(ss_classes,ss_transformation) 79 | 80 | ss_data = np.transpose(ss_data, [2, 0, 1]) 81 | ss_data = np.asarray(ss_data, np.float32) / 255.0 82 | 83 | original_image = np.transpose(original_image, [2, 0, 1]) 84 | original_image = np.asarray(original_image, np.float32) / 255.0 85 | label_object_center = label 86 | label = one_hot(n_classes+1, label) 87 | label_ss_center = ss_transformation 88 | 89 | return original_image,ss_data,label,ss_label,label_ss_center,label_object_center 90 | 91 | 92 | def transform_target_ss_step2(data, label, is_train,ss_classes,n_classes,only_4_rotations,n_classes_target): 93 | data = imresize(data, (256,256)) 94 | ss_transformation = np.random.randint(ss_classes) 95 | 96 | original_image = tl.prepro.crop(data, 224, 224, is_random=is_train) 97 | data = tl.prepro.crop(data, 224, 224, is_random=is_train) 98 | 99 | if ss_transformation==0: 100 | ss_data=data 101 | if ss_transformation==1: 102 | ss_data=np.rot90(data,k=1) 103 | if ss_transformation==2: 104 | ss_data=np.rot90(data,k=2) 105 | if ss_transformation==3: 106 | ss_data=np.rot90(data,k=3) 107 | 108 | ss_label = one_hot(ss_classes,ss_transformation) 109 | 110 | ss_data = np.transpose(ss_data, [2, 0, 1]) 111 | ss_data = np.asarray(ss_data, np.float32) / 255.0 112 | 113 | original_image = np.transpose(original_image, [2, 0, 1]) 114 | original_image = np.asarray(original_image, np.float32) / 255.0 115 | 116 | return original_image,ss_data,ss_label 117 | 118 | def transform_target_test(data, label, is_train,ss_classes,n_classes,only_4_rotations,n_classes_target): 119 | data = imresize(data, (256,256)) 120 | label = one_hot(n_classes_target, label) 121 | data = tl.prepro.crop(data, 224, 224, is_random=is_train) 122 | data = np.transpose(data, [2, 0, 1]) 123 | data = np.asarray(data, np.float32) / 255.0 124 | return data, label 125 | 126 | def transform_target_test_mnist(data, label, is_train,ss_classes,n_classes,only_4_rotations,n_classes_target,mean,std): 127 | 128 | label = one_hot(n_classes_target, label) 129 | data = imresize(data, (32,32)) 130 | data = np.transpose(data, [2, 0, 1]) 131 | data = np.asarray(data, np.float32) / 255.0 132 | 133 | mean = mean.repeat(3) 134 | mean = np.asarray(mean, dtype=np.float32).reshape((3, 1, 1)) 135 | std = std.repeat(3) 136 | std = np.asarray(std, dtype=np.float32).reshape((3, 1, 1)) 137 | 138 | data = (data - mean)/std 139 | 140 | return data, label 141 | 142 | def transform_target_test_for_scores(data, label, is_train,ss_classes,n_classes,only_4_rotations,n_classes_target): 143 | data = imresize(data, (256,256)) 144 | return data, label 145 | -------------------------------------------------------------------------------- /code/utilities.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorlayer as tl 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.autograd.variable import * 8 | import os 9 | from collections import Counter 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | class Accumulator(dict): 14 | def __init__(self, name_or_names, accumulate_fn=np.concatenate): 15 | super(Accumulator, self).__init__() 16 | self.names = [name_or_names] if isinstance(name_or_names, str) else name_or_names 17 | self.accumulate_fn = accumulate_fn 18 | for name in self.names: 19 | self.__setitem__(name, []) 20 | 21 | def updateData(self, scope): 22 | for name in self.names: 23 | self.__getitem__(name).append(scope[name]) 24 | 25 | def __enter__(self): 26 | return self 27 | 28 | def __exit__(self, exc_type, exc_val, exc_tb): 29 | if exc_tb: 30 | print(exc_tb) 31 | return False 32 | 33 | for name in self.names: 34 | self.__setitem__(name, self.accumulate_fn(self.__getitem__(name))) 35 | 36 | return True 37 | 38 | class TrainingModeManager: 39 | def __init__(self, nets, train=False): 40 | self.nets = nets 41 | self.modes = [net.training for net in nets] 42 | self.train = train 43 | def __enter__(self): 44 | for net in self.nets: 45 | net.train(self.train) 46 | def __exit__(self, exceptionType, exception, exceptionTraceback): 47 | for (mode, net) in zip(self.modes, self.nets): 48 | net.train(mode) 49 | self.nets = None # release reference, to avoid imexplicit reference 50 | if exceptionTraceback: 51 | print(exceptionTraceback) 52 | return False 53 | return True 54 | 55 | def clear_output(): 56 | def clear(): 57 | return 58 | try: 59 | from IPython.display import clear_output as clear 60 | except ImportError as e: 61 | pass 62 | import os 63 | def cls(): 64 | os.system('cls' if os.name == 'nt' else 'clear') 65 | 66 | clear() 67 | cls() 68 | 69 | def addkey(diction, key, global_vars): 70 | diction[key] = global_vars[key] 71 | 72 | def track_scalars(logger, names, global_vars): 73 | values = {} 74 | for name in names: 75 | addkey(values, name, global_vars) 76 | for k in values: 77 | values[k] = variable_to_numpy(values[k]) 78 | for k, v in values.items(): 79 | logger.log_scalar(k, v) 80 | print(values) 81 | 82 | def variable_to_numpy(x): 83 | ans = x.cpu().data.numpy() 84 | if torch.numel(x) == 1: 85 | return float(np.sum(ans)) 86 | return ans 87 | 88 | def inverseDecaySheduler(step, initial_lr, gamma=10, power=0.75, max_iter=1000): 89 | 90 | return initial_lr * ((1 + gamma * min(1.0, step / float(max_iter))) ** (- power)) 91 | 92 | def aToBSheduler(step, A, B, gamma=10, max_iter=10000): 93 | 94 | ans = A + (2.0 / (1 + np.exp(- gamma * step * 1.0 / max_iter)) - 1.0) * (B - A) 95 | return float(ans) 96 | 97 | def stepScheduler(step, initial_lr, gamma=0.1,max_iter=1000): 98 | 99 | step = int(step) 100 | step_size = int(max_iter * .8) 101 | 102 | if step>step_size: 103 | ans = initial_lr*gamma 104 | else: 105 | ans = initial_lr 106 | 107 | return ans 108 | 109 | def one_hot(n_class, index): 110 | tmp = np.zeros((n_class,), dtype=np.float32) 111 | tmp[index] = 1.0 112 | return tmp 113 | 114 | class OptimWithSheduler: 115 | def __init__(self, optimizer, scheduler_func): 116 | self.optimizer = optimizer 117 | self.scheduler_func = scheduler_func 118 | self.global_step = 0.0 119 | for g in self.optimizer.param_groups: 120 | g['initial_lr'] = g['lr'] 121 | def zero_grad(self): 122 | self.optimizer.zero_grad() 123 | def step(self): 124 | for g in self.optimizer.param_groups: 125 | g['lr'] = self.scheduler_func(step=self.global_step, initial_lr = g['initial_lr']) 126 | self.optimizer.step() 127 | self.global_step += 1 128 | 129 | class OptimizerManager: 130 | def __init__(self, optims): 131 | self.optims = optims #if isinstance(optims, Iterable) else [optims] 132 | def __enter__(self): 133 | for op in self.optims: 134 | op.zero_grad() 135 | def __exit__(self, exceptionType, exception, exceptionTraceback): 136 | for op in self.optims: 137 | op.step() 138 | self.optims = None 139 | if exceptionTraceback: 140 | print(exceptionTraceback) 141 | return False 142 | return True 143 | 144 | def setGPU(i): 145 | global os 146 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 147 | os.environ["CUDA_VISIBLE_DEVICES"] = "%s"%(i) 148 | gpus = [x.strip() for x in (str(i)).split(',')] 149 | NGPU = len(gpus) 150 | print('gpu(s) to be used: %s'%str(gpus)) 151 | return NGPU 152 | 153 | class Logger(object): 154 | def __init__(self, log_dir, clear=False): 155 | tl.files.exists_or_mkdir(log_dir) 156 | self.writer = tf.compat.v1.summary.FileWriter(log_dir) 157 | self.step = 0 158 | self.log_dir = log_dir 159 | 160 | def log_scalar(self, tag, value, step = None): 161 | if not step: 162 | step = self.step 163 | summary = tf.compat.v1.Summary(value = [tf.compat.v1.Summary.Value(tag = tag, 164 | simple_value = value)]) 165 | self.writer.add_summary(summary, step) 166 | self.writer.flush() 167 | 168 | def log_images(self, tag, images, step = None): 169 | if not step: 170 | step = self.step 171 | 172 | im_summaries = [] 173 | for nr, img in enumerate(images): 174 | s = StringIO() 175 | 176 | if len(img.shape) == 2: 177 | img = np.expand_dims(img, axis=-1) 178 | 179 | if img.shape[-1] == 1: 180 | img = np.tile(img, [1, 1, 3]) 181 | img = to_rgb_np(img) 182 | plt.imsave(s, img, format = 'png') 183 | 184 | img_sum = tf.Summary.Image(encoded_image_string = s.getvalue(), 185 | height = img.shape[0], 186 | width = img.shape[1]) 187 | im_summaries.append(tf.Summary.Value(tag = '%s/%d' % (tag, nr), 188 | image = img_sum)) 189 | summary = tf.Summary(value = im_summaries) 190 | self.writer.add_summary(summary, step) 191 | self.writer.flush() 192 | 193 | def log_histogram(self, tag, values, step = None, bins = 1000): 194 | if not step: 195 | step = self.step 196 | values = np.array(values) 197 | counts, bin_edges = np.histogram(values, bins=bins) 198 | hist = tf.HistogramProto() 199 | hist.min = float(np.min(values)) 200 | hist.max = float(np.max(values)) 201 | hist.num = int(np.prod(values.shape)) 202 | hist.sum = float(np.sum(values)) 203 | hist.sum_squares = float(np.sum(values**2)) 204 | for edge in bin_edges: 205 | hist.bucket_limit.append(edge) 206 | for c in counts: 207 | hist.bucket.append(c) 208 | 209 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 210 | self.writer.add_summary(summary, step) 211 | self.writer.flush() 212 | 213 | def log_bar(self, tag, values, xs = None, step = None): 214 | if not step: 215 | step = self.step 216 | 217 | values = np.asarray(values).flatten() 218 | if not xs: 219 | axises = list(range(len(values))) 220 | else: 221 | axises = xs 222 | hist = tf.HistogramProto() 223 | hist.min = float(min(axises)) 224 | hist.max = float(max(axises)) 225 | hist.num = sum(values) 226 | hist.sum = sum([y * x for (x, y) in zip(axises, values)]) 227 | hist.sum_squares = sum([y * (x ** 2) for (x, y) in zip(axises, values)]) 228 | 229 | for edge in axises: 230 | hist.bucket_limit.append(edge - 1e-10) 231 | hist.bucket_limit.append(edge + 1e-10) 232 | for c in values: 233 | hist.bucket.append(0) 234 | hist.bucket.append(c) 235 | 236 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 237 | self.writer.add_summary(summary, self.step) 238 | self.writer.flush() 239 | 240 | class AccuracyCounter: 241 | def __init__(self): 242 | self.Ncorrect = 0.0 243 | self.Ntotal = 0.0 244 | 245 | def addOntBatch(self, predict, label): 246 | assert predict.shape == label.shape 247 | correct_prediction = np.equal(np.argmax(predict, 1), np.argmax(label, 1)) 248 | Ncorrect = np.sum(correct_prediction.astype(np.float32)) 249 | Ntotal = len(label) 250 | self.Ncorrect += Ncorrect 251 | self.Ntotal += Ntotal 252 | return Ncorrect / Ntotal 253 | 254 | def reportAccuracy(self): 255 | return np.asarray(self.Ncorrect, dtype=float) / np.asarray(self.Ntotal, dtype=float) 256 | 257 | def CrossEntropyLoss(label, predict_prob, class_level_weight = None, instance_level_weight = None, epsilon = 1e-12): 258 | N, C = label.size() 259 | N_, C_ = predict_prob.size() 260 | 261 | assert N == N_ and C == C_, 'fatal error: dimension mismatch!' 262 | 263 | if class_level_weight is None: 264 | class_level_weight = 1.0 265 | else: 266 | if len(class_level_weight.size()) == 1: 267 | class_level_weight = class_level_weight.view(1, class_level_weight.size(0)) 268 | assert class_level_weight.size(1) == C, 'fatal error: dimension mismatch!' 269 | 270 | if instance_level_weight is None: 271 | instance_level_weight = 1.0 272 | else: 273 | if len(instance_level_weight.size()) == 1: 274 | instance_level_weight = instance_level_weight.view(instance_level_weight.size(0), 1) 275 | assert instance_level_weight.size(0) == N, 'fatal error: dimension mismatch!' 276 | 277 | ce = -label * torch.log(predict_prob + epsilon) 278 | return torch.sum(instance_level_weight * ce * class_level_weight) / float(N) 279 | 280 | def BCELossForMultiClassification(label, predict_prob, class_level_weight=None, instance_level_weight=None, epsilon = 1e-12): 281 | N, C = label.size() 282 | N_, C_ = predict_prob.size() 283 | 284 | assert N == N_ and C == C_, 'fatal error: dimension mismatch!' 285 | 286 | if class_level_weight is None: 287 | class_level_weight = 1.0 288 | else: 289 | if len(class_level_weight.size()) == 1: 290 | class_level_weight = class_level_weight.view(1, class_level_weight.size(0)) 291 | assert class_level_weight.size(1) == C, 'fatal error: dimension mismatch!' 292 | 293 | if instance_level_weight is None: 294 | instance_level_weight = 1.0 295 | else: 296 | if len(instance_level_weight.size()) == 1: 297 | instance_level_weight = instance_level_weight.view(instance_level_weight.size(0), 1) 298 | assert instance_level_weight.size(0) == N, 'fatal error: dimension mismatch!' 299 | 300 | bce = -label * torch.log(predict_prob + epsilon) - (1.0 - label) * torch.log(1.0 - predict_prob + epsilon) 301 | return torch.sum(instance_level_weight * bce * class_level_weight) / float(N) 302 | 303 | def EntropyLoss(predict_prob, class_level_weight=None, instance_level_weight=None, epsilon= 1e-20): 304 | 305 | N, C = predict_prob.size() 306 | 307 | if class_level_weight is None: 308 | class_level_weight = 1.0 309 | else: 310 | if len(class_level_weight.size()) == 1: 311 | class_level_weight = class_level_weight.view(1, class_level_weight.size(0)) 312 | assert class_level_weight.size(1) == C, 'fatal error: dimension mismatch!' 313 | 314 | if instance_level_weight is None: 315 | instance_level_weight = 1.0 316 | else: 317 | if len(instance_level_weight.size()) == 1: 318 | instance_level_weight = instance_level_weight.view(instance_level_weight.size(0), 1) 319 | assert instance_level_weight.size(0) == N, 'fatal error: dimension mismatch!' 320 | 321 | entropy = -predict_prob*torch.log(predict_prob + epsilon) 322 | return torch.sum(instance_level_weight * entropy * class_level_weight) / float(N) 323 | 324 | def plot_confusion_matrix(cm, true_classes,pred_classes=None, 325 | normalize=False, 326 | title='Confusion matrix', 327 | cmap=plt.cm.Blues): 328 | import itertools 329 | pred_classes = pred_classes or true_classes 330 | if normalize: 331 | cm = cm.astype(np.float) / np.sum(cm, axis=1, keepdims=True) 332 | 333 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 334 | plt.title(title) 335 | plt.colorbar(fraction=0.046, pad=0.04) 336 | true_tick_marks = np.arange(len(true_classes)) 337 | plt.yticks(true_classes, true_classes) 338 | pred_tick_marks = np.arange(len(pred_classes)) 339 | plt.xticks(pred_tick_marks, pred_classes, rotation=45) 340 | 341 | 342 | fmt = '.2f' if normalize else 'd' 343 | thresh = cm.max() / 2. 344 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 345 | plt.text(j, i, format(cm[i, j], fmt), 346 | horizontalalignment="center", 347 | color="white" if cm[i, j] > thresh else "black") 348 | 349 | plt.tight_layout() 350 | plt.ylabel('True label') 351 | plt.xlabel('Predicted label') 352 | plt.show() 353 | 354 | def extended_confusion_matrix(y_true, y_pred, true_labels=None, pred_labels=None): 355 | 356 | if not true_labels: 357 | true_labels = sorted(list(set(list(y_true)))) 358 | true_label_to_id = {x : i for (i, x) in enumerate(true_labels)} 359 | if not pred_labels: 360 | pred_labels = true_labels 361 | pred_label_to_id = {x : i for (i, x) in enumerate(pred_labels)} 362 | confusion_matrix = np.zeros([len(true_labels), len(pred_labels)]) 363 | for (true, pred) in zip(y_true, y_pred): 364 | confusion_matrix[true_label_to_id[true]][pred_label_to_id[pred]] += 1.0 365 | return confusion_matrix -------------------------------------------------------------------------------- /data/amazon_0-9_train_all.txt: -------------------------------------------------------------------------------- 1 | office/amazon/backpack/frame_0061.jpg 0 2 | office/amazon/backpack/frame_0059.jpg 0 3 | office/amazon/backpack/frame_0083.jpg 0 4 | office/amazon/backpack/frame_0021.jpg 0 5 | office/amazon/backpack/frame_0062.jpg 0 6 | office/amazon/backpack/frame_0075.jpg 0 7 | office/amazon/backpack/frame_0054.jpg 0 8 | office/amazon/backpack/frame_0091.jpg 0 9 | office/amazon/backpack/frame_0028.jpg 0 10 | office/amazon/backpack/frame_0002.jpg 0 11 | office/amazon/backpack/frame_0049.jpg 0 12 | office/amazon/backpack/frame_0084.jpg 0 13 | office/amazon/backpack/frame_0060.jpg 0 14 | office/amazon/backpack/frame_0050.jpg 0 15 | office/amazon/backpack/frame_0073.jpg 0 16 | office/amazon/backpack/frame_0030.jpg 0 17 | office/amazon/backpack/frame_0047.jpg 0 18 | office/amazon/backpack/frame_0005.jpg 0 19 | office/amazon/backpack/frame_0064.jpg 0 20 | office/amazon/backpack/frame_0082.jpg 0 21 | office/amazon/backpack/frame_0077.jpg 0 22 | office/amazon/backpack/frame_0078.jpg 0 23 | office/amazon/backpack/frame_0025.jpg 0 24 | office/amazon/backpack/frame_0020.jpg 0 25 | office/amazon/backpack/frame_0063.jpg 0 26 | office/amazon/backpack/frame_0006.jpg 0 27 | office/amazon/backpack/frame_0009.jpg 0 28 | office/amazon/backpack/frame_0069.jpg 0 29 | office/amazon/backpack/frame_0086.jpg 0 30 | office/amazon/backpack/frame_0071.jpg 0 31 | office/amazon/backpack/frame_0048.jpg 0 32 | office/amazon/backpack/frame_0018.jpg 0 33 | office/amazon/backpack/frame_0065.jpg 0 34 | office/amazon/backpack/frame_0024.jpg 0 35 | office/amazon/backpack/frame_0032.jpg 0 36 | office/amazon/backpack/frame_0019.jpg 0 37 | office/amazon/backpack/frame_0090.jpg 0 38 | office/amazon/backpack/frame_0088.jpg 0 39 | office/amazon/backpack/frame_0087.jpg 0 40 | office/amazon/backpack/frame_0029.jpg 0 41 | office/amazon/backpack/frame_0031.jpg 0 42 | office/amazon/backpack/frame_0012.jpg 0 43 | office/amazon/backpack/frame_0053.jpg 0 44 | office/amazon/backpack/frame_0008.jpg 0 45 | office/amazon/backpack/frame_0013.jpg 0 46 | office/amazon/backpack/frame_0051.jpg 0 47 | office/amazon/backpack/frame_0036.jpg 0 48 | office/amazon/backpack/frame_0007.jpg 0 49 | office/amazon/backpack/frame_0072.jpg 0 50 | office/amazon/backpack/frame_0089.jpg 0 51 | office/amazon/backpack/frame_0039.jpg 0 52 | office/amazon/backpack/frame_0068.jpg 0 53 | office/amazon/backpack/frame_0001.jpg 0 54 | office/amazon/backpack/frame_0037.jpg 0 55 | office/amazon/backpack/frame_0079.jpg 0 56 | office/amazon/backpack/frame_0038.jpg 0 57 | office/amazon/backpack/frame_0085.jpg 0 58 | office/amazon/backpack/frame_0023.jpg 0 59 | office/amazon/backpack/frame_0056.jpg 0 60 | office/amazon/backpack/frame_0033.jpg 0 61 | office/amazon/backpack/frame_0004.jpg 0 62 | office/amazon/backpack/frame_0017.jpg 0 63 | office/amazon/backpack/frame_0042.jpg 0 64 | office/amazon/backpack/frame_0044.jpg 0 65 | office/amazon/backpack/frame_0046.jpg 0 66 | office/amazon/backpack/frame_0015.jpg 0 67 | office/amazon/backpack/frame_0070.jpg 0 68 | office/amazon/backpack/frame_0041.jpg 0 69 | office/amazon/backpack/frame_0052.jpg 0 70 | office/amazon/backpack/frame_0080.jpg 0 71 | office/amazon/backpack/frame_0055.jpg 0 72 | office/amazon/backpack/frame_0026.jpg 0 73 | office/amazon/backpack/frame_0040.jpg 0 74 | office/amazon/backpack/frame_0016.jpg 0 75 | office/amazon/backpack/frame_0035.jpg 0 76 | office/amazon/backpack/frame_0043.jpg 0 77 | office/amazon/backpack/frame_0027.jpg 0 78 | office/amazon/backpack/frame_0092.jpg 0 79 | office/amazon/backpack/frame_0066.jpg 0 80 | office/amazon/backpack/frame_0076.jpg 0 81 | office/amazon/backpack/frame_0074.jpg 0 82 | office/amazon/backpack/frame_0058.jpg 0 83 | office/amazon/backpack/frame_0011.jpg 0 84 | office/amazon/backpack/frame_0045.jpg 0 85 | office/amazon/backpack/frame_0081.jpg 0 86 | office/amazon/backpack/frame_0034.jpg 0 87 | office/amazon/backpack/frame_0067.jpg 0 88 | office/amazon/backpack/frame_0022.jpg 0 89 | office/amazon/backpack/frame_0014.jpg 0 90 | office/amazon/backpack/frame_0003.jpg 0 91 | office/amazon/backpack/frame_0010.jpg 0 92 | office/amazon/backpack/frame_0057.jpg 0 93 | office/amazon/bike/frame_0074.jpg 1 94 | office/amazon/bike/frame_0002.jpg 1 95 | office/amazon/bike/frame_0068.jpg 1 96 | office/amazon/bike/frame_0052.jpg 1 97 | office/amazon/bike/frame_0041.jpg 1 98 | office/amazon/bike/frame_0014.jpg 1 99 | office/amazon/bike/frame_0057.jpg 1 100 | office/amazon/bike/frame_0040.jpg 1 101 | office/amazon/bike/frame_0020.jpg 1 102 | office/amazon/bike/frame_0079.jpg 1 103 | office/amazon/bike/frame_0059.jpg 1 104 | office/amazon/bike/frame_0009.jpg 1 105 | office/amazon/bike/frame_0047.jpg 1 106 | office/amazon/bike/frame_0031.jpg 1 107 | office/amazon/bike/frame_0038.jpg 1 108 | office/amazon/bike/frame_0023.jpg 1 109 | office/amazon/bike/frame_0032.jpg 1 110 | office/amazon/bike/frame_0022.jpg 1 111 | office/amazon/bike/frame_0044.jpg 1 112 | office/amazon/bike/frame_0013.jpg 1 113 | office/amazon/bike/frame_0066.jpg 1 114 | office/amazon/bike/frame_0058.jpg 1 115 | office/amazon/bike/frame_0071.jpg 1 116 | office/amazon/bike/frame_0015.jpg 1 117 | office/amazon/bike/frame_0005.jpg 1 118 | office/amazon/bike/frame_0043.jpg 1 119 | office/amazon/bike/frame_0008.jpg 1 120 | office/amazon/bike/frame_0042.jpg 1 121 | office/amazon/bike/frame_0035.jpg 1 122 | office/amazon/bike/frame_0003.jpg 1 123 | office/amazon/bike/frame_0081.jpg 1 124 | office/amazon/bike/frame_0054.jpg 1 125 | office/amazon/bike/frame_0011.jpg 1 126 | office/amazon/bike/frame_0049.jpg 1 127 | office/amazon/bike/frame_0021.jpg 1 128 | office/amazon/bike/frame_0077.jpg 1 129 | office/amazon/bike/frame_0004.jpg 1 130 | office/amazon/bike/frame_0063.jpg 1 131 | office/amazon/bike/frame_0061.jpg 1 132 | office/amazon/bike/frame_0019.jpg 1 133 | office/amazon/bike/frame_0062.jpg 1 134 | office/amazon/bike/frame_0029.jpg 1 135 | office/amazon/bike/frame_0030.jpg 1 136 | office/amazon/bike/frame_0033.jpg 1 137 | office/amazon/bike/frame_0051.jpg 1 138 | office/amazon/bike/frame_0039.jpg 1 139 | office/amazon/bike/frame_0076.jpg 1 140 | office/amazon/bike/frame_0012.jpg 1 141 | office/amazon/bike/frame_0025.jpg 1 142 | office/amazon/bike/frame_0036.jpg 1 143 | office/amazon/bike/frame_0016.jpg 1 144 | office/amazon/bike/frame_0050.jpg 1 145 | office/amazon/bike/frame_0010.jpg 1 146 | office/amazon/bike/frame_0024.jpg 1 147 | office/amazon/bike/frame_0065.jpg 1 148 | office/amazon/bike/frame_0007.jpg 1 149 | office/amazon/bike/frame_0053.jpg 1 150 | office/amazon/bike/frame_0075.jpg 1 151 | office/amazon/bike/frame_0001.jpg 1 152 | office/amazon/bike/frame_0017.jpg 1 153 | office/amazon/bike/frame_0006.jpg 1 154 | office/amazon/bike/frame_0082.jpg 1 155 | office/amazon/bike/frame_0056.jpg 1 156 | office/amazon/bike/frame_0060.jpg 1 157 | office/amazon/bike/frame_0026.jpg 1 158 | office/amazon/bike/frame_0037.jpg 1 159 | office/amazon/bike/frame_0034.jpg 1 160 | office/amazon/bike/frame_0069.jpg 1 161 | office/amazon/bike/frame_0073.jpg 1 162 | office/amazon/bike/frame_0067.jpg 1 163 | office/amazon/bike/frame_0027.jpg 1 164 | office/amazon/bike/frame_0078.jpg 1 165 | office/amazon/bike/frame_0018.jpg 1 166 | office/amazon/bike/frame_0045.jpg 1 167 | office/amazon/bike/frame_0072.jpg 1 168 | office/amazon/bike/frame_0064.jpg 1 169 | office/amazon/bike/frame_0070.jpg 1 170 | office/amazon/bike/frame_0080.jpg 1 171 | office/amazon/bike/frame_0028.jpg 1 172 | office/amazon/bike/frame_0048.jpg 1 173 | office/amazon/bike/frame_0046.jpg 1 174 | office/amazon/bike/frame_0055.jpg 1 175 | office/amazon/bike_helmet/frame_0061.jpg 2 176 | office/amazon/bike_helmet/frame_0013.jpg 2 177 | office/amazon/bike_helmet/frame_0066.jpg 2 178 | office/amazon/bike_helmet/frame_0011.jpg 2 179 | office/amazon/bike_helmet/frame_0008.jpg 2 180 | office/amazon/bike_helmet/frame_0010.jpg 2 181 | office/amazon/bike_helmet/frame_0027.jpg 2 182 | office/amazon/bike_helmet/frame_0014.jpg 2 183 | office/amazon/bike_helmet/frame_0047.jpg 2 184 | office/amazon/bike_helmet/frame_0041.jpg 2 185 | office/amazon/bike_helmet/frame_0006.jpg 2 186 | office/amazon/bike_helmet/frame_0032.jpg 2 187 | office/amazon/bike_helmet/frame_0021.jpg 2 188 | office/amazon/bike_helmet/frame_0023.jpg 2 189 | office/amazon/bike_helmet/frame_0040.jpg 2 190 | office/amazon/bike_helmet/frame_0056.jpg 2 191 | office/amazon/bike_helmet/frame_0055.jpg 2 192 | office/amazon/bike_helmet/frame_0058.jpg 2 193 | office/amazon/bike_helmet/frame_0072.jpg 2 194 | office/amazon/bike_helmet/frame_0045.jpg 2 195 | office/amazon/bike_helmet/frame_0052.jpg 2 196 | office/amazon/bike_helmet/frame_0057.jpg 2 197 | office/amazon/bike_helmet/frame_0059.jpg 2 198 | office/amazon/bike_helmet/frame_0001.jpg 2 199 | office/amazon/bike_helmet/frame_0046.jpg 2 200 | office/amazon/bike_helmet/frame_0038.jpg 2 201 | office/amazon/bike_helmet/frame_0049.jpg 2 202 | office/amazon/bike_helmet/frame_0060.jpg 2 203 | office/amazon/bike_helmet/frame_0009.jpg 2 204 | office/amazon/bike_helmet/frame_0003.jpg 2 205 | office/amazon/bike_helmet/frame_0053.jpg 2 206 | office/amazon/bike_helmet/frame_0071.jpg 2 207 | office/amazon/bike_helmet/frame_0050.jpg 2 208 | office/amazon/bike_helmet/frame_0026.jpg 2 209 | office/amazon/bike_helmet/frame_0051.jpg 2 210 | office/amazon/bike_helmet/frame_0022.jpg 2 211 | office/amazon/bike_helmet/frame_0069.jpg 2 212 | office/amazon/bike_helmet/frame_0036.jpg 2 213 | office/amazon/bike_helmet/frame_0035.jpg 2 214 | office/amazon/bike_helmet/frame_0018.jpg 2 215 | office/amazon/bike_helmet/frame_0034.jpg 2 216 | office/amazon/bike_helmet/frame_0037.jpg 2 217 | office/amazon/bike_helmet/frame_0033.jpg 2 218 | office/amazon/bike_helmet/frame_0002.jpg 2 219 | office/amazon/bike_helmet/frame_0070.jpg 2 220 | office/amazon/bike_helmet/frame_0062.jpg 2 221 | office/amazon/bike_helmet/frame_0068.jpg 2 222 | office/amazon/bike_helmet/frame_0024.jpg 2 223 | office/amazon/bike_helmet/frame_0044.jpg 2 224 | office/amazon/bike_helmet/frame_0067.jpg 2 225 | office/amazon/bike_helmet/frame_0016.jpg 2 226 | office/amazon/bike_helmet/frame_0042.jpg 2 227 | office/amazon/bike_helmet/frame_0012.jpg 2 228 | office/amazon/bike_helmet/frame_0019.jpg 2 229 | office/amazon/bike_helmet/frame_0039.jpg 2 230 | office/amazon/bike_helmet/frame_0031.jpg 2 231 | office/amazon/bike_helmet/frame_0063.jpg 2 232 | office/amazon/bike_helmet/frame_0064.jpg 2 233 | office/amazon/bike_helmet/frame_0054.jpg 2 234 | office/amazon/bike_helmet/frame_0020.jpg 2 235 | office/amazon/bike_helmet/frame_0029.jpg 2 236 | office/amazon/bike_helmet/frame_0005.jpg 2 237 | office/amazon/bike_helmet/frame_0030.jpg 2 238 | office/amazon/bike_helmet/frame_0065.jpg 2 239 | office/amazon/bike_helmet/frame_0007.jpg 2 240 | office/amazon/bike_helmet/frame_0004.jpg 2 241 | office/amazon/bike_helmet/frame_0017.jpg 2 242 | office/amazon/bike_helmet/frame_0043.jpg 2 243 | office/amazon/bike_helmet/frame_0025.jpg 2 244 | office/amazon/bike_helmet/frame_0048.jpg 2 245 | office/amazon/bike_helmet/frame_0028.jpg 2 246 | office/amazon/bike_helmet/frame_0015.jpg 2 247 | office/amazon/bookcase/frame_0054.jpg 3 248 | office/amazon/bookcase/frame_0004.jpg 3 249 | office/amazon/bookcase/frame_0059.jpg 3 250 | office/amazon/bookcase/frame_0075.jpg 3 251 | office/amazon/bookcase/frame_0022.jpg 3 252 | office/amazon/bookcase/frame_0011.jpg 3 253 | office/amazon/bookcase/frame_0026.jpg 3 254 | office/amazon/bookcase/frame_0049.jpg 3 255 | office/amazon/bookcase/frame_0068.jpg 3 256 | office/amazon/bookcase/frame_0024.jpg 3 257 | office/amazon/bookcase/frame_0072.jpg 3 258 | office/amazon/bookcase/frame_0069.jpg 3 259 | office/amazon/bookcase/frame_0079.jpg 3 260 | office/amazon/bookcase/frame_0025.jpg 3 261 | office/amazon/bookcase/frame_0003.jpg 3 262 | office/amazon/bookcase/frame_0036.jpg 3 263 | office/amazon/bookcase/frame_0063.jpg 3 264 | office/amazon/bookcase/frame_0034.jpg 3 265 | office/amazon/bookcase/frame_0037.jpg 3 266 | office/amazon/bookcase/frame_0032.jpg 3 267 | office/amazon/bookcase/frame_0016.jpg 3 268 | office/amazon/bookcase/frame_0065.jpg 3 269 | office/amazon/bookcase/frame_0060.jpg 3 270 | office/amazon/bookcase/frame_0050.jpg 3 271 | office/amazon/bookcase/frame_0035.jpg 3 272 | office/amazon/bookcase/frame_0041.jpg 3 273 | office/amazon/bookcase/frame_0028.jpg 3 274 | office/amazon/bookcase/frame_0018.jpg 3 275 | office/amazon/bookcase/frame_0040.jpg 3 276 | office/amazon/bookcase/frame_0064.jpg 3 277 | office/amazon/bookcase/frame_0053.jpg 3 278 | office/amazon/bookcase/frame_0015.jpg 3 279 | office/amazon/bookcase/frame_0046.jpg 3 280 | office/amazon/bookcase/frame_0014.jpg 3 281 | office/amazon/bookcase/frame_0066.jpg 3 282 | office/amazon/bookcase/frame_0033.jpg 3 283 | office/amazon/bookcase/frame_0048.jpg 3 284 | office/amazon/bookcase/frame_0005.jpg 3 285 | office/amazon/bookcase/frame_0052.jpg 3 286 | office/amazon/bookcase/frame_0031.jpg 3 287 | office/amazon/bookcase/frame_0013.jpg 3 288 | office/amazon/bookcase/frame_0029.jpg 3 289 | office/amazon/bookcase/frame_0010.jpg 3 290 | office/amazon/bookcase/frame_0077.jpg 3 291 | office/amazon/bookcase/frame_0006.jpg 3 292 | office/amazon/bookcase/frame_0008.jpg 3 293 | office/amazon/bookcase/frame_0051.jpg 3 294 | office/amazon/bookcase/frame_0030.jpg 3 295 | office/amazon/bookcase/frame_0042.jpg 3 296 | office/amazon/bookcase/frame_0021.jpg 3 297 | office/amazon/bookcase/frame_0076.jpg 3 298 | office/amazon/bookcase/frame_0044.jpg 3 299 | office/amazon/bookcase/frame_0023.jpg 3 300 | office/amazon/bookcase/frame_0045.jpg 3 301 | office/amazon/bookcase/frame_0017.jpg 3 302 | office/amazon/bookcase/frame_0002.jpg 3 303 | office/amazon/bookcase/frame_0019.jpg 3 304 | office/amazon/bookcase/frame_0058.jpg 3 305 | office/amazon/bookcase/frame_0070.jpg 3 306 | office/amazon/bookcase/frame_0067.jpg 3 307 | office/amazon/bookcase/frame_0027.jpg 3 308 | office/amazon/bookcase/frame_0001.jpg 3 309 | office/amazon/bookcase/frame_0057.jpg 3 310 | office/amazon/bookcase/frame_0020.jpg 3 311 | office/amazon/bookcase/frame_0009.jpg 3 312 | office/amazon/bookcase/frame_0047.jpg 3 313 | office/amazon/bookcase/frame_0073.jpg 3 314 | office/amazon/bookcase/frame_0038.jpg 3 315 | office/amazon/bookcase/frame_0039.jpg 3 316 | office/amazon/bookcase/frame_0082.jpg 3 317 | office/amazon/bookcase/frame_0012.jpg 3 318 | office/amazon/bookcase/frame_0080.jpg 3 319 | office/amazon/bookcase/frame_0081.jpg 3 320 | office/amazon/bookcase/frame_0007.jpg 3 321 | office/amazon/bookcase/frame_0074.jpg 3 322 | office/amazon/bookcase/frame_0071.jpg 3 323 | office/amazon/bookcase/frame_0061.jpg 3 324 | office/amazon/bookcase/frame_0056.jpg 3 325 | office/amazon/bookcase/frame_0062.jpg 3 326 | office/amazon/bookcase/frame_0055.jpg 3 327 | office/amazon/bookcase/frame_0043.jpg 3 328 | office/amazon/bookcase/frame_0078.jpg 3 329 | office/amazon/bottle/frame_0033.jpg 4 330 | office/amazon/bottle/frame_0010.jpg 4 331 | office/amazon/bottle/frame_0021.jpg 4 332 | office/amazon/bottle/frame_0013.jpg 4 333 | office/amazon/bottle/frame_0006.jpg 4 334 | office/amazon/bottle/frame_0027.jpg 4 335 | office/amazon/bottle/frame_0008.jpg 4 336 | office/amazon/bottle/frame_0028.jpg 4 337 | office/amazon/bottle/frame_0017.jpg 4 338 | office/amazon/bottle/frame_0003.jpg 4 339 | office/amazon/bottle/frame_0019.jpg 4 340 | office/amazon/bottle/frame_0034.jpg 4 341 | office/amazon/bottle/frame_0029.jpg 4 342 | office/amazon/bottle/frame_0023.jpg 4 343 | office/amazon/bottle/frame_0004.jpg 4 344 | office/amazon/bottle/frame_0012.jpg 4 345 | office/amazon/bottle/frame_0031.jpg 4 346 | office/amazon/bottle/frame_0009.jpg 4 347 | office/amazon/bottle/frame_0026.jpg 4 348 | office/amazon/bottle/frame_0020.jpg 4 349 | office/amazon/bottle/frame_0002.jpg 4 350 | office/amazon/bottle/frame_0036.jpg 4 351 | office/amazon/bottle/frame_0016.jpg 4 352 | office/amazon/bottle/frame_0015.jpg 4 353 | office/amazon/bottle/frame_0025.jpg 4 354 | office/amazon/bottle/frame_0018.jpg 4 355 | office/amazon/bottle/frame_0011.jpg 4 356 | office/amazon/bottle/frame_0005.jpg 4 357 | office/amazon/bottle/frame_0022.jpg 4 358 | office/amazon/bottle/frame_0024.jpg 4 359 | office/amazon/bottle/frame_0007.jpg 4 360 | office/amazon/bottle/frame_0014.jpg 4 361 | office/amazon/bottle/frame_0030.jpg 4 362 | office/amazon/bottle/frame_0001.jpg 4 363 | office/amazon/bottle/frame_0032.jpg 4 364 | office/amazon/bottle/frame_0035.jpg 4 365 | office/amazon/calculator/frame_0086.jpg 5 366 | office/amazon/calculator/frame_0061.jpg 5 367 | office/amazon/calculator/frame_0053.jpg 5 368 | office/amazon/calculator/frame_0031.jpg 5 369 | office/amazon/calculator/frame_0034.jpg 5 370 | office/amazon/calculator/frame_0085.jpg 5 371 | office/amazon/calculator/frame_0041.jpg 5 372 | office/amazon/calculator/frame_0046.jpg 5 373 | office/amazon/calculator/frame_0067.jpg 5 374 | office/amazon/calculator/frame_0072.jpg 5 375 | office/amazon/calculator/frame_0056.jpg 5 376 | office/amazon/calculator/frame_0054.jpg 5 377 | office/amazon/calculator/frame_0015.jpg 5 378 | office/amazon/calculator/frame_0084.jpg 5 379 | office/amazon/calculator/frame_0006.jpg 5 380 | office/amazon/calculator/frame_0064.jpg 5 381 | office/amazon/calculator/frame_0047.jpg 5 382 | office/amazon/calculator/frame_0058.jpg 5 383 | office/amazon/calculator/frame_0009.jpg 5 384 | office/amazon/calculator/frame_0073.jpg 5 385 | office/amazon/calculator/frame_0020.jpg 5 386 | office/amazon/calculator/frame_0012.jpg 5 387 | office/amazon/calculator/frame_0089.jpg 5 388 | office/amazon/calculator/frame_0066.jpg 5 389 | office/amazon/calculator/frame_0007.jpg 5 390 | office/amazon/calculator/frame_0048.jpg 5 391 | office/amazon/calculator/frame_0024.jpg 5 392 | office/amazon/calculator/frame_0040.jpg 5 393 | office/amazon/calculator/frame_0087.jpg 5 394 | office/amazon/calculator/frame_0055.jpg 5 395 | office/amazon/calculator/frame_0068.jpg 5 396 | office/amazon/calculator/frame_0050.jpg 5 397 | office/amazon/calculator/frame_0059.jpg 5 398 | office/amazon/calculator/frame_0077.jpg 5 399 | office/amazon/calculator/frame_0082.jpg 5 400 | office/amazon/calculator/frame_0026.jpg 5 401 | office/amazon/calculator/frame_0088.jpg 5 402 | office/amazon/calculator/frame_0042.jpg 5 403 | office/amazon/calculator/frame_0022.jpg 5 404 | office/amazon/calculator/frame_0002.jpg 5 405 | office/amazon/calculator/frame_0027.jpg 5 406 | office/amazon/calculator/frame_0094.jpg 5 407 | office/amazon/calculator/frame_0035.jpg 5 408 | office/amazon/calculator/frame_0001.jpg 5 409 | office/amazon/calculator/frame_0057.jpg 5 410 | office/amazon/calculator/frame_0036.jpg 5 411 | office/amazon/calculator/frame_0074.jpg 5 412 | office/amazon/calculator/frame_0071.jpg 5 413 | office/amazon/calculator/frame_0045.jpg 5 414 | office/amazon/calculator/frame_0030.jpg 5 415 | office/amazon/calculator/frame_0004.jpg 5 416 | office/amazon/calculator/frame_0092.jpg 5 417 | office/amazon/calculator/frame_0016.jpg 5 418 | office/amazon/calculator/frame_0090.jpg 5 419 | office/amazon/calculator/frame_0005.jpg 5 420 | office/amazon/calculator/frame_0049.jpg 5 421 | office/amazon/calculator/frame_0038.jpg 5 422 | office/amazon/calculator/frame_0051.jpg 5 423 | office/amazon/calculator/frame_0065.jpg 5 424 | office/amazon/calculator/frame_0013.jpg 5 425 | office/amazon/calculator/frame_0060.jpg 5 426 | office/amazon/calculator/frame_0023.jpg 5 427 | office/amazon/calculator/frame_0052.jpg 5 428 | office/amazon/calculator/frame_0008.jpg 5 429 | office/amazon/calculator/frame_0010.jpg 5 430 | office/amazon/calculator/frame_0076.jpg 5 431 | office/amazon/calculator/frame_0018.jpg 5 432 | office/amazon/calculator/frame_0033.jpg 5 433 | office/amazon/calculator/frame_0079.jpg 5 434 | office/amazon/calculator/frame_0039.jpg 5 435 | office/amazon/calculator/frame_0044.jpg 5 436 | office/amazon/calculator/frame_0063.jpg 5 437 | office/amazon/calculator/frame_0043.jpg 5 438 | office/amazon/calculator/frame_0078.jpg 5 439 | office/amazon/calculator/frame_0011.jpg 5 440 | office/amazon/calculator/frame_0029.jpg 5 441 | office/amazon/calculator/frame_0025.jpg 5 442 | office/amazon/calculator/frame_0037.jpg 5 443 | office/amazon/calculator/frame_0083.jpg 5 444 | office/amazon/calculator/frame_0081.jpg 5 445 | office/amazon/calculator/frame_0003.jpg 5 446 | office/amazon/calculator/frame_0075.jpg 5 447 | office/amazon/calculator/frame_0032.jpg 5 448 | office/amazon/calculator/frame_0062.jpg 5 449 | office/amazon/calculator/frame_0017.jpg 5 450 | office/amazon/calculator/frame_0093.jpg 5 451 | office/amazon/calculator/frame_0021.jpg 5 452 | office/amazon/calculator/frame_0019.jpg 5 453 | office/amazon/calculator/frame_0028.jpg 5 454 | office/amazon/calculator/frame_0014.jpg 5 455 | office/amazon/calculator/frame_0069.jpg 5 456 | office/amazon/calculator/frame_0091.jpg 5 457 | office/amazon/calculator/frame_0080.jpg 5 458 | office/amazon/calculator/frame_0070.jpg 5 459 | office/amazon/desk_chair/frame_0024.jpg 6 460 | office/amazon/desk_chair/frame_0081.jpg 6 461 | office/amazon/desk_chair/frame_0070.jpg 6 462 | office/amazon/desk_chair/frame_0017.jpg 6 463 | office/amazon/desk_chair/frame_0091.jpg 6 464 | office/amazon/desk_chair/frame_0063.jpg 6 465 | office/amazon/desk_chair/frame_0049.jpg 6 466 | office/amazon/desk_chair/frame_0006.jpg 6 467 | office/amazon/desk_chair/frame_0003.jpg 6 468 | office/amazon/desk_chair/frame_0048.jpg 6 469 | office/amazon/desk_chair/frame_0016.jpg 6 470 | office/amazon/desk_chair/frame_0090.jpg 6 471 | office/amazon/desk_chair/frame_0004.jpg 6 472 | office/amazon/desk_chair/frame_0010.jpg 6 473 | office/amazon/desk_chair/frame_0077.jpg 6 474 | office/amazon/desk_chair/frame_0042.jpg 6 475 | office/amazon/desk_chair/frame_0008.jpg 6 476 | office/amazon/desk_chair/frame_0021.jpg 6 477 | office/amazon/desk_chair/frame_0068.jpg 6 478 | office/amazon/desk_chair/frame_0054.jpg 6 479 | office/amazon/desk_chair/frame_0034.jpg 6 480 | office/amazon/desk_chair/frame_0051.jpg 6 481 | office/amazon/desk_chair/frame_0030.jpg 6 482 | office/amazon/desk_chair/frame_0058.jpg 6 483 | office/amazon/desk_chair/frame_0069.jpg 6 484 | office/amazon/desk_chair/frame_0061.jpg 6 485 | office/amazon/desk_chair/frame_0038.jpg 6 486 | office/amazon/desk_chair/frame_0002.jpg 6 487 | office/amazon/desk_chair/frame_0027.jpg 6 488 | office/amazon/desk_chair/frame_0064.jpg 6 489 | office/amazon/desk_chair/frame_0013.jpg 6 490 | office/amazon/desk_chair/frame_0073.jpg 6 491 | office/amazon/desk_chair/frame_0005.jpg 6 492 | office/amazon/desk_chair/frame_0041.jpg 6 493 | office/amazon/desk_chair/frame_0047.jpg 6 494 | office/amazon/desk_chair/frame_0015.jpg 6 495 | office/amazon/desk_chair/frame_0029.jpg 6 496 | office/amazon/desk_chair/frame_0009.jpg 6 497 | office/amazon/desk_chair/frame_0050.jpg 6 498 | office/amazon/desk_chair/frame_0079.jpg 6 499 | office/amazon/desk_chair/frame_0026.jpg 6 500 | office/amazon/desk_chair/frame_0059.jpg 6 501 | office/amazon/desk_chair/frame_0007.jpg 6 502 | office/amazon/desk_chair/frame_0040.jpg 6 503 | office/amazon/desk_chair/frame_0035.jpg 6 504 | office/amazon/desk_chair/frame_0023.jpg 6 505 | office/amazon/desk_chair/frame_0033.jpg 6 506 | office/amazon/desk_chair/frame_0072.jpg 6 507 | office/amazon/desk_chair/frame_0018.jpg 6 508 | office/amazon/desk_chair/frame_0032.jpg 6 509 | office/amazon/desk_chair/frame_0031.jpg 6 510 | office/amazon/desk_chair/frame_0019.jpg 6 511 | office/amazon/desk_chair/frame_0065.jpg 6 512 | office/amazon/desk_chair/frame_0056.jpg 6 513 | office/amazon/desk_chair/frame_0022.jpg 6 514 | office/amazon/desk_chair/frame_0001.jpg 6 515 | office/amazon/desk_chair/frame_0053.jpg 6 516 | office/amazon/desk_chair/frame_0080.jpg 6 517 | office/amazon/desk_chair/frame_0012.jpg 6 518 | office/amazon/desk_chair/frame_0043.jpg 6 519 | office/amazon/desk_chair/frame_0060.jpg 6 520 | office/amazon/desk_chair/frame_0062.jpg 6 521 | office/amazon/desk_chair/frame_0028.jpg 6 522 | office/amazon/desk_chair/frame_0025.jpg 6 523 | office/amazon/desk_chair/frame_0086.jpg 6 524 | office/amazon/desk_chair/frame_0071.jpg 6 525 | office/amazon/desk_chair/frame_0020.jpg 6 526 | office/amazon/desk_chair/frame_0078.jpg 6 527 | office/amazon/desk_chair/frame_0088.jpg 6 528 | office/amazon/desk_chair/frame_0083.jpg 6 529 | office/amazon/desk_chair/frame_0087.jpg 6 530 | office/amazon/desk_chair/frame_0036.jpg 6 531 | office/amazon/desk_chair/frame_0076.jpg 6 532 | office/amazon/desk_chair/frame_0082.jpg 6 533 | office/amazon/desk_chair/frame_0011.jpg 6 534 | office/amazon/desk_chair/frame_0044.jpg 6 535 | office/amazon/desk_chair/frame_0046.jpg 6 536 | office/amazon/desk_chair/frame_0085.jpg 6 537 | office/amazon/desk_chair/frame_0055.jpg 6 538 | office/amazon/desk_chair/frame_0074.jpg 6 539 | office/amazon/desk_chair/frame_0084.jpg 6 540 | office/amazon/desk_chair/frame_0045.jpg 6 541 | office/amazon/desk_chair/frame_0037.jpg 6 542 | office/amazon/desk_chair/frame_0075.jpg 6 543 | office/amazon/desk_chair/frame_0039.jpg 6 544 | office/amazon/desk_chair/frame_0067.jpg 6 545 | office/amazon/desk_chair/frame_0052.jpg 6 546 | office/amazon/desk_chair/frame_0057.jpg 6 547 | office/amazon/desk_chair/frame_0089.jpg 6 548 | office/amazon/desk_chair/frame_0014.jpg 6 549 | office/amazon/desk_chair/frame_0066.jpg 6 550 | office/amazon/desk_lamp/frame_0010.jpg 7 551 | office/amazon/desk_lamp/frame_0089.jpg 7 552 | office/amazon/desk_lamp/frame_0040.jpg 7 553 | office/amazon/desk_lamp/frame_0043.jpg 7 554 | office/amazon/desk_lamp/frame_0058.jpg 7 555 | office/amazon/desk_lamp/frame_0061.jpg 7 556 | office/amazon/desk_lamp/frame_0005.jpg 7 557 | office/amazon/desk_lamp/frame_0088.jpg 7 558 | office/amazon/desk_lamp/frame_0068.jpg 7 559 | office/amazon/desk_lamp/frame_0015.jpg 7 560 | office/amazon/desk_lamp/frame_0067.jpg 7 561 | office/amazon/desk_lamp/frame_0059.jpg 7 562 | office/amazon/desk_lamp/frame_0062.jpg 7 563 | office/amazon/desk_lamp/frame_0037.jpg 7 564 | office/amazon/desk_lamp/frame_0007.jpg 7 565 | office/amazon/desk_lamp/frame_0077.jpg 7 566 | office/amazon/desk_lamp/frame_0006.jpg 7 567 | office/amazon/desk_lamp/frame_0053.jpg 7 568 | office/amazon/desk_lamp/frame_0087.jpg 7 569 | office/amazon/desk_lamp/frame_0080.jpg 7 570 | office/amazon/desk_lamp/frame_0012.jpg 7 571 | office/amazon/desk_lamp/frame_0008.jpg 7 572 | office/amazon/desk_lamp/frame_0093.jpg 7 573 | office/amazon/desk_lamp/frame_0095.jpg 7 574 | office/amazon/desk_lamp/frame_0022.jpg 7 575 | office/amazon/desk_lamp/frame_0078.jpg 7 576 | office/amazon/desk_lamp/frame_0041.jpg 7 577 | office/amazon/desk_lamp/frame_0097.jpg 7 578 | office/amazon/desk_lamp/frame_0073.jpg 7 579 | office/amazon/desk_lamp/frame_0047.jpg 7 580 | office/amazon/desk_lamp/frame_0034.jpg 7 581 | office/amazon/desk_lamp/frame_0074.jpg 7 582 | office/amazon/desk_lamp/frame_0025.jpg 7 583 | office/amazon/desk_lamp/frame_0066.jpg 7 584 | office/amazon/desk_lamp/frame_0081.jpg 7 585 | office/amazon/desk_lamp/frame_0056.jpg 7 586 | office/amazon/desk_lamp/frame_0013.jpg 7 587 | office/amazon/desk_lamp/frame_0045.jpg 7 588 | office/amazon/desk_lamp/frame_0044.jpg 7 589 | office/amazon/desk_lamp/frame_0020.jpg 7 590 | office/amazon/desk_lamp/frame_0051.jpg 7 591 | office/amazon/desk_lamp/frame_0021.jpg 7 592 | office/amazon/desk_lamp/frame_0072.jpg 7 593 | office/amazon/desk_lamp/frame_0084.jpg 7 594 | office/amazon/desk_lamp/frame_0009.jpg 7 595 | office/amazon/desk_lamp/frame_0017.jpg 7 596 | office/amazon/desk_lamp/frame_0096.jpg 7 597 | office/amazon/desk_lamp/frame_0070.jpg 7 598 | office/amazon/desk_lamp/frame_0038.jpg 7 599 | office/amazon/desk_lamp/frame_0024.jpg 7 600 | office/amazon/desk_lamp/frame_0052.jpg 7 601 | office/amazon/desk_lamp/frame_0030.jpg 7 602 | office/amazon/desk_lamp/frame_0011.jpg 7 603 | office/amazon/desk_lamp/frame_0069.jpg 7 604 | office/amazon/desk_lamp/frame_0076.jpg 7 605 | office/amazon/desk_lamp/frame_0064.jpg 7 606 | office/amazon/desk_lamp/frame_0055.jpg 7 607 | office/amazon/desk_lamp/frame_0065.jpg 7 608 | office/amazon/desk_lamp/frame_0004.jpg 7 609 | office/amazon/desk_lamp/frame_0046.jpg 7 610 | office/amazon/desk_lamp/frame_0018.jpg 7 611 | office/amazon/desk_lamp/frame_0031.jpg 7 612 | office/amazon/desk_lamp/frame_0001.jpg 7 613 | office/amazon/desk_lamp/frame_0039.jpg 7 614 | office/amazon/desk_lamp/frame_0075.jpg 7 615 | office/amazon/desk_lamp/frame_0083.jpg 7 616 | office/amazon/desk_lamp/frame_0033.jpg 7 617 | office/amazon/desk_lamp/frame_0054.jpg 7 618 | office/amazon/desk_lamp/frame_0063.jpg 7 619 | office/amazon/desk_lamp/frame_0082.jpg 7 620 | office/amazon/desk_lamp/frame_0091.jpg 7 621 | office/amazon/desk_lamp/frame_0071.jpg 7 622 | office/amazon/desk_lamp/frame_0036.jpg 7 623 | office/amazon/desk_lamp/frame_0032.jpg 7 624 | office/amazon/desk_lamp/frame_0028.jpg 7 625 | office/amazon/desk_lamp/frame_0029.jpg 7 626 | office/amazon/desk_lamp/frame_0019.jpg 7 627 | office/amazon/desk_lamp/frame_0049.jpg 7 628 | office/amazon/desk_lamp/frame_0094.jpg 7 629 | office/amazon/desk_lamp/frame_0027.jpg 7 630 | office/amazon/desk_lamp/frame_0026.jpg 7 631 | office/amazon/desk_lamp/frame_0003.jpg 7 632 | office/amazon/desk_lamp/frame_0048.jpg 7 633 | office/amazon/desk_lamp/frame_0085.jpg 7 634 | office/amazon/desk_lamp/frame_0057.jpg 7 635 | office/amazon/desk_lamp/frame_0086.jpg 7 636 | office/amazon/desk_lamp/frame_0060.jpg 7 637 | office/amazon/desk_lamp/frame_0042.jpg 7 638 | office/amazon/desk_lamp/frame_0023.jpg 7 639 | office/amazon/desk_lamp/frame_0090.jpg 7 640 | office/amazon/desk_lamp/frame_0014.jpg 7 641 | office/amazon/desk_lamp/frame_0002.jpg 7 642 | office/amazon/desk_lamp/frame_0050.jpg 7 643 | office/amazon/desk_lamp/frame_0092.jpg 7 644 | office/amazon/desk_lamp/frame_0016.jpg 7 645 | office/amazon/desk_lamp/frame_0035.jpg 7 646 | office/amazon/desk_lamp/frame_0079.jpg 7 647 | office/amazon/desktop_computer/frame_0044.jpg 8 648 | office/amazon/desktop_computer/frame_0065.jpg 8 649 | office/amazon/desktop_computer/frame_0073.jpg 8 650 | office/amazon/desktop_computer/frame_0085.jpg 8 651 | office/amazon/desktop_computer/frame_0014.jpg 8 652 | office/amazon/desktop_computer/frame_0080.jpg 8 653 | office/amazon/desktop_computer/frame_0021.jpg 8 654 | office/amazon/desktop_computer/frame_0071.jpg 8 655 | office/amazon/desktop_computer/frame_0041.jpg 8 656 | office/amazon/desktop_computer/frame_0083.jpg 8 657 | office/amazon/desktop_computer/frame_0040.jpg 8 658 | office/amazon/desktop_computer/frame_0090.jpg 8 659 | office/amazon/desktop_computer/frame_0039.jpg 8 660 | office/amazon/desktop_computer/frame_0072.jpg 8 661 | office/amazon/desktop_computer/frame_0061.jpg 8 662 | office/amazon/desktop_computer/frame_0010.jpg 8 663 | office/amazon/desktop_computer/frame_0063.jpg 8 664 | office/amazon/desktop_computer/frame_0062.jpg 8 665 | office/amazon/desktop_computer/frame_0092.jpg 8 666 | office/amazon/desktop_computer/frame_0054.jpg 8 667 | office/amazon/desktop_computer/frame_0012.jpg 8 668 | office/amazon/desktop_computer/frame_0036.jpg 8 669 | office/amazon/desktop_computer/frame_0002.jpg 8 670 | office/amazon/desktop_computer/frame_0037.jpg 8 671 | office/amazon/desktop_computer/frame_0018.jpg 8 672 | office/amazon/desktop_computer/frame_0078.jpg 8 673 | office/amazon/desktop_computer/frame_0051.jpg 8 674 | office/amazon/desktop_computer/frame_0033.jpg 8 675 | office/amazon/desktop_computer/frame_0049.jpg 8 676 | office/amazon/desktop_computer/frame_0089.jpg 8 677 | office/amazon/desktop_computer/frame_0060.jpg 8 678 | office/amazon/desktop_computer/frame_0038.jpg 8 679 | office/amazon/desktop_computer/frame_0029.jpg 8 680 | office/amazon/desktop_computer/frame_0087.jpg 8 681 | office/amazon/desktop_computer/frame_0019.jpg 8 682 | office/amazon/desktop_computer/frame_0057.jpg 8 683 | office/amazon/desktop_computer/frame_0015.jpg 8 684 | office/amazon/desktop_computer/frame_0020.jpg 8 685 | office/amazon/desktop_computer/frame_0059.jpg 8 686 | office/amazon/desktop_computer/frame_0068.jpg 8 687 | office/amazon/desktop_computer/frame_0001.jpg 8 688 | office/amazon/desktop_computer/frame_0026.jpg 8 689 | office/amazon/desktop_computer/frame_0079.jpg 8 690 | office/amazon/desktop_computer/frame_0095.jpg 8 691 | office/amazon/desktop_computer/frame_0030.jpg 8 692 | office/amazon/desktop_computer/frame_0084.jpg 8 693 | office/amazon/desktop_computer/frame_0077.jpg 8 694 | office/amazon/desktop_computer/frame_0055.jpg 8 695 | office/amazon/desktop_computer/frame_0008.jpg 8 696 | office/amazon/desktop_computer/frame_0086.jpg 8 697 | office/amazon/desktop_computer/frame_0042.jpg 8 698 | office/amazon/desktop_computer/frame_0069.jpg 8 699 | office/amazon/desktop_computer/frame_0004.jpg 8 700 | office/amazon/desktop_computer/frame_0032.jpg 8 701 | office/amazon/desktop_computer/frame_0070.jpg 8 702 | office/amazon/desktop_computer/frame_0053.jpg 8 703 | office/amazon/desktop_computer/frame_0027.jpg 8 704 | office/amazon/desktop_computer/frame_0011.jpg 8 705 | office/amazon/desktop_computer/frame_0005.jpg 8 706 | office/amazon/desktop_computer/frame_0013.jpg 8 707 | office/amazon/desktop_computer/frame_0076.jpg 8 708 | office/amazon/desktop_computer/frame_0009.jpg 8 709 | office/amazon/desktop_computer/frame_0093.jpg 8 710 | office/amazon/desktop_computer/frame_0045.jpg 8 711 | office/amazon/desktop_computer/frame_0081.jpg 8 712 | office/amazon/desktop_computer/frame_0050.jpg 8 713 | office/amazon/desktop_computer/frame_0017.jpg 8 714 | office/amazon/desktop_computer/frame_0034.jpg 8 715 | office/amazon/desktop_computer/frame_0075.jpg 8 716 | office/amazon/desktop_computer/frame_0091.jpg 8 717 | office/amazon/desktop_computer/frame_0056.jpg 8 718 | office/amazon/desktop_computer/frame_0024.jpg 8 719 | office/amazon/desktop_computer/frame_0066.jpg 8 720 | office/amazon/desktop_computer/frame_0097.jpg 8 721 | office/amazon/desktop_computer/frame_0048.jpg 8 722 | office/amazon/desktop_computer/frame_0064.jpg 8 723 | office/amazon/desktop_computer/frame_0046.jpg 8 724 | office/amazon/desktop_computer/frame_0007.jpg 8 725 | office/amazon/desktop_computer/frame_0023.jpg 8 726 | office/amazon/desktop_computer/frame_0043.jpg 8 727 | office/amazon/desktop_computer/frame_0067.jpg 8 728 | office/amazon/desktop_computer/frame_0025.jpg 8 729 | office/amazon/desktop_computer/frame_0074.jpg 8 730 | office/amazon/desktop_computer/frame_0088.jpg 8 731 | office/amazon/desktop_computer/frame_0035.jpg 8 732 | office/amazon/desktop_computer/frame_0028.jpg 8 733 | office/amazon/desktop_computer/frame_0016.jpg 8 734 | office/amazon/desktop_computer/frame_0058.jpg 8 735 | office/amazon/desktop_computer/frame_0096.jpg 8 736 | office/amazon/desktop_computer/frame_0003.jpg 8 737 | office/amazon/desktop_computer/frame_0082.jpg 8 738 | office/amazon/desktop_computer/frame_0031.jpg 8 739 | office/amazon/desktop_computer/frame_0022.jpg 8 740 | office/amazon/desktop_computer/frame_0006.jpg 8 741 | office/amazon/desktop_computer/frame_0094.jpg 8 742 | office/amazon/desktop_computer/frame_0047.jpg 8 743 | office/amazon/desktop_computer/frame_0052.jpg 8 744 | office/amazon/file_cabinet/frame_0051.jpg 9 745 | office/amazon/file_cabinet/frame_0073.jpg 9 746 | office/amazon/file_cabinet/frame_0029.jpg 9 747 | office/amazon/file_cabinet/frame_0050.jpg 9 748 | office/amazon/file_cabinet/frame_0022.jpg 9 749 | office/amazon/file_cabinet/frame_0059.jpg 9 750 | office/amazon/file_cabinet/frame_0027.jpg 9 751 | office/amazon/file_cabinet/frame_0071.jpg 9 752 | office/amazon/file_cabinet/frame_0013.jpg 9 753 | office/amazon/file_cabinet/frame_0068.jpg 9 754 | office/amazon/file_cabinet/frame_0020.jpg 9 755 | office/amazon/file_cabinet/frame_0056.jpg 9 756 | office/amazon/file_cabinet/frame_0009.jpg 9 757 | office/amazon/file_cabinet/frame_0048.jpg 9 758 | office/amazon/file_cabinet/frame_0001.jpg 9 759 | office/amazon/file_cabinet/frame_0023.jpg 9 760 | office/amazon/file_cabinet/frame_0021.jpg 9 761 | office/amazon/file_cabinet/frame_0077.jpg 9 762 | office/amazon/file_cabinet/frame_0063.jpg 9 763 | office/amazon/file_cabinet/frame_0044.jpg 9 764 | office/amazon/file_cabinet/frame_0018.jpg 9 765 | office/amazon/file_cabinet/frame_0067.jpg 9 766 | office/amazon/file_cabinet/frame_0028.jpg 9 767 | office/amazon/file_cabinet/frame_0039.jpg 9 768 | office/amazon/file_cabinet/frame_0014.jpg 9 769 | office/amazon/file_cabinet/frame_0007.jpg 9 770 | office/amazon/file_cabinet/frame_0019.jpg 9 771 | office/amazon/file_cabinet/frame_0045.jpg 9 772 | office/amazon/file_cabinet/frame_0016.jpg 9 773 | office/amazon/file_cabinet/frame_0072.jpg 9 774 | office/amazon/file_cabinet/frame_0057.jpg 9 775 | office/amazon/file_cabinet/frame_0017.jpg 9 776 | office/amazon/file_cabinet/frame_0078.jpg 9 777 | office/amazon/file_cabinet/frame_0053.jpg 9 778 | office/amazon/file_cabinet/frame_0031.jpg 9 779 | office/amazon/file_cabinet/frame_0024.jpg 9 780 | office/amazon/file_cabinet/frame_0040.jpg 9 781 | office/amazon/file_cabinet/frame_0005.jpg 9 782 | office/amazon/file_cabinet/frame_0064.jpg 9 783 | office/amazon/file_cabinet/frame_0054.jpg 9 784 | office/amazon/file_cabinet/frame_0041.jpg 9 785 | office/amazon/file_cabinet/frame_0037.jpg 9 786 | office/amazon/file_cabinet/frame_0074.jpg 9 787 | office/amazon/file_cabinet/frame_0060.jpg 9 788 | office/amazon/file_cabinet/frame_0049.jpg 9 789 | office/amazon/file_cabinet/frame_0038.jpg 9 790 | office/amazon/file_cabinet/frame_0070.jpg 9 791 | office/amazon/file_cabinet/frame_0008.jpg 9 792 | office/amazon/file_cabinet/frame_0055.jpg 9 793 | office/amazon/file_cabinet/frame_0033.jpg 9 794 | office/amazon/file_cabinet/frame_0062.jpg 9 795 | office/amazon/file_cabinet/frame_0080.jpg 9 796 | office/amazon/file_cabinet/frame_0046.jpg 9 797 | office/amazon/file_cabinet/frame_0035.jpg 9 798 | office/amazon/file_cabinet/frame_0043.jpg 9 799 | office/amazon/file_cabinet/frame_0075.jpg 9 800 | office/amazon/file_cabinet/frame_0010.jpg 9 801 | office/amazon/file_cabinet/frame_0002.jpg 9 802 | office/amazon/file_cabinet/frame_0042.jpg 9 803 | office/amazon/file_cabinet/frame_0025.jpg 9 804 | office/amazon/file_cabinet/frame_0052.jpg 9 805 | office/amazon/file_cabinet/frame_0047.jpg 9 806 | office/amazon/file_cabinet/frame_0034.jpg 9 807 | office/amazon/file_cabinet/frame_0032.jpg 9 808 | office/amazon/file_cabinet/frame_0076.jpg 9 809 | office/amazon/file_cabinet/frame_0004.jpg 9 810 | office/amazon/file_cabinet/frame_0026.jpg 9 811 | office/amazon/file_cabinet/frame_0036.jpg 9 812 | office/amazon/file_cabinet/frame_0011.jpg 9 813 | office/amazon/file_cabinet/frame_0006.jpg 9 814 | office/amazon/file_cabinet/frame_0079.jpg 9 815 | office/amazon/file_cabinet/frame_0012.jpg 9 816 | office/amazon/file_cabinet/frame_0065.jpg 9 817 | office/amazon/file_cabinet/frame_0069.jpg 9 818 | office/amazon/file_cabinet/frame_0066.jpg 9 819 | office/amazon/file_cabinet/frame_0081.jpg 9 820 | office/amazon/file_cabinet/frame_0015.jpg 9 821 | office/amazon/file_cabinet/frame_0030.jpg 9 822 | office/amazon/file_cabinet/frame_0003.jpg 9 823 | office/amazon/file_cabinet/frame_0061.jpg 9 824 | office/amazon/file_cabinet/frame_0058.jpg 9 -------------------------------------------------------------------------------- /data/dslr_0-9_20-30_test.txt: -------------------------------------------------------------------------------- 1 | office/dslr/backpack/frame_0009.jpg 0 2 | office/dslr/backpack/frame_0003.jpg 0 3 | office/dslr/backpack/frame_0011.jpg 0 4 | office/dslr/backpack/frame_0001.jpg 0 5 | office/dslr/backpack/frame_0006.jpg 0 6 | office/dslr/backpack/frame_0005.jpg 0 7 | office/dslr/backpack/frame_0008.jpg 0 8 | office/dslr/backpack/frame_0004.jpg 0 9 | office/dslr/backpack/frame_0012.jpg 0 10 | office/dslr/backpack/frame_0002.jpg 0 11 | office/dslr/backpack/frame_0010.jpg 0 12 | office/dslr/backpack/frame_0007.jpg 0 13 | office/dslr/bike/frame_0005.jpg 1 14 | office/dslr/bike/frame_0001.jpg 1 15 | office/dslr/bike/frame_0018.jpg 1 16 | office/dslr/bike/frame_0004.jpg 1 17 | office/dslr/bike/frame_0014.jpg 1 18 | office/dslr/bike/frame_0012.jpg 1 19 | office/dslr/bike/frame_0002.jpg 1 20 | office/dslr/bike/frame_0009.jpg 1 21 | office/dslr/bike/frame_0003.jpg 1 22 | office/dslr/bike/frame_0008.jpg 1 23 | office/dslr/bike/frame_0015.jpg 1 24 | office/dslr/bike/frame_0007.jpg 1 25 | office/dslr/bike/frame_0016.jpg 1 26 | office/dslr/bike/frame_0019.jpg 1 27 | office/dslr/bike/frame_0006.jpg 1 28 | office/dslr/bike/frame_0017.jpg 1 29 | office/dslr/bike/frame_0013.jpg 1 30 | office/dslr/bike/frame_0021.jpg 1 31 | office/dslr/bike/frame_0020.jpg 1 32 | office/dslr/bike/frame_0011.jpg 1 33 | office/dslr/bike/frame_0010.jpg 1 34 | office/dslr/bike_helmet/frame_0001.jpg 2 35 | office/dslr/bike_helmet/frame_0007.jpg 2 36 | office/dslr/bike_helmet/frame_0021.jpg 2 37 | office/dslr/bike_helmet/frame_0010.jpg 2 38 | office/dslr/bike_helmet/frame_0023.jpg 2 39 | office/dslr/bike_helmet/frame_0017.jpg 2 40 | office/dslr/bike_helmet/frame_0003.jpg 2 41 | office/dslr/bike_helmet/frame_0019.jpg 2 42 | office/dslr/bike_helmet/frame_0016.jpg 2 43 | office/dslr/bike_helmet/frame_0012.jpg 2 44 | office/dslr/bike_helmet/frame_0015.jpg 2 45 | office/dslr/bike_helmet/frame_0011.jpg 2 46 | office/dslr/bike_helmet/frame_0009.jpg 2 47 | office/dslr/bike_helmet/frame_0020.jpg 2 48 | office/dslr/bike_helmet/frame_0006.jpg 2 49 | office/dslr/bike_helmet/frame_0005.jpg 2 50 | office/dslr/bike_helmet/frame_0008.jpg 2 51 | office/dslr/bike_helmet/frame_0002.jpg 2 52 | office/dslr/bike_helmet/frame_0013.jpg 2 53 | office/dslr/bike_helmet/frame_0022.jpg 2 54 | office/dslr/bike_helmet/frame_0018.jpg 2 55 | office/dslr/bike_helmet/frame_0024.jpg 2 56 | office/dslr/bike_helmet/frame_0004.jpg 2 57 | office/dslr/bike_helmet/frame_0014.jpg 2 58 | office/dslr/bookcase/frame_0010.jpg 3 59 | office/dslr/bookcase/frame_0012.jpg 3 60 | office/dslr/bookcase/frame_0002.jpg 3 61 | office/dslr/bookcase/frame_0007.jpg 3 62 | office/dslr/bookcase/frame_0001.jpg 3 63 | office/dslr/bookcase/frame_0005.jpg 3 64 | office/dslr/bookcase/frame_0009.jpg 3 65 | office/dslr/bookcase/frame_0011.jpg 3 66 | office/dslr/bookcase/frame_0004.jpg 3 67 | office/dslr/bookcase/frame_0006.jpg 3 68 | office/dslr/bookcase/frame_0003.jpg 3 69 | office/dslr/bookcase/frame_0008.jpg 3 70 | office/dslr/bottle/frame_0011.jpg 4 71 | office/dslr/bottle/frame_0008.jpg 4 72 | office/dslr/bottle/frame_0001.jpg 4 73 | office/dslr/bottle/frame_0016.jpg 4 74 | office/dslr/bottle/frame_0002.jpg 4 75 | office/dslr/bottle/frame_0003.jpg 4 76 | office/dslr/bottle/frame_0004.jpg 4 77 | office/dslr/bottle/frame_0014.jpg 4 78 | office/dslr/bottle/frame_0006.jpg 4 79 | office/dslr/bottle/frame_0013.jpg 4 80 | office/dslr/bottle/frame_0007.jpg 4 81 | office/dslr/bottle/frame_0012.jpg 4 82 | office/dslr/bottle/frame_0009.jpg 4 83 | office/dslr/bottle/frame_0005.jpg 4 84 | office/dslr/bottle/frame_0010.jpg 4 85 | office/dslr/bottle/frame_0015.jpg 4 86 | office/dslr/calculator/frame_0004.jpg 5 87 | office/dslr/calculator/frame_0005.jpg 5 88 | office/dslr/calculator/frame_0010.jpg 5 89 | office/dslr/calculator/frame_0009.jpg 5 90 | office/dslr/calculator/frame_0008.jpg 5 91 | office/dslr/calculator/frame_0011.jpg 5 92 | office/dslr/calculator/frame_0006.jpg 5 93 | office/dslr/calculator/frame_0007.jpg 5 94 | office/dslr/calculator/frame_0012.jpg 5 95 | office/dslr/calculator/frame_0003.jpg 5 96 | office/dslr/calculator/frame_0001.jpg 5 97 | office/dslr/calculator/frame_0002.jpg 5 98 | office/dslr/desk_chair/frame_0013.jpg 6 99 | office/dslr/desk_chair/frame_0006.jpg 6 100 | office/dslr/desk_chair/frame_0005.jpg 6 101 | office/dslr/desk_chair/frame_0002.jpg 6 102 | office/dslr/desk_chair/frame_0009.jpg 6 103 | office/dslr/desk_chair/frame_0008.jpg 6 104 | office/dslr/desk_chair/frame_0004.jpg 6 105 | office/dslr/desk_chair/frame_0007.jpg 6 106 | office/dslr/desk_chair/frame_0003.jpg 6 107 | office/dslr/desk_chair/frame_0011.jpg 6 108 | office/dslr/desk_chair/frame_0012.jpg 6 109 | office/dslr/desk_chair/frame_0010.jpg 6 110 | office/dslr/desk_chair/frame_0001.jpg 6 111 | office/dslr/desk_lamp/frame_0007.jpg 7 112 | office/dslr/desk_lamp/frame_0011.jpg 7 113 | office/dslr/desk_lamp/frame_0010.jpg 7 114 | office/dslr/desk_lamp/frame_0009.jpg 7 115 | office/dslr/desk_lamp/frame_0014.jpg 7 116 | office/dslr/desk_lamp/frame_0013.jpg 7 117 | office/dslr/desk_lamp/frame_0001.jpg 7 118 | office/dslr/desk_lamp/frame_0012.jpg 7 119 | office/dslr/desk_lamp/frame_0003.jpg 7 120 | office/dslr/desk_lamp/frame_0008.jpg 7 121 | office/dslr/desk_lamp/frame_0006.jpg 7 122 | office/dslr/desk_lamp/frame_0005.jpg 7 123 | office/dslr/desk_lamp/frame_0004.jpg 7 124 | office/dslr/desk_lamp/frame_0002.jpg 7 125 | office/dslr/desktop_computer/frame_0010.jpg 8 126 | office/dslr/desktop_computer/frame_0005.jpg 8 127 | office/dslr/desktop_computer/frame_0008.jpg 8 128 | office/dslr/desktop_computer/frame_0004.jpg 8 129 | office/dslr/desktop_computer/frame_0011.jpg 8 130 | office/dslr/desktop_computer/frame_0002.jpg 8 131 | office/dslr/desktop_computer/frame_0001.jpg 8 132 | office/dslr/desktop_computer/frame_0014.jpg 8 133 | office/dslr/desktop_computer/frame_0013.jpg 8 134 | office/dslr/desktop_computer/frame_0009.jpg 8 135 | office/dslr/desktop_computer/frame_0015.jpg 8 136 | office/dslr/desktop_computer/frame_0007.jpg 8 137 | office/dslr/desktop_computer/frame_0012.jpg 8 138 | office/dslr/desktop_computer/frame_0003.jpg 8 139 | office/dslr/desktop_computer/frame_0006.jpg 8 140 | office/dslr/file_cabinet/frame_0014.jpg 9 141 | office/dslr/file_cabinet/frame_0003.jpg 9 142 | office/dslr/file_cabinet/frame_0015.jpg 9 143 | office/dslr/file_cabinet/frame_0008.jpg 9 144 | office/dslr/file_cabinet/frame_0011.jpg 9 145 | office/dslr/file_cabinet/frame_0010.jpg 9 146 | office/dslr/file_cabinet/frame_0002.jpg 9 147 | office/dslr/file_cabinet/frame_0006.jpg 9 148 | office/dslr/file_cabinet/frame_0012.jpg 9 149 | office/dslr/file_cabinet/frame_0013.jpg 9 150 | office/dslr/file_cabinet/frame_0007.jpg 9 151 | office/dslr/file_cabinet/frame_0004.jpg 9 152 | office/dslr/file_cabinet/frame_0001.jpg 9 153 | office/dslr/file_cabinet/frame_0005.jpg 9 154 | office/dslr/file_cabinet/frame_0009.jpg 9 155 | office/dslr/phone/frame_0011.jpg 10 156 | office/dslr/phone/frame_0002.jpg 10 157 | office/dslr/phone/frame_0010.jpg 10 158 | office/dslr/phone/frame_0001.jpg 10 159 | office/dslr/phone/frame_0003.jpg 10 160 | office/dslr/phone/frame_0013.jpg 10 161 | office/dslr/phone/frame_0007.jpg 10 162 | office/dslr/phone/frame_0006.jpg 10 163 | office/dslr/phone/frame_0012.jpg 10 164 | office/dslr/phone/frame_0005.jpg 10 165 | office/dslr/phone/frame_0004.jpg 10 166 | office/dslr/phone/frame_0008.jpg 10 167 | office/dslr/phone/frame_0009.jpg 10 168 | office/dslr/printer/frame_0004.jpg 11 169 | office/dslr/printer/frame_0010.jpg 11 170 | office/dslr/printer/frame_0005.jpg 11 171 | office/dslr/printer/frame_0001.jpg 11 172 | office/dslr/printer/frame_0012.jpg 11 173 | office/dslr/printer/frame_0003.jpg 11 174 | office/dslr/printer/frame_0009.jpg 11 175 | office/dslr/printer/frame_0011.jpg 11 176 | office/dslr/printer/frame_0007.jpg 11 177 | office/dslr/printer/frame_0006.jpg 11 178 | office/dslr/printer/frame_0013.jpg 11 179 | office/dslr/printer/frame_0002.jpg 11 180 | office/dslr/printer/frame_0008.jpg 11 181 | office/dslr/printer/frame_0015.jpg 11 182 | office/dslr/printer/frame_0014.jpg 11 183 | office/dslr/projector/frame_0012.jpg 12 184 | office/dslr/projector/frame_0020.jpg 12 185 | office/dslr/projector/frame_0005.jpg 12 186 | office/dslr/projector/frame_0017.jpg 12 187 | office/dslr/projector/frame_0003.jpg 12 188 | office/dslr/projector/frame_0022.jpg 12 189 | office/dslr/projector/frame_0009.jpg 12 190 | office/dslr/projector/frame_0019.jpg 12 191 | office/dslr/projector/frame_0007.jpg 12 192 | office/dslr/projector/frame_0016.jpg 12 193 | office/dslr/projector/frame_0018.jpg 12 194 | office/dslr/projector/frame_0001.jpg 12 195 | office/dslr/projector/frame_0010.jpg 12 196 | office/dslr/projector/frame_0023.jpg 12 197 | office/dslr/projector/frame_0014.jpg 12 198 | office/dslr/projector/frame_0006.jpg 12 199 | office/dslr/projector/frame_0008.jpg 12 200 | office/dslr/projector/frame_0013.jpg 12 201 | office/dslr/projector/frame_0002.jpg 12 202 | office/dslr/projector/frame_0021.jpg 12 203 | office/dslr/projector/frame_0015.jpg 12 204 | office/dslr/projector/frame_0004.jpg 12 205 | office/dslr/projector/frame_0011.jpg 12 206 | office/dslr/punchers/frame_0006.jpg 13 207 | office/dslr/punchers/frame_0017.jpg 13 208 | office/dslr/punchers/frame_0016.jpg 13 209 | office/dslr/punchers/frame_0013.jpg 13 210 | office/dslr/punchers/frame_0003.jpg 13 211 | office/dslr/punchers/frame_0018.jpg 13 212 | office/dslr/punchers/frame_0001.jpg 13 213 | office/dslr/punchers/frame_0010.jpg 13 214 | office/dslr/punchers/frame_0008.jpg 13 215 | office/dslr/punchers/frame_0015.jpg 13 216 | office/dslr/punchers/frame_0005.jpg 13 217 | office/dslr/punchers/frame_0014.jpg 13 218 | office/dslr/punchers/frame_0009.jpg 13 219 | office/dslr/punchers/frame_0004.jpg 13 220 | office/dslr/punchers/frame_0012.jpg 13 221 | office/dslr/punchers/frame_0007.jpg 13 222 | office/dslr/punchers/frame_0011.jpg 13 223 | office/dslr/punchers/frame_0002.jpg 13 224 | office/dslr/ring_binder/frame_0010.jpg 14 225 | office/dslr/ring_binder/frame_0003.jpg 14 226 | office/dslr/ring_binder/frame_0009.jpg 14 227 | office/dslr/ring_binder/frame_0004.jpg 14 228 | office/dslr/ring_binder/frame_0001.jpg 14 229 | office/dslr/ring_binder/frame_0006.jpg 14 230 | office/dslr/ring_binder/frame_0005.jpg 14 231 | office/dslr/ring_binder/frame_0002.jpg 14 232 | office/dslr/ring_binder/frame_0007.jpg 14 233 | office/dslr/ring_binder/frame_0008.jpg 14 234 | office/dslr/ruler/frame_0006.jpg 15 235 | office/dslr/ruler/frame_0004.jpg 15 236 | office/dslr/ruler/frame_0001.jpg 15 237 | office/dslr/ruler/frame_0007.jpg 15 238 | office/dslr/ruler/frame_0002.jpg 15 239 | office/dslr/ruler/frame_0003.jpg 15 240 | office/dslr/ruler/frame_0005.jpg 15 241 | office/dslr/scissors/frame_0017.jpg 16 242 | office/dslr/scissors/frame_0013.jpg 16 243 | office/dslr/scissors/frame_0006.jpg 16 244 | office/dslr/scissors/frame_0015.jpg 16 245 | office/dslr/scissors/frame_0016.jpg 16 246 | office/dslr/scissors/frame_0012.jpg 16 247 | office/dslr/scissors/frame_0002.jpg 16 248 | office/dslr/scissors/frame_0004.jpg 16 249 | office/dslr/scissors/frame_0010.jpg 16 250 | office/dslr/scissors/frame_0007.jpg 16 251 | office/dslr/scissors/frame_0011.jpg 16 252 | office/dslr/scissors/frame_0001.jpg 16 253 | office/dslr/scissors/frame_0018.jpg 16 254 | office/dslr/scissors/frame_0005.jpg 16 255 | office/dslr/scissors/frame_0008.jpg 16 256 | office/dslr/scissors/frame_0014.jpg 16 257 | office/dslr/scissors/frame_0009.jpg 16 258 | office/dslr/scissors/frame_0003.jpg 16 259 | office/dslr/speaker/frame_0021.jpg 17 260 | office/dslr/speaker/frame_0019.jpg 17 261 | office/dslr/speaker/frame_0004.jpg 17 262 | office/dslr/speaker/frame_0016.jpg 17 263 | office/dslr/speaker/frame_0012.jpg 17 264 | office/dslr/speaker/frame_0009.jpg 17 265 | office/dslr/speaker/frame_0017.jpg 17 266 | office/dslr/speaker/frame_0018.jpg 17 267 | office/dslr/speaker/frame_0015.jpg 17 268 | office/dslr/speaker/frame_0025.jpg 17 269 | office/dslr/speaker/frame_0022.jpg 17 270 | office/dslr/speaker/frame_0024.jpg 17 271 | office/dslr/speaker/frame_0023.jpg 17 272 | office/dslr/speaker/frame_0006.jpg 17 273 | office/dslr/speaker/frame_0005.jpg 17 274 | office/dslr/speaker/frame_0020.jpg 17 275 | office/dslr/speaker/frame_0010.jpg 17 276 | office/dslr/speaker/frame_0011.jpg 17 277 | office/dslr/speaker/frame_0007.jpg 17 278 | office/dslr/speaker/frame_0026.jpg 17 279 | office/dslr/speaker/frame_0008.jpg 17 280 | office/dslr/speaker/frame_0002.jpg 17 281 | office/dslr/speaker/frame_0013.jpg 17 282 | office/dslr/speaker/frame_0001.jpg 17 283 | office/dslr/speaker/frame_0003.jpg 17 284 | office/dslr/speaker/frame_0014.jpg 17 285 | office/dslr/stapler/frame_0013.jpg 18 286 | office/dslr/stapler/frame_0010.jpg 18 287 | office/dslr/stapler/frame_0021.jpg 18 288 | office/dslr/stapler/frame_0006.jpg 18 289 | office/dslr/stapler/frame_0002.jpg 18 290 | office/dslr/stapler/frame_0003.jpg 18 291 | office/dslr/stapler/frame_0004.jpg 18 292 | office/dslr/stapler/frame_0011.jpg 18 293 | office/dslr/stapler/frame_0001.jpg 18 294 | office/dslr/stapler/frame_0008.jpg 18 295 | office/dslr/stapler/frame_0009.jpg 18 296 | office/dslr/stapler/frame_0015.jpg 18 297 | office/dslr/stapler/frame_0018.jpg 18 298 | office/dslr/stapler/frame_0016.jpg 18 299 | office/dslr/stapler/frame_0019.jpg 18 300 | office/dslr/stapler/frame_0012.jpg 18 301 | office/dslr/stapler/frame_0014.jpg 18 302 | office/dslr/stapler/frame_0007.jpg 18 303 | office/dslr/stapler/frame_0005.jpg 18 304 | office/dslr/stapler/frame_0020.jpg 18 305 | office/dslr/stapler/frame_0017.jpg 18 306 | office/dslr/tape_dispenser/frame_0005.jpg 19 307 | office/dslr/tape_dispenser/frame_0003.jpg 19 308 | office/dslr/tape_dispenser/frame_0012.jpg 19 309 | office/dslr/tape_dispenser/frame_0022.jpg 19 310 | office/dslr/tape_dispenser/frame_0018.jpg 19 311 | office/dslr/tape_dispenser/frame_0017.jpg 19 312 | office/dslr/tape_dispenser/frame_0021.jpg 19 313 | office/dslr/tape_dispenser/frame_0015.jpg 19 314 | office/dslr/tape_dispenser/frame_0011.jpg 19 315 | office/dslr/tape_dispenser/frame_0014.jpg 19 316 | office/dslr/tape_dispenser/frame_0013.jpg 19 317 | office/dslr/tape_dispenser/frame_0004.jpg 19 318 | office/dslr/tape_dispenser/frame_0010.jpg 19 319 | office/dslr/tape_dispenser/frame_0009.jpg 19 320 | office/dslr/tape_dispenser/frame_0001.jpg 19 321 | office/dslr/tape_dispenser/frame_0007.jpg 19 322 | office/dslr/tape_dispenser/frame_0008.jpg 19 323 | office/dslr/tape_dispenser/frame_0016.jpg 19 324 | office/dslr/tape_dispenser/frame_0006.jpg 19 325 | office/dslr/tape_dispenser/frame_0019.jpg 19 326 | office/dslr/tape_dispenser/frame_0002.jpg 19 327 | office/dslr/tape_dispenser/frame_0020.jpg 19 328 | office/dslr/trash_can/frame_0011.jpg 20 329 | office/dslr/trash_can/frame_0006.jpg 20 330 | office/dslr/trash_can/frame_0015.jpg 20 331 | office/dslr/trash_can/frame_0002.jpg 20 332 | office/dslr/trash_can/frame_0009.jpg 20 333 | office/dslr/trash_can/frame_0004.jpg 20 334 | office/dslr/trash_can/frame_0008.jpg 20 335 | office/dslr/trash_can/frame_0001.jpg 20 336 | office/dslr/trash_can/frame_0013.jpg 20 337 | office/dslr/trash_can/frame_0007.jpg 20 338 | office/dslr/trash_can/frame_0010.jpg 20 339 | office/dslr/trash_can/frame_0005.jpg 20 340 | office/dslr/trash_can/frame_0012.jpg 20 341 | office/dslr/trash_can/frame_0003.jpg 20 342 | office/dslr/trash_can/frame_0014.jpg 20 -------------------------------------------------------------------------------- /data/dslr_0-9_train_all.txt: -------------------------------------------------------------------------------- 1 | office/dslr/backpack/frame_0009.jpg 0 2 | office/dslr/backpack/frame_0003.jpg 0 3 | office/dslr/backpack/frame_0011.jpg 0 4 | office/dslr/backpack/frame_0001.jpg 0 5 | office/dslr/backpack/frame_0006.jpg 0 6 | office/dslr/backpack/frame_0005.jpg 0 7 | office/dslr/backpack/frame_0008.jpg 0 8 | office/dslr/backpack/frame_0004.jpg 0 9 | office/dslr/backpack/frame_0012.jpg 0 10 | office/dslr/backpack/frame_0002.jpg 0 11 | office/dslr/backpack/frame_0010.jpg 0 12 | office/dslr/backpack/frame_0007.jpg 0 13 | office/dslr/bike/frame_0005.jpg 1 14 | office/dslr/bike/frame_0001.jpg 1 15 | office/dslr/bike/frame_0018.jpg 1 16 | office/dslr/bike/frame_0004.jpg 1 17 | office/dslr/bike/frame_0014.jpg 1 18 | office/dslr/bike/frame_0012.jpg 1 19 | office/dslr/bike/frame_0002.jpg 1 20 | office/dslr/bike/frame_0009.jpg 1 21 | office/dslr/bike/frame_0003.jpg 1 22 | office/dslr/bike/frame_0008.jpg 1 23 | office/dslr/bike/frame_0015.jpg 1 24 | office/dslr/bike/frame_0007.jpg 1 25 | office/dslr/bike/frame_0016.jpg 1 26 | office/dslr/bike/frame_0019.jpg 1 27 | office/dslr/bike/frame_0006.jpg 1 28 | office/dslr/bike/frame_0017.jpg 1 29 | office/dslr/bike/frame_0013.jpg 1 30 | office/dslr/bike/frame_0021.jpg 1 31 | office/dslr/bike/frame_0020.jpg 1 32 | office/dslr/bike/frame_0011.jpg 1 33 | office/dslr/bike/frame_0010.jpg 1 34 | office/dslr/bike_helmet/frame_0001.jpg 2 35 | office/dslr/bike_helmet/frame_0007.jpg 2 36 | office/dslr/bike_helmet/frame_0021.jpg 2 37 | office/dslr/bike_helmet/frame_0010.jpg 2 38 | office/dslr/bike_helmet/frame_0023.jpg 2 39 | office/dslr/bike_helmet/frame_0017.jpg 2 40 | office/dslr/bike_helmet/frame_0003.jpg 2 41 | office/dslr/bike_helmet/frame_0019.jpg 2 42 | office/dslr/bike_helmet/frame_0016.jpg 2 43 | office/dslr/bike_helmet/frame_0012.jpg 2 44 | office/dslr/bike_helmet/frame_0015.jpg 2 45 | office/dslr/bike_helmet/frame_0011.jpg 2 46 | office/dslr/bike_helmet/frame_0009.jpg 2 47 | office/dslr/bike_helmet/frame_0020.jpg 2 48 | office/dslr/bike_helmet/frame_0006.jpg 2 49 | office/dslr/bike_helmet/frame_0005.jpg 2 50 | office/dslr/bike_helmet/frame_0008.jpg 2 51 | office/dslr/bike_helmet/frame_0002.jpg 2 52 | office/dslr/bike_helmet/frame_0013.jpg 2 53 | office/dslr/bike_helmet/frame_0022.jpg 2 54 | office/dslr/bike_helmet/frame_0018.jpg 2 55 | office/dslr/bike_helmet/frame_0024.jpg 2 56 | office/dslr/bike_helmet/frame_0004.jpg 2 57 | office/dslr/bike_helmet/frame_0014.jpg 2 58 | office/dslr/bookcase/frame_0010.jpg 3 59 | office/dslr/bookcase/frame_0012.jpg 3 60 | office/dslr/bookcase/frame_0002.jpg 3 61 | office/dslr/bookcase/frame_0007.jpg 3 62 | office/dslr/bookcase/frame_0001.jpg 3 63 | office/dslr/bookcase/frame_0005.jpg 3 64 | office/dslr/bookcase/frame_0009.jpg 3 65 | office/dslr/bookcase/frame_0011.jpg 3 66 | office/dslr/bookcase/frame_0004.jpg 3 67 | office/dslr/bookcase/frame_0006.jpg 3 68 | office/dslr/bookcase/frame_0003.jpg 3 69 | office/dslr/bookcase/frame_0008.jpg 3 70 | office/dslr/bottle/frame_0011.jpg 4 71 | office/dslr/bottle/frame_0008.jpg 4 72 | office/dslr/bottle/frame_0001.jpg 4 73 | office/dslr/bottle/frame_0016.jpg 4 74 | office/dslr/bottle/frame_0002.jpg 4 75 | office/dslr/bottle/frame_0003.jpg 4 76 | office/dslr/bottle/frame_0004.jpg 4 77 | office/dslr/bottle/frame_0014.jpg 4 78 | office/dslr/bottle/frame_0006.jpg 4 79 | office/dslr/bottle/frame_0013.jpg 4 80 | office/dslr/bottle/frame_0007.jpg 4 81 | office/dslr/bottle/frame_0012.jpg 4 82 | office/dslr/bottle/frame_0009.jpg 4 83 | office/dslr/bottle/frame_0005.jpg 4 84 | office/dslr/bottle/frame_0010.jpg 4 85 | office/dslr/bottle/frame_0015.jpg 4 86 | office/dslr/calculator/frame_0004.jpg 5 87 | office/dslr/calculator/frame_0005.jpg 5 88 | office/dslr/calculator/frame_0010.jpg 5 89 | office/dslr/calculator/frame_0009.jpg 5 90 | office/dslr/calculator/frame_0008.jpg 5 91 | office/dslr/calculator/frame_0011.jpg 5 92 | office/dslr/calculator/frame_0006.jpg 5 93 | office/dslr/calculator/frame_0007.jpg 5 94 | office/dslr/calculator/frame_0012.jpg 5 95 | office/dslr/calculator/frame_0003.jpg 5 96 | office/dslr/calculator/frame_0001.jpg 5 97 | office/dslr/calculator/frame_0002.jpg 5 98 | office/dslr/desk_chair/frame_0013.jpg 6 99 | office/dslr/desk_chair/frame_0006.jpg 6 100 | office/dslr/desk_chair/frame_0005.jpg 6 101 | office/dslr/desk_chair/frame_0002.jpg 6 102 | office/dslr/desk_chair/frame_0009.jpg 6 103 | office/dslr/desk_chair/frame_0008.jpg 6 104 | office/dslr/desk_chair/frame_0004.jpg 6 105 | office/dslr/desk_chair/frame_0007.jpg 6 106 | office/dslr/desk_chair/frame_0003.jpg 6 107 | office/dslr/desk_chair/frame_0011.jpg 6 108 | office/dslr/desk_chair/frame_0012.jpg 6 109 | office/dslr/desk_chair/frame_0010.jpg 6 110 | office/dslr/desk_chair/frame_0001.jpg 6 111 | office/dslr/desk_lamp/frame_0007.jpg 7 112 | office/dslr/desk_lamp/frame_0011.jpg 7 113 | office/dslr/desk_lamp/frame_0010.jpg 7 114 | office/dslr/desk_lamp/frame_0009.jpg 7 115 | office/dslr/desk_lamp/frame_0014.jpg 7 116 | office/dslr/desk_lamp/frame_0013.jpg 7 117 | office/dslr/desk_lamp/frame_0001.jpg 7 118 | office/dslr/desk_lamp/frame_0012.jpg 7 119 | office/dslr/desk_lamp/frame_0003.jpg 7 120 | office/dslr/desk_lamp/frame_0008.jpg 7 121 | office/dslr/desk_lamp/frame_0006.jpg 7 122 | office/dslr/desk_lamp/frame_0005.jpg 7 123 | office/dslr/desk_lamp/frame_0004.jpg 7 124 | office/dslr/desk_lamp/frame_0002.jpg 7 125 | office/dslr/desktop_computer/frame_0010.jpg 8 126 | office/dslr/desktop_computer/frame_0005.jpg 8 127 | office/dslr/desktop_computer/frame_0008.jpg 8 128 | office/dslr/desktop_computer/frame_0004.jpg 8 129 | office/dslr/desktop_computer/frame_0011.jpg 8 130 | office/dslr/desktop_computer/frame_0002.jpg 8 131 | office/dslr/desktop_computer/frame_0001.jpg 8 132 | office/dslr/desktop_computer/frame_0014.jpg 8 133 | office/dslr/desktop_computer/frame_0013.jpg 8 134 | office/dslr/desktop_computer/frame_0009.jpg 8 135 | office/dslr/desktop_computer/frame_0015.jpg 8 136 | office/dslr/desktop_computer/frame_0007.jpg 8 137 | office/dslr/desktop_computer/frame_0012.jpg 8 138 | office/dslr/desktop_computer/frame_0003.jpg 8 139 | office/dslr/desktop_computer/frame_0006.jpg 8 140 | office/dslr/file_cabinet/frame_0014.jpg 9 141 | office/dslr/file_cabinet/frame_0003.jpg 9 142 | office/dslr/file_cabinet/frame_0015.jpg 9 143 | office/dslr/file_cabinet/frame_0008.jpg 9 144 | office/dslr/file_cabinet/frame_0011.jpg 9 145 | office/dslr/file_cabinet/frame_0010.jpg 9 146 | office/dslr/file_cabinet/frame_0002.jpg 9 147 | office/dslr/file_cabinet/frame_0006.jpg 9 148 | office/dslr/file_cabinet/frame_0012.jpg 9 149 | office/dslr/file_cabinet/frame_0013.jpg 9 150 | office/dslr/file_cabinet/frame_0007.jpg 9 151 | office/dslr/file_cabinet/frame_0004.jpg 9 152 | office/dslr/file_cabinet/frame_0001.jpg 9 153 | office/dslr/file_cabinet/frame_0005.jpg 9 154 | office/dslr/file_cabinet/frame_0009.jpg 9 -------------------------------------------------------------------------------- /data/webcam_0-9_20-30_test.txt: -------------------------------------------------------------------------------- 1 | office/webcam/backpack/frame_0019.jpg 0 2 | office/webcam/backpack/frame_0004.jpg 0 3 | office/webcam/backpack/frame_0007.jpg 0 4 | office/webcam/backpack/frame_0024.jpg 0 5 | office/webcam/backpack/frame_0023.jpg 0 6 | office/webcam/backpack/frame_0027.jpg 0 7 | office/webcam/backpack/frame_0003.jpg 0 8 | office/webcam/backpack/frame_0022.jpg 0 9 | office/webcam/backpack/frame_0005.jpg 0 10 | office/webcam/backpack/frame_0017.jpg 0 11 | office/webcam/backpack/frame_0029.jpg 0 12 | office/webcam/backpack/frame_0016.jpg 0 13 | office/webcam/backpack/frame_0021.jpg 0 14 | office/webcam/backpack/frame_0012.jpg 0 15 | office/webcam/backpack/frame_0018.jpg 0 16 | office/webcam/backpack/frame_0028.jpg 0 17 | office/webcam/backpack/frame_0014.jpg 0 18 | office/webcam/backpack/frame_0026.jpg 0 19 | office/webcam/backpack/frame_0025.jpg 0 20 | office/webcam/backpack/frame_0020.jpg 0 21 | office/webcam/backpack/frame_0008.jpg 0 22 | office/webcam/backpack/frame_0015.jpg 0 23 | office/webcam/backpack/frame_0010.jpg 0 24 | office/webcam/backpack/frame_0009.jpg 0 25 | office/webcam/backpack/frame_0001.jpg 0 26 | office/webcam/backpack/frame_0011.jpg 0 27 | office/webcam/backpack/frame_0002.jpg 0 28 | office/webcam/backpack/frame_0006.jpg 0 29 | office/webcam/backpack/frame_0013.jpg 0 30 | office/webcam/bike/frame_0007.jpg 1 31 | office/webcam/bike/frame_0016.jpg 1 32 | office/webcam/bike/frame_0006.jpg 1 33 | office/webcam/bike/frame_0002.jpg 1 34 | office/webcam/bike/frame_0012.jpg 1 35 | office/webcam/bike/frame_0019.jpg 1 36 | office/webcam/bike/frame_0020.jpg 1 37 | office/webcam/bike/frame_0001.jpg 1 38 | office/webcam/bike/frame_0014.jpg 1 39 | office/webcam/bike/frame_0015.jpg 1 40 | office/webcam/bike/frame_0011.jpg 1 41 | office/webcam/bike/frame_0004.jpg 1 42 | office/webcam/bike/frame_0010.jpg 1 43 | office/webcam/bike/frame_0018.jpg 1 44 | office/webcam/bike/frame_0009.jpg 1 45 | office/webcam/bike/frame_0005.jpg 1 46 | office/webcam/bike/frame_0021.jpg 1 47 | office/webcam/bike/frame_0017.jpg 1 48 | office/webcam/bike/frame_0013.jpg 1 49 | office/webcam/bike/frame_0008.jpg 1 50 | office/webcam/bike/frame_0003.jpg 1 51 | office/webcam/bike_helmet/frame_0012.jpg 2 52 | office/webcam/bike_helmet/frame_0013.jpg 2 53 | office/webcam/bike_helmet/frame_0019.jpg 2 54 | office/webcam/bike_helmet/frame_0006.jpg 2 55 | office/webcam/bike_helmet/frame_0003.jpg 2 56 | office/webcam/bike_helmet/frame_0022.jpg 2 57 | office/webcam/bike_helmet/frame_0008.jpg 2 58 | office/webcam/bike_helmet/frame_0015.jpg 2 59 | office/webcam/bike_helmet/frame_0026.jpg 2 60 | office/webcam/bike_helmet/frame_0024.jpg 2 61 | office/webcam/bike_helmet/frame_0023.jpg 2 62 | office/webcam/bike_helmet/frame_0025.jpg 2 63 | office/webcam/bike_helmet/frame_0001.jpg 2 64 | office/webcam/bike_helmet/frame_0027.jpg 2 65 | office/webcam/bike_helmet/frame_0009.jpg 2 66 | office/webcam/bike_helmet/frame_0016.jpg 2 67 | office/webcam/bike_helmet/frame_0010.jpg 2 68 | office/webcam/bike_helmet/frame_0014.jpg 2 69 | office/webcam/bike_helmet/frame_0017.jpg 2 70 | office/webcam/bike_helmet/frame_0018.jpg 2 71 | office/webcam/bike_helmet/frame_0002.jpg 2 72 | office/webcam/bike_helmet/frame_0011.jpg 2 73 | office/webcam/bike_helmet/frame_0007.jpg 2 74 | office/webcam/bike_helmet/frame_0005.jpg 2 75 | office/webcam/bike_helmet/frame_0020.jpg 2 76 | office/webcam/bike_helmet/frame_0028.jpg 2 77 | office/webcam/bike_helmet/frame_0021.jpg 2 78 | office/webcam/bike_helmet/frame_0004.jpg 2 79 | office/webcam/bookcase/frame_0010.jpg 3 80 | office/webcam/bookcase/frame_0003.jpg 3 81 | office/webcam/bookcase/frame_0007.jpg 3 82 | office/webcam/bookcase/frame_0008.jpg 3 83 | office/webcam/bookcase/frame_0011.jpg 3 84 | office/webcam/bookcase/frame_0004.jpg 3 85 | office/webcam/bookcase/frame_0012.jpg 3 86 | office/webcam/bookcase/frame_0009.jpg 3 87 | office/webcam/bookcase/frame_0006.jpg 3 88 | office/webcam/bookcase/frame_0002.jpg 3 89 | office/webcam/bookcase/frame_0001.jpg 3 90 | office/webcam/bookcase/frame_0005.jpg 3 91 | office/webcam/bottle/frame_0012.jpg 4 92 | office/webcam/bottle/frame_0002.jpg 4 93 | office/webcam/bottle/frame_0006.jpg 4 94 | office/webcam/bottle/frame_0008.jpg 4 95 | office/webcam/bottle/frame_0013.jpg 4 96 | office/webcam/bottle/frame_0009.jpg 4 97 | office/webcam/bottle/frame_0010.jpg 4 98 | office/webcam/bottle/frame_0016.jpg 4 99 | office/webcam/bottle/frame_0003.jpg 4 100 | office/webcam/bottle/frame_0004.jpg 4 101 | office/webcam/bottle/frame_0007.jpg 4 102 | office/webcam/bottle/frame_0014.jpg 4 103 | office/webcam/bottle/frame_0001.jpg 4 104 | office/webcam/bottle/frame_0015.jpg 4 105 | office/webcam/bottle/frame_0005.jpg 4 106 | office/webcam/bottle/frame_0011.jpg 4 107 | office/webcam/calculator/frame_0022.jpg 5 108 | office/webcam/calculator/frame_0017.jpg 5 109 | office/webcam/calculator/frame_0020.jpg 5 110 | office/webcam/calculator/frame_0029.jpg 5 111 | office/webcam/calculator/frame_0025.jpg 5 112 | office/webcam/calculator/frame_0024.jpg 5 113 | office/webcam/calculator/frame_0023.jpg 5 114 | office/webcam/calculator/frame_0013.jpg 5 115 | office/webcam/calculator/frame_0011.jpg 5 116 | office/webcam/calculator/frame_0007.jpg 5 117 | office/webcam/calculator/frame_0030.jpg 5 118 | office/webcam/calculator/frame_0015.jpg 5 119 | office/webcam/calculator/frame_0014.jpg 5 120 | office/webcam/calculator/frame_0003.jpg 5 121 | office/webcam/calculator/frame_0006.jpg 5 122 | office/webcam/calculator/frame_0018.jpg 5 123 | office/webcam/calculator/frame_0004.jpg 5 124 | office/webcam/calculator/frame_0010.jpg 5 125 | office/webcam/calculator/frame_0016.jpg 5 126 | office/webcam/calculator/frame_0005.jpg 5 127 | office/webcam/calculator/frame_0002.jpg 5 128 | office/webcam/calculator/frame_0026.jpg 5 129 | office/webcam/calculator/frame_0012.jpg 5 130 | office/webcam/calculator/frame_0001.jpg 5 131 | office/webcam/calculator/frame_0008.jpg 5 132 | office/webcam/calculator/frame_0009.jpg 5 133 | office/webcam/calculator/frame_0021.jpg 5 134 | office/webcam/calculator/frame_0027.jpg 5 135 | office/webcam/calculator/frame_0028.jpg 5 136 | office/webcam/calculator/frame_0031.jpg 5 137 | office/webcam/calculator/frame_0019.jpg 5 138 | office/webcam/desk_chair/frame_0008.jpg 6 139 | office/webcam/desk_chair/frame_0033.jpg 6 140 | office/webcam/desk_chair/frame_0007.jpg 6 141 | office/webcam/desk_chair/frame_0036.jpg 6 142 | office/webcam/desk_chair/frame_0013.jpg 6 143 | office/webcam/desk_chair/frame_0023.jpg 6 144 | office/webcam/desk_chair/frame_0017.jpg 6 145 | office/webcam/desk_chair/frame_0028.jpg 6 146 | office/webcam/desk_chair/frame_0011.jpg 6 147 | office/webcam/desk_chair/frame_0021.jpg 6 148 | office/webcam/desk_chair/frame_0005.jpg 6 149 | office/webcam/desk_chair/frame_0024.jpg 6 150 | office/webcam/desk_chair/frame_0004.jpg 6 151 | office/webcam/desk_chair/frame_0034.jpg 6 152 | office/webcam/desk_chair/frame_0038.jpg 6 153 | office/webcam/desk_chair/frame_0030.jpg 6 154 | office/webcam/desk_chair/frame_0003.jpg 6 155 | office/webcam/desk_chair/frame_0010.jpg 6 156 | office/webcam/desk_chair/frame_0001.jpg 6 157 | office/webcam/desk_chair/frame_0031.jpg 6 158 | office/webcam/desk_chair/frame_0022.jpg 6 159 | office/webcam/desk_chair/frame_0015.jpg 6 160 | office/webcam/desk_chair/frame_0029.jpg 6 161 | office/webcam/desk_chair/frame_0012.jpg 6 162 | office/webcam/desk_chair/frame_0016.jpg 6 163 | office/webcam/desk_chair/frame_0039.jpg 6 164 | office/webcam/desk_chair/frame_0002.jpg 6 165 | office/webcam/desk_chair/frame_0009.jpg 6 166 | office/webcam/desk_chair/frame_0037.jpg 6 167 | office/webcam/desk_chair/frame_0025.jpg 6 168 | office/webcam/desk_chair/frame_0014.jpg 6 169 | office/webcam/desk_chair/frame_0020.jpg 6 170 | office/webcam/desk_chair/frame_0027.jpg 6 171 | office/webcam/desk_chair/frame_0032.jpg 6 172 | office/webcam/desk_chair/frame_0035.jpg 6 173 | office/webcam/desk_chair/frame_0018.jpg 6 174 | office/webcam/desk_chair/frame_0006.jpg 6 175 | office/webcam/desk_chair/frame_0019.jpg 6 176 | office/webcam/desk_chair/frame_0040.jpg 6 177 | office/webcam/desk_chair/frame_0026.jpg 6 178 | office/webcam/desk_lamp/frame_0013.jpg 7 179 | office/webcam/desk_lamp/frame_0007.jpg 7 180 | office/webcam/desk_lamp/frame_0017.jpg 7 181 | office/webcam/desk_lamp/frame_0014.jpg 7 182 | office/webcam/desk_lamp/frame_0001.jpg 7 183 | office/webcam/desk_lamp/frame_0003.jpg 7 184 | office/webcam/desk_lamp/frame_0002.jpg 7 185 | office/webcam/desk_lamp/frame_0015.jpg 7 186 | office/webcam/desk_lamp/frame_0010.jpg 7 187 | office/webcam/desk_lamp/frame_0004.jpg 7 188 | office/webcam/desk_lamp/frame_0008.jpg 7 189 | office/webcam/desk_lamp/frame_0016.jpg 7 190 | office/webcam/desk_lamp/frame_0009.jpg 7 191 | office/webcam/desk_lamp/frame_0006.jpg 7 192 | office/webcam/desk_lamp/frame_0012.jpg 7 193 | office/webcam/desk_lamp/frame_0011.jpg 7 194 | office/webcam/desk_lamp/frame_0005.jpg 7 195 | office/webcam/desk_lamp/frame_0018.jpg 7 196 | office/webcam/desktop_computer/frame_0005.jpg 8 197 | office/webcam/desktop_computer/frame_0007.jpg 8 198 | office/webcam/desktop_computer/frame_0003.jpg 8 199 | office/webcam/desktop_computer/frame_0010.jpg 8 200 | office/webcam/desktop_computer/frame_0021.jpg 8 201 | office/webcam/desktop_computer/frame_0019.jpg 8 202 | office/webcam/desktop_computer/frame_0012.jpg 8 203 | office/webcam/desktop_computer/frame_0002.jpg 8 204 | office/webcam/desktop_computer/frame_0017.jpg 8 205 | office/webcam/desktop_computer/frame_0008.jpg 8 206 | office/webcam/desktop_computer/frame_0018.jpg 8 207 | office/webcam/desktop_computer/frame_0004.jpg 8 208 | office/webcam/desktop_computer/frame_0006.jpg 8 209 | office/webcam/desktop_computer/frame_0016.jpg 8 210 | office/webcam/desktop_computer/frame_0014.jpg 8 211 | office/webcam/desktop_computer/frame_0009.jpg 8 212 | office/webcam/desktop_computer/frame_0020.jpg 8 213 | office/webcam/desktop_computer/frame_0015.jpg 8 214 | office/webcam/desktop_computer/frame_0001.jpg 8 215 | office/webcam/desktop_computer/frame_0013.jpg 8 216 | office/webcam/desktop_computer/frame_0011.jpg 8 217 | office/webcam/file_cabinet/frame_0018.jpg 9 218 | office/webcam/file_cabinet/frame_0003.jpg 9 219 | office/webcam/file_cabinet/frame_0005.jpg 9 220 | office/webcam/file_cabinet/frame_0001.jpg 9 221 | office/webcam/file_cabinet/frame_0010.jpg 9 222 | office/webcam/file_cabinet/frame_0014.jpg 9 223 | office/webcam/file_cabinet/frame_0008.jpg 9 224 | office/webcam/file_cabinet/frame_0019.jpg 9 225 | office/webcam/file_cabinet/frame_0007.jpg 9 226 | office/webcam/file_cabinet/frame_0009.jpg 9 227 | office/webcam/file_cabinet/frame_0017.jpg 9 228 | office/webcam/file_cabinet/frame_0016.jpg 9 229 | office/webcam/file_cabinet/frame_0012.jpg 9 230 | office/webcam/file_cabinet/frame_0013.jpg 9 231 | office/webcam/file_cabinet/frame_0015.jpg 9 232 | office/webcam/file_cabinet/frame_0006.jpg 9 233 | office/webcam/file_cabinet/frame_0004.jpg 9 234 | office/webcam/file_cabinet/frame_0011.jpg 9 235 | office/webcam/file_cabinet/frame_0002.jpg 9 236 | office/webcam/phone/frame_0006.jpg 10 237 | office/webcam/phone/frame_0015.jpg 10 238 | office/webcam/phone/frame_0014.jpg 10 239 | office/webcam/phone/frame_0005.jpg 10 240 | office/webcam/phone/frame_0013.jpg 10 241 | office/webcam/phone/frame_0010.jpg 10 242 | office/webcam/phone/frame_0007.jpg 10 243 | office/webcam/phone/frame_0008.jpg 10 244 | office/webcam/phone/frame_0016.jpg 10 245 | office/webcam/phone/frame_0004.jpg 10 246 | office/webcam/phone/frame_0012.jpg 10 247 | office/webcam/phone/frame_0003.jpg 10 248 | office/webcam/phone/frame_0002.jpg 10 249 | office/webcam/phone/frame_0001.jpg 10 250 | office/webcam/phone/frame_0011.jpg 10 251 | office/webcam/phone/frame_0009.jpg 10 252 | office/webcam/printer/frame_0018.jpg 11 253 | office/webcam/printer/frame_0012.jpg 11 254 | office/webcam/printer/frame_0011.jpg 11 255 | office/webcam/printer/frame_0006.jpg 11 256 | office/webcam/printer/frame_0015.jpg 11 257 | office/webcam/printer/frame_0007.jpg 11 258 | office/webcam/printer/frame_0016.jpg 11 259 | office/webcam/printer/frame_0003.jpg 11 260 | office/webcam/printer/frame_0010.jpg 11 261 | office/webcam/printer/frame_0002.jpg 11 262 | office/webcam/printer/frame_0001.jpg 11 263 | office/webcam/printer/frame_0020.jpg 11 264 | office/webcam/printer/frame_0005.jpg 11 265 | office/webcam/printer/frame_0019.jpg 11 266 | office/webcam/printer/frame_0013.jpg 11 267 | office/webcam/printer/frame_0017.jpg 11 268 | office/webcam/printer/frame_0004.jpg 11 269 | office/webcam/printer/frame_0009.jpg 11 270 | office/webcam/printer/frame_0014.jpg 11 271 | office/webcam/printer/frame_0008.jpg 11 272 | office/webcam/projector/frame_0026.jpg 12 273 | office/webcam/projector/frame_0008.jpg 12 274 | office/webcam/projector/frame_0027.jpg 12 275 | office/webcam/projector/frame_0024.jpg 12 276 | office/webcam/projector/frame_0005.jpg 12 277 | office/webcam/projector/frame_0002.jpg 12 278 | office/webcam/projector/frame_0001.jpg 12 279 | office/webcam/projector/frame_0007.jpg 12 280 | office/webcam/projector/frame_0018.jpg 12 281 | office/webcam/projector/frame_0013.jpg 12 282 | office/webcam/projector/frame_0006.jpg 12 283 | office/webcam/projector/frame_0003.jpg 12 284 | office/webcam/projector/frame_0011.jpg 12 285 | office/webcam/projector/frame_0012.jpg 12 286 | office/webcam/projector/frame_0023.jpg 12 287 | office/webcam/projector/frame_0017.jpg 12 288 | office/webcam/projector/frame_0015.jpg 12 289 | office/webcam/projector/frame_0014.jpg 12 290 | office/webcam/projector/frame_0019.jpg 12 291 | office/webcam/projector/frame_0020.jpg 12 292 | office/webcam/projector/frame_0030.jpg 12 293 | office/webcam/projector/frame_0021.jpg 12 294 | office/webcam/projector/frame_0028.jpg 12 295 | office/webcam/projector/frame_0010.jpg 12 296 | office/webcam/projector/frame_0016.jpg 12 297 | office/webcam/projector/frame_0009.jpg 12 298 | office/webcam/projector/frame_0025.jpg 12 299 | office/webcam/projector/frame_0004.jpg 12 300 | office/webcam/projector/frame_0029.jpg 12 301 | office/webcam/projector/frame_0022.jpg 12 302 | office/webcam/punchers/frame_0006.jpg 13 303 | office/webcam/punchers/frame_0012.jpg 13 304 | office/webcam/punchers/frame_0002.jpg 13 305 | office/webcam/punchers/frame_0004.jpg 13 306 | office/webcam/punchers/frame_0010.jpg 13 307 | office/webcam/punchers/frame_0020.jpg 13 308 | office/webcam/punchers/frame_0026.jpg 13 309 | office/webcam/punchers/frame_0003.jpg 13 310 | office/webcam/punchers/frame_0018.jpg 13 311 | office/webcam/punchers/frame_0001.jpg 13 312 | office/webcam/punchers/frame_0011.jpg 13 313 | office/webcam/punchers/frame_0019.jpg 13 314 | office/webcam/punchers/frame_0025.jpg 13 315 | office/webcam/punchers/frame_0017.jpg 13 316 | office/webcam/punchers/frame_0008.jpg 13 317 | office/webcam/punchers/frame_0016.jpg 13 318 | office/webcam/punchers/frame_0023.jpg 13 319 | office/webcam/punchers/frame_0007.jpg 13 320 | office/webcam/punchers/frame_0013.jpg 13 321 | office/webcam/punchers/frame_0021.jpg 13 322 | office/webcam/punchers/frame_0027.jpg 13 323 | office/webcam/punchers/frame_0015.jpg 13 324 | office/webcam/punchers/frame_0022.jpg 13 325 | office/webcam/punchers/frame_0014.jpg 13 326 | office/webcam/punchers/frame_0005.jpg 13 327 | office/webcam/punchers/frame_0009.jpg 13 328 | office/webcam/punchers/frame_0024.jpg 13 329 | office/webcam/ring_binder/frame_0022.jpg 14 330 | office/webcam/ring_binder/frame_0025.jpg 14 331 | office/webcam/ring_binder/frame_0034.jpg 14 332 | office/webcam/ring_binder/frame_0021.jpg 14 333 | office/webcam/ring_binder/frame_0009.jpg 14 334 | office/webcam/ring_binder/frame_0003.jpg 14 335 | office/webcam/ring_binder/frame_0017.jpg 14 336 | office/webcam/ring_binder/frame_0033.jpg 14 337 | office/webcam/ring_binder/frame_0015.jpg 14 338 | office/webcam/ring_binder/frame_0035.jpg 14 339 | office/webcam/ring_binder/frame_0031.jpg 14 340 | office/webcam/ring_binder/frame_0004.jpg 14 341 | office/webcam/ring_binder/frame_0030.jpg 14 342 | office/webcam/ring_binder/frame_0020.jpg 14 343 | office/webcam/ring_binder/frame_0016.jpg 14 344 | office/webcam/ring_binder/frame_0013.jpg 14 345 | office/webcam/ring_binder/frame_0037.jpg 14 346 | office/webcam/ring_binder/frame_0006.jpg 14 347 | office/webcam/ring_binder/frame_0038.jpg 14 348 | office/webcam/ring_binder/frame_0008.jpg 14 349 | office/webcam/ring_binder/frame_0018.jpg 14 350 | office/webcam/ring_binder/frame_0011.jpg 14 351 | office/webcam/ring_binder/frame_0014.jpg 14 352 | office/webcam/ring_binder/frame_0007.jpg 14 353 | office/webcam/ring_binder/frame_0039.jpg 14 354 | office/webcam/ring_binder/frame_0023.jpg 14 355 | office/webcam/ring_binder/frame_0010.jpg 14 356 | office/webcam/ring_binder/frame_0029.jpg 14 357 | office/webcam/ring_binder/frame_0028.jpg 14 358 | office/webcam/ring_binder/frame_0005.jpg 14 359 | office/webcam/ring_binder/frame_0019.jpg 14 360 | office/webcam/ring_binder/frame_0026.jpg 14 361 | office/webcam/ring_binder/frame_0001.jpg 14 362 | office/webcam/ring_binder/frame_0024.jpg 14 363 | office/webcam/ring_binder/frame_0036.jpg 14 364 | office/webcam/ring_binder/frame_0012.jpg 14 365 | office/webcam/ring_binder/frame_0002.jpg 14 366 | office/webcam/ring_binder/frame_0032.jpg 14 367 | office/webcam/ring_binder/frame_0040.jpg 14 368 | office/webcam/ring_binder/frame_0027.jpg 14 369 | office/webcam/ruler/frame_0002.jpg 15 370 | office/webcam/ruler/frame_0011.jpg 15 371 | office/webcam/ruler/frame_0003.jpg 15 372 | office/webcam/ruler/frame_0008.jpg 15 373 | office/webcam/ruler/frame_0001.jpg 15 374 | office/webcam/ruler/frame_0009.jpg 15 375 | office/webcam/ruler/frame_0007.jpg 15 376 | office/webcam/ruler/frame_0010.jpg 15 377 | office/webcam/ruler/frame_0006.jpg 15 378 | office/webcam/ruler/frame_0004.jpg 15 379 | office/webcam/ruler/frame_0005.jpg 15 380 | office/webcam/scissors/frame_0008.jpg 16 381 | office/webcam/scissors/frame_0004.jpg 16 382 | office/webcam/scissors/frame_0016.jpg 16 383 | office/webcam/scissors/frame_0006.jpg 16 384 | office/webcam/scissors/frame_0019.jpg 16 385 | office/webcam/scissors/frame_0021.jpg 16 386 | office/webcam/scissors/frame_0003.jpg 16 387 | office/webcam/scissors/frame_0011.jpg 16 388 | office/webcam/scissors/frame_0024.jpg 16 389 | office/webcam/scissors/frame_0012.jpg 16 390 | office/webcam/scissors/frame_0005.jpg 16 391 | office/webcam/scissors/frame_0007.jpg 16 392 | office/webcam/scissors/frame_0025.jpg 16 393 | office/webcam/scissors/frame_0001.jpg 16 394 | office/webcam/scissors/frame_0002.jpg 16 395 | office/webcam/scissors/frame_0009.jpg 16 396 | office/webcam/scissors/frame_0023.jpg 16 397 | office/webcam/scissors/frame_0020.jpg 16 398 | office/webcam/scissors/frame_0013.jpg 16 399 | office/webcam/scissors/frame_0018.jpg 16 400 | office/webcam/scissors/frame_0017.jpg 16 401 | office/webcam/scissors/frame_0022.jpg 16 402 | office/webcam/scissors/frame_0014.jpg 16 403 | office/webcam/scissors/frame_0015.jpg 16 404 | office/webcam/scissors/frame_0010.jpg 16 405 | office/webcam/speaker/frame_0026.jpg 17 406 | office/webcam/speaker/frame_0010.jpg 17 407 | office/webcam/speaker/frame_0022.jpg 17 408 | office/webcam/speaker/frame_0028.jpg 17 409 | office/webcam/speaker/frame_0008.jpg 17 410 | office/webcam/speaker/frame_0019.jpg 17 411 | office/webcam/speaker/frame_0004.jpg 17 412 | office/webcam/speaker/frame_0014.jpg 17 413 | office/webcam/speaker/frame_0011.jpg 17 414 | office/webcam/speaker/frame_0024.jpg 17 415 | office/webcam/speaker/frame_0012.jpg 17 416 | office/webcam/speaker/frame_0029.jpg 17 417 | office/webcam/speaker/frame_0020.jpg 17 418 | office/webcam/speaker/frame_0009.jpg 17 419 | office/webcam/speaker/frame_0005.jpg 17 420 | office/webcam/speaker/frame_0018.jpg 17 421 | office/webcam/speaker/frame_0023.jpg 17 422 | office/webcam/speaker/frame_0006.jpg 17 423 | office/webcam/speaker/frame_0013.jpg 17 424 | office/webcam/speaker/frame_0030.jpg 17 425 | office/webcam/speaker/frame_0007.jpg 17 426 | office/webcam/speaker/frame_0021.jpg 17 427 | office/webcam/speaker/frame_0025.jpg 17 428 | office/webcam/speaker/frame_0001.jpg 17 429 | office/webcam/speaker/frame_0015.jpg 17 430 | office/webcam/speaker/frame_0016.jpg 17 431 | office/webcam/speaker/frame_0017.jpg 17 432 | office/webcam/speaker/frame_0003.jpg 17 433 | office/webcam/speaker/frame_0002.jpg 17 434 | office/webcam/speaker/frame_0027.jpg 17 435 | office/webcam/stapler/frame_0022.jpg 18 436 | office/webcam/stapler/frame_0001.jpg 18 437 | office/webcam/stapler/frame_0002.jpg 18 438 | office/webcam/stapler/frame_0016.jpg 18 439 | office/webcam/stapler/frame_0010.jpg 18 440 | office/webcam/stapler/frame_0013.jpg 18 441 | office/webcam/stapler/frame_0005.jpg 18 442 | office/webcam/stapler/frame_0009.jpg 18 443 | office/webcam/stapler/frame_0018.jpg 18 444 | office/webcam/stapler/frame_0006.jpg 18 445 | office/webcam/stapler/frame_0004.jpg 18 446 | office/webcam/stapler/frame_0012.jpg 18 447 | office/webcam/stapler/frame_0024.jpg 18 448 | office/webcam/stapler/frame_0003.jpg 18 449 | office/webcam/stapler/frame_0020.jpg 18 450 | office/webcam/stapler/frame_0019.jpg 18 451 | office/webcam/stapler/frame_0023.jpg 18 452 | office/webcam/stapler/frame_0014.jpg 18 453 | office/webcam/stapler/frame_0017.jpg 18 454 | office/webcam/stapler/frame_0011.jpg 18 455 | office/webcam/stapler/frame_0021.jpg 18 456 | office/webcam/stapler/frame_0015.jpg 18 457 | office/webcam/stapler/frame_0007.jpg 18 458 | office/webcam/stapler/frame_0008.jpg 18 459 | office/webcam/tape_dispenser/frame_0018.jpg 19 460 | office/webcam/tape_dispenser/frame_0005.jpg 19 461 | office/webcam/tape_dispenser/frame_0010.jpg 19 462 | office/webcam/tape_dispenser/frame_0001.jpg 19 463 | office/webcam/tape_dispenser/frame_0009.jpg 19 464 | office/webcam/tape_dispenser/frame_0004.jpg 19 465 | office/webcam/tape_dispenser/frame_0022.jpg 19 466 | office/webcam/tape_dispenser/frame_0006.jpg 19 467 | office/webcam/tape_dispenser/frame_0007.jpg 19 468 | office/webcam/tape_dispenser/frame_0015.jpg 19 469 | office/webcam/tape_dispenser/frame_0013.jpg 19 470 | office/webcam/tape_dispenser/frame_0020.jpg 19 471 | office/webcam/tape_dispenser/frame_0003.jpg 19 472 | office/webcam/tape_dispenser/frame_0008.jpg 19 473 | office/webcam/tape_dispenser/frame_0021.jpg 19 474 | office/webcam/tape_dispenser/frame_0019.jpg 19 475 | office/webcam/tape_dispenser/frame_0014.jpg 19 476 | office/webcam/tape_dispenser/frame_0002.jpg 19 477 | office/webcam/tape_dispenser/frame_0011.jpg 19 478 | office/webcam/tape_dispenser/frame_0016.jpg 19 479 | office/webcam/tape_dispenser/frame_0012.jpg 19 480 | office/webcam/tape_dispenser/frame_0023.jpg 19 481 | office/webcam/tape_dispenser/frame_0017.jpg 19 482 | office/webcam/trash_can/frame_0015.jpg 20 483 | office/webcam/trash_can/frame_0017.jpg 20 484 | office/webcam/trash_can/frame_0018.jpg 20 485 | office/webcam/trash_can/frame_0006.jpg 20 486 | office/webcam/trash_can/frame_0007.jpg 20 487 | office/webcam/trash_can/frame_0002.jpg 20 488 | office/webcam/trash_can/frame_0010.jpg 20 489 | office/webcam/trash_can/frame_0005.jpg 20 490 | office/webcam/trash_can/frame_0009.jpg 20 491 | office/webcam/trash_can/frame_0014.jpg 20 492 | office/webcam/trash_can/frame_0013.jpg 20 493 | office/webcam/trash_can/frame_0004.jpg 20 494 | office/webcam/trash_can/frame_0003.jpg 20 495 | office/webcam/trash_can/frame_0011.jpg 20 496 | office/webcam/trash_can/frame_0021.jpg 20 497 | office/webcam/trash_can/frame_0016.jpg 20 498 | office/webcam/trash_can/frame_0001.jpg 20 499 | office/webcam/trash_can/frame_0008.jpg 20 500 | office/webcam/trash_can/frame_0020.jpg 20 501 | office/webcam/trash_can/frame_0012.jpg 20 502 | office/webcam/trash_can/frame_0019.jpg 20 -------------------------------------------------------------------------------- /data/webcam_0-9_train_all.txt: -------------------------------------------------------------------------------- 1 | office/webcam/backpack/frame_0019.jpg 0 2 | office/webcam/backpack/frame_0004.jpg 0 3 | office/webcam/backpack/frame_0007.jpg 0 4 | office/webcam/backpack/frame_0024.jpg 0 5 | office/webcam/backpack/frame_0023.jpg 0 6 | office/webcam/backpack/frame_0027.jpg 0 7 | office/webcam/backpack/frame_0003.jpg 0 8 | office/webcam/backpack/frame_0022.jpg 0 9 | office/webcam/backpack/frame_0005.jpg 0 10 | office/webcam/backpack/frame_0017.jpg 0 11 | office/webcam/backpack/frame_0029.jpg 0 12 | office/webcam/backpack/frame_0016.jpg 0 13 | office/webcam/backpack/frame_0021.jpg 0 14 | office/webcam/backpack/frame_0012.jpg 0 15 | office/webcam/backpack/frame_0018.jpg 0 16 | office/webcam/backpack/frame_0028.jpg 0 17 | office/webcam/backpack/frame_0014.jpg 0 18 | office/webcam/backpack/frame_0026.jpg 0 19 | office/webcam/backpack/frame_0025.jpg 0 20 | office/webcam/backpack/frame_0020.jpg 0 21 | office/webcam/backpack/frame_0008.jpg 0 22 | office/webcam/backpack/frame_0015.jpg 0 23 | office/webcam/backpack/frame_0010.jpg 0 24 | office/webcam/backpack/frame_0009.jpg 0 25 | office/webcam/backpack/frame_0001.jpg 0 26 | office/webcam/backpack/frame_0011.jpg 0 27 | office/webcam/backpack/frame_0002.jpg 0 28 | office/webcam/backpack/frame_0006.jpg 0 29 | office/webcam/backpack/frame_0013.jpg 0 30 | office/webcam/bike/frame_0007.jpg 1 31 | office/webcam/bike/frame_0016.jpg 1 32 | office/webcam/bike/frame_0006.jpg 1 33 | office/webcam/bike/frame_0002.jpg 1 34 | office/webcam/bike/frame_0012.jpg 1 35 | office/webcam/bike/frame_0019.jpg 1 36 | office/webcam/bike/frame_0020.jpg 1 37 | office/webcam/bike/frame_0001.jpg 1 38 | office/webcam/bike/frame_0014.jpg 1 39 | office/webcam/bike/frame_0015.jpg 1 40 | office/webcam/bike/frame_0011.jpg 1 41 | office/webcam/bike/frame_0004.jpg 1 42 | office/webcam/bike/frame_0010.jpg 1 43 | office/webcam/bike/frame_0018.jpg 1 44 | office/webcam/bike/frame_0009.jpg 1 45 | office/webcam/bike/frame_0005.jpg 1 46 | office/webcam/bike/frame_0021.jpg 1 47 | office/webcam/bike/frame_0017.jpg 1 48 | office/webcam/bike/frame_0013.jpg 1 49 | office/webcam/bike/frame_0008.jpg 1 50 | office/webcam/bike/frame_0003.jpg 1 51 | office/webcam/bike_helmet/frame_0012.jpg 2 52 | office/webcam/bike_helmet/frame_0013.jpg 2 53 | office/webcam/bike_helmet/frame_0019.jpg 2 54 | office/webcam/bike_helmet/frame_0006.jpg 2 55 | office/webcam/bike_helmet/frame_0003.jpg 2 56 | office/webcam/bike_helmet/frame_0022.jpg 2 57 | office/webcam/bike_helmet/frame_0008.jpg 2 58 | office/webcam/bike_helmet/frame_0015.jpg 2 59 | office/webcam/bike_helmet/frame_0026.jpg 2 60 | office/webcam/bike_helmet/frame_0024.jpg 2 61 | office/webcam/bike_helmet/frame_0023.jpg 2 62 | office/webcam/bike_helmet/frame_0025.jpg 2 63 | office/webcam/bike_helmet/frame_0001.jpg 2 64 | office/webcam/bike_helmet/frame_0027.jpg 2 65 | office/webcam/bike_helmet/frame_0009.jpg 2 66 | office/webcam/bike_helmet/frame_0016.jpg 2 67 | office/webcam/bike_helmet/frame_0010.jpg 2 68 | office/webcam/bike_helmet/frame_0014.jpg 2 69 | office/webcam/bike_helmet/frame_0017.jpg 2 70 | office/webcam/bike_helmet/frame_0018.jpg 2 71 | office/webcam/bike_helmet/frame_0002.jpg 2 72 | office/webcam/bike_helmet/frame_0011.jpg 2 73 | office/webcam/bike_helmet/frame_0007.jpg 2 74 | office/webcam/bike_helmet/frame_0005.jpg 2 75 | office/webcam/bike_helmet/frame_0020.jpg 2 76 | office/webcam/bike_helmet/frame_0028.jpg 2 77 | office/webcam/bike_helmet/frame_0021.jpg 2 78 | office/webcam/bike_helmet/frame_0004.jpg 2 79 | office/webcam/bookcase/frame_0010.jpg 3 80 | office/webcam/bookcase/frame_0003.jpg 3 81 | office/webcam/bookcase/frame_0007.jpg 3 82 | office/webcam/bookcase/frame_0008.jpg 3 83 | office/webcam/bookcase/frame_0011.jpg 3 84 | office/webcam/bookcase/frame_0004.jpg 3 85 | office/webcam/bookcase/frame_0012.jpg 3 86 | office/webcam/bookcase/frame_0009.jpg 3 87 | office/webcam/bookcase/frame_0006.jpg 3 88 | office/webcam/bookcase/frame_0002.jpg 3 89 | office/webcam/bookcase/frame_0001.jpg 3 90 | office/webcam/bookcase/frame_0005.jpg 3 91 | office/webcam/bottle/frame_0012.jpg 4 92 | office/webcam/bottle/frame_0002.jpg 4 93 | office/webcam/bottle/frame_0006.jpg 4 94 | office/webcam/bottle/frame_0008.jpg 4 95 | office/webcam/bottle/frame_0013.jpg 4 96 | office/webcam/bottle/frame_0009.jpg 4 97 | office/webcam/bottle/frame_0010.jpg 4 98 | office/webcam/bottle/frame_0016.jpg 4 99 | office/webcam/bottle/frame_0003.jpg 4 100 | office/webcam/bottle/frame_0004.jpg 4 101 | office/webcam/bottle/frame_0007.jpg 4 102 | office/webcam/bottle/frame_0014.jpg 4 103 | office/webcam/bottle/frame_0001.jpg 4 104 | office/webcam/bottle/frame_0015.jpg 4 105 | office/webcam/bottle/frame_0005.jpg 4 106 | office/webcam/bottle/frame_0011.jpg 4 107 | office/webcam/calculator/frame_0022.jpg 5 108 | office/webcam/calculator/frame_0017.jpg 5 109 | office/webcam/calculator/frame_0020.jpg 5 110 | office/webcam/calculator/frame_0029.jpg 5 111 | office/webcam/calculator/frame_0025.jpg 5 112 | office/webcam/calculator/frame_0024.jpg 5 113 | office/webcam/calculator/frame_0023.jpg 5 114 | office/webcam/calculator/frame_0013.jpg 5 115 | office/webcam/calculator/frame_0011.jpg 5 116 | office/webcam/calculator/frame_0007.jpg 5 117 | office/webcam/calculator/frame_0030.jpg 5 118 | office/webcam/calculator/frame_0015.jpg 5 119 | office/webcam/calculator/frame_0014.jpg 5 120 | office/webcam/calculator/frame_0003.jpg 5 121 | office/webcam/calculator/frame_0006.jpg 5 122 | office/webcam/calculator/frame_0018.jpg 5 123 | office/webcam/calculator/frame_0004.jpg 5 124 | office/webcam/calculator/frame_0010.jpg 5 125 | office/webcam/calculator/frame_0016.jpg 5 126 | office/webcam/calculator/frame_0005.jpg 5 127 | office/webcam/calculator/frame_0002.jpg 5 128 | office/webcam/calculator/frame_0026.jpg 5 129 | office/webcam/calculator/frame_0012.jpg 5 130 | office/webcam/calculator/frame_0001.jpg 5 131 | office/webcam/calculator/frame_0008.jpg 5 132 | office/webcam/calculator/frame_0009.jpg 5 133 | office/webcam/calculator/frame_0021.jpg 5 134 | office/webcam/calculator/frame_0027.jpg 5 135 | office/webcam/calculator/frame_0028.jpg 5 136 | office/webcam/calculator/frame_0031.jpg 5 137 | office/webcam/calculator/frame_0019.jpg 5 138 | office/webcam/desk_chair/frame_0008.jpg 6 139 | office/webcam/desk_chair/frame_0033.jpg 6 140 | office/webcam/desk_chair/frame_0007.jpg 6 141 | office/webcam/desk_chair/frame_0036.jpg 6 142 | office/webcam/desk_chair/frame_0013.jpg 6 143 | office/webcam/desk_chair/frame_0023.jpg 6 144 | office/webcam/desk_chair/frame_0017.jpg 6 145 | office/webcam/desk_chair/frame_0028.jpg 6 146 | office/webcam/desk_chair/frame_0011.jpg 6 147 | office/webcam/desk_chair/frame_0021.jpg 6 148 | office/webcam/desk_chair/frame_0005.jpg 6 149 | office/webcam/desk_chair/frame_0024.jpg 6 150 | office/webcam/desk_chair/frame_0004.jpg 6 151 | office/webcam/desk_chair/frame_0034.jpg 6 152 | office/webcam/desk_chair/frame_0038.jpg 6 153 | office/webcam/desk_chair/frame_0030.jpg 6 154 | office/webcam/desk_chair/frame_0003.jpg 6 155 | office/webcam/desk_chair/frame_0010.jpg 6 156 | office/webcam/desk_chair/frame_0001.jpg 6 157 | office/webcam/desk_chair/frame_0031.jpg 6 158 | office/webcam/desk_chair/frame_0022.jpg 6 159 | office/webcam/desk_chair/frame_0015.jpg 6 160 | office/webcam/desk_chair/frame_0029.jpg 6 161 | office/webcam/desk_chair/frame_0012.jpg 6 162 | office/webcam/desk_chair/frame_0016.jpg 6 163 | office/webcam/desk_chair/frame_0039.jpg 6 164 | office/webcam/desk_chair/frame_0002.jpg 6 165 | office/webcam/desk_chair/frame_0009.jpg 6 166 | office/webcam/desk_chair/frame_0037.jpg 6 167 | office/webcam/desk_chair/frame_0025.jpg 6 168 | office/webcam/desk_chair/frame_0014.jpg 6 169 | office/webcam/desk_chair/frame_0020.jpg 6 170 | office/webcam/desk_chair/frame_0027.jpg 6 171 | office/webcam/desk_chair/frame_0032.jpg 6 172 | office/webcam/desk_chair/frame_0035.jpg 6 173 | office/webcam/desk_chair/frame_0018.jpg 6 174 | office/webcam/desk_chair/frame_0006.jpg 6 175 | office/webcam/desk_chair/frame_0019.jpg 6 176 | office/webcam/desk_chair/frame_0040.jpg 6 177 | office/webcam/desk_chair/frame_0026.jpg 6 178 | office/webcam/desk_lamp/frame_0013.jpg 7 179 | office/webcam/desk_lamp/frame_0007.jpg 7 180 | office/webcam/desk_lamp/frame_0017.jpg 7 181 | office/webcam/desk_lamp/frame_0014.jpg 7 182 | office/webcam/desk_lamp/frame_0001.jpg 7 183 | office/webcam/desk_lamp/frame_0003.jpg 7 184 | office/webcam/desk_lamp/frame_0002.jpg 7 185 | office/webcam/desk_lamp/frame_0015.jpg 7 186 | office/webcam/desk_lamp/frame_0010.jpg 7 187 | office/webcam/desk_lamp/frame_0004.jpg 7 188 | office/webcam/desk_lamp/frame_0008.jpg 7 189 | office/webcam/desk_lamp/frame_0016.jpg 7 190 | office/webcam/desk_lamp/frame_0009.jpg 7 191 | office/webcam/desk_lamp/frame_0006.jpg 7 192 | office/webcam/desk_lamp/frame_0012.jpg 7 193 | office/webcam/desk_lamp/frame_0011.jpg 7 194 | office/webcam/desk_lamp/frame_0005.jpg 7 195 | office/webcam/desk_lamp/frame_0018.jpg 7 196 | office/webcam/desktop_computer/frame_0005.jpg 8 197 | office/webcam/desktop_computer/frame_0007.jpg 8 198 | office/webcam/desktop_computer/frame_0003.jpg 8 199 | office/webcam/desktop_computer/frame_0010.jpg 8 200 | office/webcam/desktop_computer/frame_0021.jpg 8 201 | office/webcam/desktop_computer/frame_0019.jpg 8 202 | office/webcam/desktop_computer/frame_0012.jpg 8 203 | office/webcam/desktop_computer/frame_0002.jpg 8 204 | office/webcam/desktop_computer/frame_0017.jpg 8 205 | office/webcam/desktop_computer/frame_0008.jpg 8 206 | office/webcam/desktop_computer/frame_0018.jpg 8 207 | office/webcam/desktop_computer/frame_0004.jpg 8 208 | office/webcam/desktop_computer/frame_0006.jpg 8 209 | office/webcam/desktop_computer/frame_0016.jpg 8 210 | office/webcam/desktop_computer/frame_0014.jpg 8 211 | office/webcam/desktop_computer/frame_0009.jpg 8 212 | office/webcam/desktop_computer/frame_0020.jpg 8 213 | office/webcam/desktop_computer/frame_0015.jpg 8 214 | office/webcam/desktop_computer/frame_0001.jpg 8 215 | office/webcam/desktop_computer/frame_0013.jpg 8 216 | office/webcam/desktop_computer/frame_0011.jpg 8 217 | office/webcam/file_cabinet/frame_0018.jpg 9 218 | office/webcam/file_cabinet/frame_0003.jpg 9 219 | office/webcam/file_cabinet/frame_0005.jpg 9 220 | office/webcam/file_cabinet/frame_0001.jpg 9 221 | office/webcam/file_cabinet/frame_0010.jpg 9 222 | office/webcam/file_cabinet/frame_0014.jpg 9 223 | office/webcam/file_cabinet/frame_0008.jpg 9 224 | office/webcam/file_cabinet/frame_0019.jpg 9 225 | office/webcam/file_cabinet/frame_0007.jpg 9 226 | office/webcam/file_cabinet/frame_0009.jpg 9 227 | office/webcam/file_cabinet/frame_0017.jpg 9 228 | office/webcam/file_cabinet/frame_0016.jpg 9 229 | office/webcam/file_cabinet/frame_0012.jpg 9 230 | office/webcam/file_cabinet/frame_0013.jpg 9 231 | office/webcam/file_cabinet/frame_0015.jpg 9 232 | office/webcam/file_cabinet/frame_0006.jpg 9 233 | office/webcam/file_cabinet/frame_0004.jpg 9 234 | office/webcam/file_cabinet/frame_0011.jpg 9 235 | office/webcam/file_cabinet/frame_0002.jpg 9 -------------------------------------------------------------------------------- /image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/silvia1993/ROS/bceef4d9dc505f55322a4c25fb8071f49e7a5671/image.jpg -------------------------------------------------------------------------------- /requirements_ROS.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.1.post2 2 | torchvision==0.2.1 3 | tensorlayer==1.11.0 4 | tensorflow-gpu==1.14.0 5 | scipy==1.1.0 6 | Pillow==6.2.0 7 | numpy==1.15.4 8 | -------------------------------------------------------------------------------- /train_resnet50_office31.sh: -------------------------------------------------------------------------------- 1 | 2 | ##########################train of Office-31 dataset using ResNet-50######################################################### 3 | 4 | 5 | python3 code/train.py --source dslr_0-9 --target webcam_0-9_20-30 --epochs_step1 80 --epochs_step2 80 --n_classes 10 --n_classes_target 21 --use_weight_net_first_part --weight_class_unknown 2 --weight_center_loss 0.1 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 6 | 7 | python3 code/train.py --source webcam_0-9 --target dslr_0-9_20-30 --epochs_step1 80 --epochs_step2 80 --n_classes 10 --n_classes_target 21 --use_weight_net_first_part --weight_class_unknown 2 --weight_center_loss 0.1 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 8 | 9 | python3 code/train.py --source dslr_0-9 --target amazon_0-9_20-30 --epochs_step1 80 --epochs_step2 80 --n_classes 10 --n_classes_target 21 --use_weight_net_first_part --weight_class_unknown 2 --weight_center_loss 0.1 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 10 | 11 | python3 code/train.py --source webcam_0-9 --target amazon_0-9_20-30 --epochs_step1 80 --epochs_step2 80 --n_classes 10 --n_classes_target 21 --use_weight_net_first_part --weight_class_unknown 2 --weight_center_loss 0.1 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 12 | 13 | python3 code/train.py --source amazon_0-9 --target dslr_0-9_20-30 --epochs_step1 80 --epochs_step2 80 --n_classes 10 --n_classes_target 21 --use_weight_net_first_part --weight_class_unknown 2 --weight_center_loss 0.1 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 14 | 15 | python3 code/train.py --source amazon_0-9 --target webcam_0-9_20-30 --epochs_step1 80 --epochs_step2 80 --n_classes 10 --n_classes_target 21 --use_weight_net_first_part --weight_class_unknown 2 --weight_center_loss 0.1 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 16 | -------------------------------------------------------------------------------- /train_resnet50_officehome.sh: -------------------------------------------------------------------------------- 1 | 2 | ##########################train of Office-Home dataset using ResNet50######################################################### 3 | 4 | 5 | python3 code/train.py --source product_0-24 --target clipart_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 6 | 7 | python3 code/train.py --source product_0-24 --target art_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 8 | 9 | python3 code/train.py --source product_0-24 --target real_world_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 10 | 11 | 12 | 13 | python3 code/train.py --source art_0-24 --target clipart_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 14 | 15 | python3 code/train.py --source art_0-24 --target product_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 16 | 17 | python3 code/train.py --source art_0-24 --target real_world_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 18 | 19 | 20 | 21 | python3 code/train.py --source clipart_0-24 --target art_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 22 | 23 | python3 code/train.py --source clipart_0-24 --target product_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 24 | 25 | python3 code/train.py --source clipart_0-24 --target real_world_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 26 | 27 | 28 | 29 | python3 code/train.py --source real_world_0-24 --target art_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 30 | 31 | python3 code/train.py --source real_world_0-24 --target product_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 32 | 33 | python3 code/train.py --source real_world_0-24 --target clipart_0-64 --epochs_step1 150 --epochs_step2 45 --use_weight_net_first_part --weight_center_loss 0.001 --weight_class_unknown 2 --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 34 | -------------------------------------------------------------------------------- /train_vgg_office31.sh: -------------------------------------------------------------------------------- 1 | 2 | ##########################train of Office-31 dataset using VGGNet######################################################### 3 | 4 | 5 | python3 code/train.py --source amazon_0-9 --target webcam_0-9_20-30 --epochs_step1 100 --epochs_step2 200 --n_classes 10 --n_classes_target 21 --weight_class_unknown 1.5 --weight_center_loss 0.1 --use_VGG --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 6 | 7 | python3 code/train.py --source webcam_0-9 --target dslr_0-9_20-30 --epochs_step1 100 --epochs_step2 200 --n_classes 10 --n_classes_target 21 --weight_class_unknown 1.5 --weight_center_loss 0.1 --use_VGG --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 8 | 9 | python3 code/train.py --source webcam_0-9 --target amazon_0-9_20-30 --epochs_step1 100 --epochs_step2 200 --n_classes 10 --n_classes_target 21 --weight_class_unknown 1.5 --weight_center_loss 0.1 --use_VGG --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 10 | 11 | python3 code/train.py --source dslr_0-9 --target amazon_0-9_20-30 --epochs_step1 100 --epochs_step2 200 --n_classes 10 --n_classes_target 21 --weight_class_unknown 1.5 --weight_center_loss 0.1 --use_VGG --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 12 | 13 | python3 code/train.py --source dslr_0-9 --target webcam_0-9_20-30 --epochs_step1 100 --epochs_step2 200 --n_classes 10 --n_classes_target 21 --weight_class_unknown 1.5 --weight_center_loss 0.1 --use_VGG --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 14 | 15 | python3 code/train.py --source amazon_0-9 --target dslr_0-9_20-30 --epochs_step1 100 --epochs_step2 200 --n_classes 10 --n_classes_target 21 --weight_class_unknown 1.5 --weight_center_loss 0.1 --use_VGG --folder_txt_files /.../ROS/data/ --folder_dataset /.../ROS/ 16 | 17 | --------------------------------------------------------------------------------