├── .gitignore ├── LICENSE ├── code ├── data_gen │ ├── custom_augs.py │ ├── generate_data.py │ ├── im_folder.py │ └── readme.md ├── explain │ ├── data_output.py │ ├── datasets.py │ ├── hypothesis.py │ ├── main.py │ ├── offline_teachers.py │ ├── readme.md │ └── utils.py ├── readme.md └── teaching_app │ ├── Procfile │ ├── application.py │ ├── config.py │ ├── data │ ├── random_image.strat │ ├── random_image_with_explain.strat │ ├── readme.md │ ├── settings.json │ ├── strategy_0.strat │ ├── strategy_1.strat │ └── teaching_images.json │ ├── readme.md │ ├── requirements.txt │ ├── runtime.txt │ ├── static │ ├── style.css │ ├── tutorial_0.jpg │ ├── tutorial_1.jpg │ ├── tutorial_2.jpg │ └── tutorial_3.jpg │ ├── templates │ ├── dashboard.html │ ├── debug.html │ ├── index.html │ ├── results.html │ ├── teaching.html │ └── tutorial.html │ └── utils.py ├── data ├── .gitignore └── readme.md ├── readme.md └── results ├── create_plots.py ├── experiments ├── butterflies_crop │ ├── explain_1vall.strat │ ├── random.strat │ ├── random_feedback.strat │ ├── results.json │ ├── settings.json │ ├── strict_1vall.strat │ └── teaching_images.json ├── chinese_chars │ ├── explain_1vall.strat │ ├── random.strat │ ├── random_feedback.strat │ ├── results.json │ ├── settings.json │ ├── strict_1vall.strat │ └── teaching_images.json ├── chinese_chars_crowd │ ├── explain_1vall.strat │ ├── results.json │ ├── settings.json │ ├── strict_1vall.strat │ └── teaching_images.json └── oct │ ├── explain_1vall.strat │ ├── random.strat │ ├── random_feedback.strat │ ├── results.json │ ├── settings.json │ ├── strict_1vall.strat │ └── teaching_images.json └── readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.png 3 | *.npz 4 | *.zip 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Oisin Mac Aodha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /code/data_gen/custom_augs.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps, ImageEnhance 6 | try: 7 | import accimage 8 | except ImportError: 9 | accimage = None 10 | import numpy as np 11 | import numbers 12 | import types 13 | import collections 14 | import warnings 15 | from torchvision.transforms import Scale, CenterCrop, RandomCrop 16 | 17 | 18 | class RandomSizedCrop(object): 19 | """Crop the given PIL.Image to random size and aspect ratio. 20 | 21 | A crop of random size of (0.08 to 1.0) of the original size and a random 22 | aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop 23 | is finally resized to given size. 24 | This is popularly used to train the Inception networks. 25 | 26 | Args: 27 | size: size of the smaller edge 28 | interpolation: Default: PIL.Image.BILINEAR 29 | """ 30 | 31 | def __init__(self, size, interpolation=Image.BILINEAR): 32 | self.size = size 33 | self.interpolation = interpolation 34 | 35 | def __call__(self, img): 36 | for attempt in range(10): 37 | area = img.size[0] * img.size[1] 38 | target_area = random.uniform(0.80, 1.0) * area 39 | aspect_ratio = random.uniform(7. / 8, 8. / 7) 40 | 41 | w = int(round(math.sqrt(target_area * aspect_ratio))) 42 | h = int(round(math.sqrt(target_area / aspect_ratio))) 43 | 44 | if random.random() < 0.5: 45 | w, h = h, w 46 | 47 | if w <= img.size[0] and h <= img.size[1]: 48 | x1 = random.randint(0, img.size[0] - w) 49 | y1 = random.randint(0, img.size[1] - h) 50 | 51 | img = img.crop((x1, y1, x1 + w, y1 + h)) 52 | assert(img.size == (w, h)) 53 | 54 | return img.resize((self.size, self.size), self.interpolation) 55 | 56 | # Fallback 57 | #scale = Scale(self.size, interpolation=self.interpolation) 58 | #crop = CenterCrop(self.size) 59 | #return crop(scale(img)) 60 | rcrop = RandomCrop(self.size) 61 | return rcrop(img) 62 | 63 | 64 | def _is_pil_image(img): 65 | if accimage is not None: 66 | return isinstance(img, (Image.Image, accimage.Image)) 67 | else: 68 | return isinstance(img, Image.Image) 69 | 70 | def adjust_brightness(img, brightness_factor): 71 | """Adjust brightness of an Image. 72 | Args: 73 | img (PIL Image): PIL Image to be adjusted. 74 | brightness_factor (float): How much to adjust the brightness. Can be 75 | any non negative number. 0 gives a black image, 1 gives the 76 | original image while 2 increases the brightness by a factor of 2. 77 | Returns: 78 | PIL Image: Brightness adjusted image. 79 | """ 80 | if not _is_pil_image(img): 81 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 82 | 83 | enhancer = ImageEnhance.Brightness(img) 84 | img = enhancer.enhance(brightness_factor) 85 | return img 86 | 87 | 88 | def adjust_contrast(img, contrast_factor): 89 | """Adjust contrast of an Image. 90 | Args: 91 | img (PIL Image): PIL Image to be adjusted. 92 | contrast_factor (float): How much to adjust the contrast. Can be any 93 | non negative number. 0 gives a solid gray image, 1 gives the 94 | original image while 2 increases the contrast by a factor of 2. 95 | Returns: 96 | PIL Image: Contrast adjusted image. 97 | """ 98 | if not _is_pil_image(img): 99 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 100 | 101 | enhancer = ImageEnhance.Contrast(img) 102 | img = enhancer.enhance(contrast_factor) 103 | return img 104 | 105 | 106 | def adjust_saturation(img, saturation_factor): 107 | """Adjust color saturation of an image. 108 | Args: 109 | img (PIL Image): PIL Image to be adjusted. 110 | saturation_factor (float): How much to adjust the saturation. 0 will 111 | give a black and white image, 1 will give the original image while 112 | 2 will enhance the saturation by a factor of 2. 113 | Returns: 114 | PIL Image: Saturation adjusted image. 115 | """ 116 | if not _is_pil_image(img): 117 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 118 | 119 | enhancer = ImageEnhance.Color(img) 120 | img = enhancer.enhance(saturation_factor) 121 | return img 122 | 123 | 124 | def adjust_hue(img, hue_factor): 125 | """Adjust hue of an image. 126 | The image hue is adjusted by converting the image to HSV and 127 | cyclically shifting the intensities in the hue channel (H). 128 | The image is then converted back to original image mode. 129 | `hue_factor` is the amount of shift in H channel and must be in the 130 | interval `[-0.5, 0.5]`. 131 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 132 | Args: 133 | img (PIL Image): PIL Image to be adjusted. 134 | hue_factor (float): How much to shift the hue channel. Should be in 135 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 136 | HSV space in positive and negative direction respectively. 137 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 138 | with complementary colors while 0 gives the original image. 139 | Returns: 140 | PIL Image: Hue adjusted image. 141 | """ 142 | if not(-0.5 <= hue_factor <= 0.5): 143 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 144 | 145 | if not _is_pil_image(img): 146 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 147 | 148 | input_mode = img.mode 149 | if input_mode in {'L', '1', 'I', 'F'}: 150 | return img 151 | 152 | h, s, v = img.convert('HSV').split() 153 | 154 | np_h = np.array(h, dtype=np.uint8) 155 | # uint8 addition take cares of rotation across boundaries 156 | with np.errstate(over='ignore'): 157 | np_h += np.uint8(hue_factor * 255) 158 | h = Image.fromarray(np_h, 'L') 159 | 160 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 161 | return img 162 | 163 | 164 | def adjust_gamma(img, gamma, gain=1): 165 | """Perform gamma correction on an image. 166 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 167 | based on the following equation: 168 | I_out = 255 * gain * ((I_in / 255) ** gamma) 169 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 170 | Args: 171 | img (PIL Image): PIL Image to be adjusted. 172 | gamma (float): Non negative real number. gamma larger than 1 make the 173 | shadows darker, while gamma smaller than 1 make dark regions 174 | lighter. 175 | gain (float): The constant multiplier. 176 | """ 177 | if not _is_pil_image(img): 178 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 179 | 180 | if gamma < 0: 181 | raise ValueError('Gamma should be a non-negative real number') 182 | 183 | input_mode = img.mode 184 | img = img.convert('RGB') 185 | 186 | np_img = np.array(img, dtype=np.float32) 187 | np_img = 255 * gain * ((np_img / 255) ** gamma) 188 | np_img = np.uint8(np.clip(np_img, 0, 255)) 189 | 190 | img = Image.fromarray(np_img, 'RGB').convert(input_mode) 191 | return img 192 | 193 | class Compose(object): 194 | """Composes several transforms together. 195 | Args: 196 | transforms (list of ``Transform`` objects): list of transforms to compose. 197 | Example: 198 | >>> transforms.Compose([ 199 | >>> transforms.CenterCrop(10), 200 | >>> transforms.ToTensor(), 201 | >>> ]) 202 | """ 203 | 204 | def __init__(self, transforms): 205 | self.transforms = transforms 206 | 207 | def __call__(self, img): 208 | for t in self.transforms: 209 | img = t(img) 210 | return img 211 | 212 | 213 | class Lambda(object): 214 | """Apply a user-defined lambda as a transform. 215 | Args: 216 | lambd (function): Lambda/function to be used for transform. 217 | """ 218 | 219 | def __init__(self, lambd): 220 | assert isinstance(lambd, types.LambdaType) 221 | self.lambd = lambd 222 | 223 | def __call__(self, img): 224 | return self.lambd(img) 225 | 226 | 227 | class ColorJitter(object): 228 | """Randomly change the brightness, contrast and saturation of an image. 229 | Args: 230 | brightness (float): How much to jitter brightness. brightness_factor 231 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 232 | contrast (float): How much to jitter contrast. contrast_factor 233 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 234 | saturation (float): How much to jitter saturation. saturation_factor 235 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 236 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 237 | [-hue, hue]. Should be >=0 and <= 0.5. 238 | """ 239 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 240 | self.brightness = brightness 241 | self.contrast = contrast 242 | self.saturation = saturation 243 | self.hue = hue 244 | 245 | @staticmethod 246 | def get_params(brightness, contrast, saturation, hue): 247 | """Get a randomized transform to be applied on image. 248 | Arguments are same as that of __init__. 249 | Returns: 250 | Transform which randomly adjusts brightness, contrast and 251 | saturation in a random order. 252 | """ 253 | transforms = [] 254 | if brightness > 0: 255 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 256 | transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) 257 | 258 | if contrast > 0: 259 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 260 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) 261 | 262 | if saturation > 0: 263 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 264 | transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) 265 | 266 | if hue > 0: 267 | hue_factor = np.random.uniform(-hue, hue) 268 | transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) 269 | 270 | np.random.shuffle(transforms) 271 | transform = Compose(transforms) 272 | 273 | return transform 274 | 275 | def __call__(self, img): 276 | """ 277 | Args: 278 | img (PIL Image): Input image. 279 | Returns: 280 | PIL Image: Color jittered image. 281 | """ 282 | transform = self.get_params(self.brightness, self.contrast, 283 | self.saturation, self.hue) 284 | return transform(img) -------------------------------------------------------------------------------- /code/data_gen/generate_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | from torch.autograd import Variable 8 | import im_folder 9 | import matplotlib.pyplot as plt 10 | from sklearn import decomposition 11 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 12 | import numpy as np 13 | import os 14 | from scipy.ndimage import zoom 15 | import torchvision 16 | from scipy.spatial.distance import squareform, pdist 17 | from sklearn.preprocessing import normalize 18 | from scipy.ndimage.filters import gaussian_filter 19 | import custom_augs 20 | import cv2 21 | 22 | 23 | def gen_mask(ip_mask, im, blur=False): 24 | # assume im in [0,1] and RGB 25 | mask = ip_mask.copy() 26 | mask -= np.min(mask) 27 | mask /= np.max(mask) 28 | 29 | mask = np.exp(3*mask) # scale it 30 | mask /= np.max(mask) 31 | 32 | if blur: 33 | mask = gaussian_filter(mask.copy(), 5) 34 | mask = mask*0.99 + 0.01 35 | 36 | mask = np.tile(mask[..., np.newaxis], (1, 1, 3)) 37 | 38 | if im.dtype == np.uint8: 39 | im = im.astype(np.float)/255 40 | 41 | if len(im.shape) == 3: 42 | im = im[:,:,:3] 43 | if len(im.shape) == 2: 44 | im = np.tile(im[..., np.newaxis], (1, 1, 3)) 45 | 46 | op = im*mask 47 | op[op>1] = 1 48 | op[op<0] = 0 49 | return op 50 | 51 | 52 | def plot_proj(feats_op, gt, classes, title_txt, fid): 53 | cols = ['r', 'g', 'b', 'c', 'm', 'y', 'k'] 54 | plt.figure(fid) 55 | for cc in np.unique(gt): 56 | inds = np.where(gt==cc)[0] 57 | plt.plot(feats_op[inds, 0], feats_op[inds, 1], cols[cc]+'.', label=classes[cc]) 58 | 59 | plt.legend() 60 | plt.title(title_txt) 61 | plt.axis('equal') 62 | plt.show() 63 | 64 | 65 | class FTNet(torch.nn.Module): 66 | def __init__(self, backbone, num_classes, compute_bias): 67 | super(FTNet, self).__init__() 68 | op_chns = 64 69 | self.backbone = backbone 70 | self.conv1 = nn.Conv2d(128, op_chns, kernel_size=3, padding=1) 71 | self.fc = torch.nn.Linear(op_chns, num_classes, bias=compute_bias) 72 | 73 | def forward(self, data): 74 | x = self.backbone(data) 75 | x_feat = F.relu(self.conv1(x)) 76 | x = F.adaptive_avg_pool2d(x_feat, 1) 77 | x_flat = x.view(x.size(0), -1) 78 | cls_op = self.fc(x_flat) 79 | return x_feat, x_flat, cls_op 80 | 81 | 82 | def cam_mapper(act, cam_weight, cam_bias): 83 | cam_map = np.zeros((act.shape[0], act.shape[-1],act.shape[-1], num_classes)) 84 | for cc in range(cam_weight.shape[1]): 85 | cam_map[:,:,:,cc] = (act*cam_weight[:, cc][..., np.newaxis, np.newaxis] + cam_bias[cc]).sum(1) 86 | cam_map[cam_map<0] = 0 87 | return cam_map 88 | 89 | 90 | def resize_cam(cam, orig_size, crop_size): 91 | # assumes square 92 | diff = (orig_size - crop_size) / 2 93 | resize_fact = crop_size / cam.shape[0] 94 | cam_op = np.ones((orig_size, orig_size, cam.shape[-1]))*cam.min() 95 | for cc in range(cam.shape[-1]): 96 | zm = zoom(cam[:,:,cc], (resize_fact, resize_fact), order=1) 97 | cam_op[diff:-diff,diff:-diff, cc] = zm.copy() 98 | cam_op[:,:,cc] = gaussian_filter(cam_op[:,:,cc], sigma=1.5) 99 | return cam_op 100 | 101 | 102 | # Training settings 103 | lr = 0.0002 104 | weight_decay = 1e-4 105 | momentum = 0.9 106 | batch_size = 64 107 | test_batch_size = 64 108 | log_interval = 100 109 | compute_bias = False 110 | 111 | dataset = 'chinese_chars' # 'oct', 'butterflies_crop', 'chinese_chars' 112 | base_dir = '../../data/' 113 | save_op = False 114 | save_debug_ims = False 115 | 116 | if save_op == False: 117 | print('***\nNot saving outputs\n***') 118 | 119 | if dataset == 'oct': 120 | root_dir = 'oct/images/' 121 | explain_dir = 'oct/explanations/' 122 | op_file_name = 'oct' 123 | orig_size = 144 124 | crop_size = 128 125 | epochs = 60 126 | elif dataset == 'chinese_chars': 127 | root_dir = 'chinese_chars/images/' 128 | explain_dir = 'chinese_chars/explanations/' 129 | op_file_name = 'chinese_chars' 130 | orig_size = 144 131 | crop_size = 128 132 | epochs = 60 133 | elif dataset == 'butterflies_crop': 134 | root_dir = 'butterflies_crop/images/' 135 | explain_dir = 'butterflies_crop/explanations/' 136 | op_file_name = 'butterflies_crop' 137 | orig_size = 144 # assumes square 138 | crop_size = 128 139 | epochs = 60 140 | 141 | 142 | is_cuda = True 143 | kwargs = {'num_workers': 6, 'pin_memory': True} if is_cuda else {} 144 | plt.close('all') 145 | 146 | mu_data = [0.485, 0.456, 0.406] 147 | std_data = [0.229, 0.224, 0.225] 148 | 149 | train_transform = transforms.Compose([ 150 | transforms.RandomHorizontalFlip(), 151 | custom_augs.RandomSizedCrop(crop_size), 152 | custom_augs.ColorJitter(brightness=0.4,contrast=0.4,saturation=0.4,hue=0.25), 153 | transforms.ToTensor(), 154 | transforms.Normalize(mean=mu_data, std=std_data)]) 155 | 156 | test_transform = transforms.Compose([ 157 | transforms.CenterCrop(128), 158 | transforms.ToTensor(), 159 | transforms.Normalize(mean=mu_data, std=std_data)]) 160 | 161 | # Note this is currently loading all files into both train and test 162 | train_dataset = im_folder.ImageFolder(root=base_dir+root_dir, transform=train_transform) 163 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs) 164 | all_dataset = im_folder.ImageFolder(root=base_dir+root_dir, transform=test_transform) 165 | all_loader = torch.utils.data.DataLoader(all_dataset, batch_size=batch_size, shuffle=False, **kwargs) 166 | 167 | # get dataset details 168 | data, target = next(iter(all_loader)) 169 | num_channels = int(data.size()[1]) 170 | class_names = all_loader.dataset.classes 171 | num_classes = len(class_names) 172 | gt_labels = np.asarray(all_loader.dataset.class_labels) 173 | imgs = all_loader.dataset.imgs 174 | imgs_files = [ii[len(base_dir):] for ii in imgs] 175 | explain_files = [explain_dir + ii[len(base_dir+root_dir):] for ii in imgs] 176 | print('class names', class_names) 177 | 178 | for cc in class_names: 179 | if not os.path.isdir(base_dir + explain_dir + cc): 180 | os.makedirs(base_dir + explain_dir + cc) 181 | 182 | # use resnet BB 183 | resnet = torchvision.models.resnet18(pretrained=False) 184 | resnetbb = nn.Sequential(*list(resnet.children())[:-4]) 185 | model = FTNet(resnetbb, num_classes, compute_bias) 186 | 187 | if is_cuda: 188 | model.cuda() 189 | 190 | 191 | # train CNN 192 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay) 193 | l1_loss_weight = nn.L1Loss() 194 | l1_target_weight = torch.autograd.Variable(torch.zeros(model.fc.weight.size()).cuda()) 195 | 196 | for epoch in range(1, epochs + 1): 197 | 198 | if epoch == (epochs/2): 199 | print('dropping learning rate to ', lr / 10.0) 200 | for param_group in optimizer.param_groups: 201 | param_group['lr'] = lr / 10.0 202 | 203 | model.train() 204 | for batch_idx, (data, target) in enumerate(train_loader): 205 | if is_cuda: 206 | data, target = data.cuda(), target.cuda() 207 | data, target = Variable(data), Variable(target) 208 | 209 | optimizer.zero_grad() 210 | _, _, output = model(data) 211 | #weight_loss = 1000*l1_loss_weight(model.fc.weight, l1_target_weight) 212 | loss = F.cross_entropy(output, target)# + weight_loss 213 | 214 | loss.backward() 215 | optimizer.step() 216 | 217 | if batch_idx % log_interval == 0: 218 | print('Train: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 219 | epoch, batch_idx * len(data), len(train_loader.dataset), 220 | 100. * batch_idx / len(train_loader), loss.data[0])) 221 | 222 | # running test 223 | model.eval() 224 | test_loss = 0 225 | correct = 0 226 | for data, target in all_loader: 227 | if is_cuda: 228 | data, target = data.cuda(), target.cuda() 229 | data, target = Variable(data, volatile=True), Variable(target) 230 | _, _, output = model(data) 231 | test_loss += F.cross_entropy(output, target, size_average=False).data[0] # sum up batch loss 232 | pred = output.data.max(1)[1] # get the index of the max log-probability 233 | correct += pred.eq(target.data).cpu().sum() 234 | 235 | test_loss /= len(all_loader.dataset) 236 | print('Test: Loss: {:.4f}, Acc: {}/{} ({:.0f}%)\n'.format( 237 | test_loss, correct, len(all_loader.dataset), 238 | 100. * correct / len(all_loader.dataset))) 239 | 240 | 241 | model.eval() 242 | cam_weight = model.fc.weight.cpu().data.numpy().T 243 | if compute_bias: 244 | cam_bias = model.fc.bias.cpu().data.numpy() 245 | else: 246 | cam_bias = np.zeros(num_classes) 247 | 248 | 249 | # extract features 250 | feats_op = [] 251 | cam_op = [] 252 | logits = [] 253 | cnt = 0 254 | for ii, (data, target) in enumerate(all_loader): 255 | if is_cuda: 256 | data, target = data.cuda(), target.cuda() 257 | data, target = Variable(data, volatile=True), Variable(target) 258 | feats, feats_flat, output = model(data) 259 | 260 | cam = cam_mapper(feats.cpu().data.numpy(), cam_weight, cam_bias) 261 | cam_op.append(cam) 262 | feats_op.append(feats_flat.cpu().data.numpy()) 263 | logits.append(output.cpu().data.numpy()) 264 | 265 | for bb in range(cam.shape[0]): 266 | 267 | # save explanation images 268 | cam_rs = resize_cam(cam[bb, :], orig_size, crop_size) 269 | if save_op: 270 | im = plt.imread(imgs[cnt]) 271 | expl_im = gen_mask(cam_rs[:,:,gt_labels[cnt]], im) 272 | exp_op = base_dir + explain_dir + class_names[gt_labels[cnt]] + '/' + os.path.basename(imgs[cnt]) 273 | cv2.imwrite(exp_op, cv2.cvtColor((expl_im*255).astype(np.uint8), cv2.COLOR_RGB2BGR)) 274 | 275 | cnt += 1 276 | 277 | feats_op = np.vstack((feats_op)) 278 | cam_op = np.vstack((cam_op)) 279 | logits = np.vstack((logits)) 280 | pred_labels = logits.argmax(1) 281 | 282 | 283 | # visualize results 284 | if save_debug_ims: 285 | if not os.path.isdir('im_heat'): 286 | os.makedirs('im_heat') 287 | if not os.path.isdir('im_expl'): 288 | os.makedirs('im_expl') 289 | 290 | print('saving some ims') 291 | for jj in range(30): 292 | plt.close('all') 293 | im_id = np.random.randint(len(imgs)) 294 | im = plt.imread(imgs[im_id]) 295 | cam = resize_cam(cam_op[im_id,:], orig_size, crop_size) 296 | 297 | plt.figure(3) 298 | plt.gcf().suptitle('%s, GT: %s, Pred: %s' % (os.path.basename(imgs[im_id]), class_names[gt_labels[im_id]], class_names[pred_labels[im_id]])) 299 | plt.subplot(np.ceil(0.1 + num_classes/2.0), 2, 1) 300 | plt.imshow(im, cmap='gray', interpolation='bilinear');plt.axis('off') 301 | plt.title('ip im') 302 | for ii in range(num_classes): 303 | plt.subplot(np.ceil(0.01 + num_classes/2.0), 2, ii+2) 304 | plt.imshow(cam[:,:,ii], vmin=0, vmax=cam.max(), interpolation='bilinear') 305 | plt.axis('off');plt.title('*'*int(gt_labels[im_id]==ii) + class_names[ii]) 306 | plt.savefig('im_heat/' + str(jj).zfill(3) + '.png') 307 | 308 | expl_im = gen_mask(cam[:,:,gt_labels[im_id]], im) 309 | plt.figure(4) 310 | plt.gcf().suptitle('%s, GT: %s, Pred: %s' % (os.path.basename(imgs[im_id]), class_names[gt_labels[im_id]], class_names[pred_labels[im_id]])) 311 | plt.subplot(1, 2, 1) 312 | plt.imshow(im, cmap='gray', vmin=0, vmax=1) 313 | plt.title('ip im') 314 | plt.axis('off') 315 | plt.subplot(1, 2, 2) 316 | plt.imshow(expl_im, cmap='gray', vmin=0, vmax=1) 317 | plt.axis('off') 318 | plt.title('expl') 319 | plt.savefig('im_expl/' + str(jj).zfill(3) + '.png') 320 | 321 | 322 | # plot 323 | pca = decomposition.PCA(n_components=2) 324 | pca.fit(feats_op, gt_labels) 325 | feats_pca = pca.transform(feats_op) 326 | plot_proj(feats_pca, gt_labels, class_names, 'PCA', 0) 327 | 328 | # PW distance 329 | sorted_inds = np.argsort(gt_labels) 330 | feats_op_norm = normalize(feats_op) 331 | dist = squareform(pdist(feats_op_norm[sorted_inds, :], 'cosine')) 332 | print(class_names) 333 | 334 | plt.figure(1) 335 | plt.imshow(dist) 336 | plt.show() 337 | 338 | # save 339 | if save_op: 340 | print('saving results ' + op_file_name) 341 | np.savez(base_dir + op_file_name, im_files=imgs_files, explain_files=explain_files, 342 | X=feats_op, Y=gt_labels, Y_pred=pred_labels, interp=cam_op, 343 | class_names=class_names) 344 | else: 345 | print('not saving output') 346 | -------------------------------------------------------------------------------- /code/data_gen/im_folder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 10 | ] 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | 17 | def find_classes(dir): 18 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 19 | classes.sort() 20 | class_to_idx = {classes[i]: i for i in range(len(classes))} 21 | return classes, class_to_idx 22 | 23 | 24 | def make_dataset(dir, class_to_idx): 25 | images = [] 26 | class_labels = [] 27 | dir = os.path.expanduser(dir) 28 | for target in os.listdir(dir): 29 | d = os.path.join(dir, target) 30 | if not os.path.isdir(d): 31 | continue 32 | 33 | for root, _, fnames in sorted(os.walk(d)): 34 | for fname in fnames: 35 | if is_image_file(fname): 36 | path = os.path.join(root, fname) 37 | images.append(path) 38 | class_labels.append(class_to_idx[target]) 39 | 40 | return images, class_labels 41 | 42 | 43 | def pil_loader(path): 44 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 45 | with open(path, 'rb') as f: 46 | with Image.open(f) as img: 47 | # if img.mode is 'L': 48 | # return img.convert('L') 49 | # else: 50 | # return img.convert('RGB') 51 | return img.convert('RGB') 52 | 53 | 54 | def accimage_loader(path): 55 | import accimage 56 | try: 57 | return accimage.Image(path) 58 | except IOError: 59 | # Potentially a decoding problem, fall back to PIL.Image 60 | return pil_loader(path) 61 | 62 | 63 | def default_loader(path): 64 | from torchvision import get_image_backend 65 | if get_image_backend() == 'accimage': 66 | return accimage_loader(path) 67 | else: 68 | return pil_loader(path) 69 | 70 | 71 | class ImageFolder(data.Dataset): 72 | """A generic data loader where the images are arranged in this way: :: 73 | 74 | root/dog/xxx.png 75 | root/dog/xxy.png 76 | root/dog/xxz.png 77 | 78 | root/cat/123.png 79 | root/cat/nsdf3.png 80 | root/cat/asd932_.png 81 | 82 | Args: 83 | root (string): Root directory path. 84 | transform (callable, optional): A function/transform that takes in an PIL image 85 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 86 | target_transform (callable, optional): A function/transform that takes in the 87 | target and transforms it. 88 | loader (callable, optional): A function to load an image given its path. 89 | 90 | Attributes: 91 | classes (list): List of the class names. 92 | class_to_idx (dict): Dict with items (class_name, class_index). 93 | imgs (list): List of (image path, class_index) tuples 94 | """ 95 | 96 | def __init__(self, root, transform=None, target_transform=None, 97 | loader=default_loader): 98 | classes, class_to_idx = find_classes(root) 99 | imgs, class_labels = make_dataset(root, class_to_idx) 100 | if len(imgs) == 0: 101 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 102 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 103 | 104 | self.root = root 105 | self.imgs = imgs 106 | self.class_labels = class_labels 107 | self.classes = classes 108 | self.class_to_idx = class_to_idx 109 | self.transform = transform 110 | self.target_transform = target_transform 111 | self.loader = loader 112 | 113 | def __getitem__(self, index): 114 | """ 115 | Args: 116 | index (int): Index 117 | 118 | Returns: 119 | tuple: (image, target) where target is class_index of the target class. 120 | """ 121 | path = self.imgs[index] 122 | target = self.class_labels[index] 123 | img = self.loader(path) 124 | if self.transform is not None: 125 | img = self.transform(img) 126 | if self.target_transform is not None: 127 | target = self.target_transform(target) 128 | 129 | return img, target 130 | 131 | def __len__(self): 132 | return len(self.imgs) 133 | -------------------------------------------------------------------------------- /code/data_gen/readme.md: -------------------------------------------------------------------------------- 1 | # Teaching Categories to Human Learners with Visual Explanations - Data Pre-processing 2 | This directory contains scripts for pre-processing the datasets and outputting explanations. 3 | Run `generate_data.py` to generate the features and explanations for the teaching experiments. 4 | 5 | ## Requirements 6 | Code was developed using pytorch 0.1.12. 7 | 8 | ## Reference 9 | If you find our work useful in your research please consider citing our paper: 10 | ``` 11 | @inproceedings{explainteachcvpr18, 12 | title = {Teaching Categories to Human Learners with Visual Explanations}, 13 | author = {Mac Aodha, Oisin and Su, Shihan and Chen, Yuxin and Perona, Pietro and Yue, Yisong}, 14 | booktitle = {CVPR}, 15 | year = {2018} 16 | } 17 | ``` 18 | -------------------------------------------------------------------------------- /code/explain/data_output.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import numpy as np 4 | 5 | 6 | def save_teaching_sequence(teacher, alg_name, op_file_name): 7 | # saves the teaching sequences so they can be used by the webapp 8 | results = {} 9 | num_train = len(teacher.teaching_exs) 10 | results['num_train'] = num_train 11 | 12 | if 'rand' not in alg_name: 13 | if 'strict' in alg_name: 14 | results['im_ids'] = teacher.teaching_exs 15 | results['display_explain_image'] = [0 for ii in range(num_train)] 16 | if 'explain' in alg_name: 17 | results['im_ids'] = teacher.teaching_exs 18 | results['display_explain_image'] = [1 for ii in range(num_train)] 19 | 20 | with open(op_file_name, 'w') as js: 21 | json.dump(results, js) 22 | 23 | 24 | def save_teaching_images(dataset_train, dataset_test, op_file_name, url_root): 25 | 26 | teaching_ims = [] 27 | for ii in range(len(dataset_train['im_files'])): 28 | im_data = {} 29 | im_data['image_url'] = url_root + dataset_train['im_files'][ii] 30 | im_data['explain_url'] = url_root + dataset_train['explain_files'][ii] 31 | im_data['class_label'] = dataset_train['Y'][ii] 32 | teaching_ims.append(im_data) 33 | 34 | # puts test images at the end 35 | for ii in range(len(dataset_test['im_files'])): 36 | im_data = {} 37 | im_data['image_url'] = url_root + dataset_test['im_files'][ii] 38 | im_data['class_label'] = dataset_test['Y'][ii] 39 | teaching_ims.append(im_data) 40 | 41 | with open(op_file_name, 'w') as js: 42 | json.dump(teaching_ims, js) 43 | 44 | 45 | def save_settings(dataset_train, dataset_test, experiment_id, num_random_test_ims, scale, op_file_name): 46 | settings = {} 47 | settings['experiment_id'] = experiment_id 48 | settings['train_indices'] = range(len(dataset_train['im_files'])) 49 | settings['test_indices'] = [len(dataset_train['im_files'])+tt for tt in range(len(dataset_test['im_files']))] 50 | settings['test_sequence'] = [-1]*num_random_test_ims 51 | settings['class_names'] = [cc.replace('_', ' ') for cc in dataset_train['class_names']] 52 | settings['scale'] = scale 53 | 54 | with open(op_file_name, 'w') as js: 55 | json.dump(settings, js) 56 | -------------------------------------------------------------------------------- /code/explain/datasets.py: -------------------------------------------------------------------------------- 1 | from sklearn import datasets 2 | from sklearn.datasets import make_blobs 3 | import numpy as np 4 | from sklearn.decomposition import PCA 5 | from sklearn.utils import resample 6 | from sklearn.model_selection import train_test_split 7 | import offline_teachers as teach 8 | import utils as ut 9 | 10 | 11 | def load_datasets(dataset_name, dataset_dir, do_pca, pca_dims, add_bias, remove_mean, density_sigma, interp_sigma): 12 | print dataset_name 13 | 14 | im_files = None 15 | explain_files = None 16 | class_names = None 17 | explain_interp = None # for the explanation 1.0 means easy to interpret and 0.0 means hard 18 | 19 | if dataset_name == 'iris': 20 | iris = datasets.load_iris() 21 | X = iris.data 22 | Y = iris.target 23 | elif dataset_name == 'wine': 24 | wine = datasets.load_wine() 25 | X = wine.data 26 | Y = wine.target 27 | elif dataset_name == 'breast_cancer': 28 | bc = datasets.load_breast_cancer() 29 | X = bc.data 30 | Y = bc.target 31 | elif dataset_name == '2d_outlier': 32 | num_exs = 100 33 | sig = 0.005 34 | pt = 0.3 35 | cls1 = np.random.multivariate_normal([pt, pt], [[sig, 0],[0,sig]], int(num_exs*0.8)) 36 | cls2 = np.random.multivariate_normal([-pt, -pt], [[sig, 0],[0,sig]], int(num_exs*0.8)) 37 | # add "noise" 38 | cls1n = np.random.multivariate_normal([pt, pt], [[sig*10, 0],[0,sig*10]], int(num_exs*0.2)) 39 | cls2n = np.random.multivariate_normal([-pt, -pt], [[sig*10, 0],[0,sig*10]], int(num_exs*0.2)) 40 | X = np.vstack((cls1, cls1n, cls2, cls2n)) 41 | Y = np.ones(X.shape[0]).astype(np.int) 42 | Y[:int(num_exs*0.8)+int(num_exs*0.2)] = 0 43 | elif dataset_name == '3blobs': 44 | num_exs = 80 45 | cls1 = np.random.multivariate_normal([1.0, -1.0], [[0.12, 0],[0,0.12]], num_exs) 46 | cls2 = np.random.multivariate_normal([-1.0, -1.0], [[0.12, 0],[0,0.12]], num_exs) 47 | cls3 = np.random.multivariate_normal([-1.0, 1.0], [[0.12, 0],[0,0.12]], num_exs) 48 | X = np.vstack((cls1,cls2, cls3)) 49 | Y = np.ones(X.shape[0]).astype(np.int) 50 | Y[:num_exs] = 0 51 | elif dataset_name == 'blobs_2_class': 52 | X, Y = make_blobs(n_samples=200, centers=2, random_state=0) 53 | elif dataset_name == 'blobs_3_class': 54 | X, Y = make_blobs(n_samples=300, centers=3, random_state=0) 55 | else: 56 | X, Y, im_files, explain_files, class_names, explain_interp = load_data(dataset_dir, dataset_name, interp_sigma) 57 | 58 | if im_files is None: 59 | im_files = np.asarray(['']*X.shape[0]) 60 | if explain_files is None: 61 | explain_files = np.asarray(['']*X.shape[0]) 62 | if class_names is None: 63 | class_names = np.asarray(['']*np.unique(Y).shape[0]) 64 | if explain_interp is None: 65 | explain_interp = np.ones(X.shape[0]) 66 | 67 | # standardize 68 | if remove_mean: 69 | X = X - X.mean(0) 70 | X = X / X.std(0) 71 | 72 | # do PCA 73 | if do_pca and X.shape[1] > 2: 74 | pca = PCA(n_components=2) 75 | pca.fit(X) 76 | X = pca.transform(X) 77 | X = X - X.mean(0) 78 | X = X / X.std(0) 79 | 80 | # add 1 for bias (intercept) term 81 | if add_bias: 82 | X = np.hstack((X, np.ones(X.shape[0])[..., np.newaxis])) 83 | 84 | # balance datasets - same number of examples per class 85 | X, Y, im_files, explain_files, explain_interp = balance_data(X, Y, im_files, explain_files, explain_interp) 86 | 87 | # train test split 88 | dataset_train, dataset_test = make_train_test_split(X, Y, im_files, explain_files, class_names, explain_interp) 89 | 90 | # density of points 91 | dataset_train['X_density'] = ut.compute_density(dataset_train['X'], dataset_train['Y'], density_sigma, True) 92 | 93 | print 'train split' 94 | print dataset_train['X'].shape[0], 'instances' 95 | print dataset_train['X'].shape[1], 'features' 96 | print np.unique(dataset_train['Y']).shape[0], 'classes' 97 | 98 | return dataset_train, dataset_test 99 | 100 | 101 | def load_data(dataset_dir, dataset_name, interp_sigma): 102 | data = np.load(dataset_dir + dataset_name + '.npz') 103 | X = data['X'] 104 | Y = data['Y'] 105 | im_files = data['im_files'] 106 | explain_files = data['explain_files'] 107 | class_names = data['class_names'].tolist() 108 | 109 | # compute interpretability 110 | if 'interp' not in data.keys(): 111 | # does not exist so set them all the same 112 | explain_interp = np.ones(X.shape[0]) 113 | elif len(data['interp'].shape) == 1: 114 | # already computed 115 | explain_interp = data['interp'] 116 | explain_interp = 1.0 / (1.0 + np.exp(-interp_sigma*(explain_interp+0.0000001))) 117 | else: 118 | # not computed, generate it from explanation images 119 | print 'computing interpretability' 120 | explain_interp = ut.compute_interpretability(data['interp'], data['Y'], data['Y_pred'], interp_sigma) 121 | 122 | return X, Y, im_files, explain_files, class_names, explain_interp 123 | 124 | 125 | def make_train_test_split(X, Y, im_files, explain_files, class_names, explain_interp): 126 | # split_data = [X_train, X_test, Y_trains, ...] 127 | split_data = train_test_split(X, Y, im_files, explain_files, explain_interp, test_size=0.2, random_state=0) 128 | 129 | datasets = [] 130 | for dd in range(2): 131 | dataset = {} 132 | dataset['X'] = split_data[dd+0] 133 | dataset['Y'] = split_data[dd+2] 134 | dataset['im_files'] = split_data[dd+4] 135 | dataset['explain_files'] = split_data[dd+6] 136 | dataset['explain_interp'] = split_data[dd+8] 137 | dataset['class_names'] = class_names 138 | datasets.append(dataset) 139 | 140 | return datasets[0], datasets[1] 141 | 142 | 143 | def balance_data(X, Y, im_files, explain_files, explain_interp): 144 | # ensure there is an equal number of examples per class 145 | 146 | # shuffle 147 | X,Y,im_files,explain_files,explain_interp = resample(X,Y,im_files,explain_files,explain_interp,replace=False,random_state=0) 148 | min_cnt = X.shape[0] 149 | for cc in np.unique(Y): 150 | if (Y==cc).sum() < min_cnt: 151 | min_cnt = (Y==cc).sum() 152 | 153 | inds = [] 154 | for cc in np.unique(Y): 155 | inds.extend(np.where(Y==cc)[0][:min_cnt]) 156 | 157 | X = X[inds, :] 158 | Y = Y[inds] 159 | im_files = im_files[inds] 160 | explain_files = explain_files[inds] 161 | explain_interp = explain_interp[inds] 162 | 163 | return X, Y, im_files, explain_files, explain_interp 164 | 165 | 166 | def remove_exs(dataset, hyps, err_hyp, alpha, split_name, one_v_all): 167 | # only keep examples that we can predict with the best hypothesis 168 | if one_v_all: 169 | if np.unique(dataset['Y'].shape[0]) == 2: 170 | # binary 171 | optimal_index = np.argmin(err_hyp[0]) 172 | _, pred_class = teach.user_model_binary(hyps[optimal_index], dataset['X'], dataset['Y'], alpha) 173 | inds = np.where(dataset['Y'] == pred_class)[0] 174 | else: 175 | # multi class 176 | correctly_predicted = np.zeros(dataset['Y'].shape[0]) 177 | for cc in range(len(err_hyp)): 178 | optimal_index = np.argmin(err_hyp[cc]) 179 | Y_bin = np.zeros(dataset['Y'].shape[0]).astype(np.int) 180 | Y_bin[np.where(dataset['Y']==cc)[0]] = 1 181 | _, pred_class = teach.user_model_binary(hyps[optimal_index], dataset['X'], Y_bin, alpha) 182 | correctly_predicted[np.where(Y_bin == pred_class)[0]] += 1 183 | inds = np.where(correctly_predicted == len(err_hyp))[0] 184 | else: 185 | optimal_index = np.argmin(err_hyp) 186 | _, pred_class = teach.user_model(hyps[optimal_index], dataset['X'], dataset['Y'], alpha) 187 | inds = np.where(dataset['Y'] == pred_class)[0] 188 | print dataset['X'].shape[0] - inds.shape[0], split_name, 'examples removed' 189 | 190 | # remove the examples 191 | dataset['X'] = dataset['X'][inds, :] 192 | dataset['Y'] = dataset['Y'][inds] 193 | dataset['im_files'] = dataset['im_files'][inds] 194 | dataset['explain_files'] = dataset['explain_files'][inds] 195 | dataset['explain_interp'] = dataset['explain_interp'][inds] 196 | cls_un, cls_cnt = np.unique(dataset['Y'], return_counts=True) 197 | if 'X_density' in dataset.keys(): 198 | dataset['X_density'] = dataset['X_density'][inds] 199 | 200 | print '\n', split_name 201 | for cc in range(len(cls_cnt)): 202 | print cls_un[cc], dataset['class_names'][cls_un[cc]].ljust(30), '\t', cls_cnt[cc] 203 | 204 | return dataset 205 | 206 | -------------------------------------------------------------------------------- /code/explain/hypothesis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.cluster import KMeans 3 | import random 4 | from sklearn import svm 5 | import offline_teachers as teach 6 | import itertools 7 | 8 | 9 | def compute_hyps_error_one_vs_all(hyps, X, Y, alpha): 10 | # compute err(h, h*) - list of length C with vectors of H 11 | 12 | err_hyps = [] 13 | if np.unique(Y).shape[0] == 2: 14 | # if only two classes don't need to do both 15 | err_hyps.append(compute_hyps_error(hyps, X, Y, alpha, True)) 16 | else: 17 | # multi class 18 | for cc in np.unique(Y): 19 | Y_bin = np.zeros(Y.shape[0]).astype(np.int) 20 | Y_bin[np.where(Y==cc)[0]] = 1 21 | err = compute_hyps_error(hyps, X, Y_bin, alpha, True) 22 | err_hyps.append(err) 23 | 24 | return err_hyps 25 | 26 | 27 | def compute_hyps_error(hyps, X, Y, alpha, one_v_all=False): 28 | # compute err(h, h*) - vector of length H 29 | err_hyp = np.zeros(len(hyps)) 30 | for hh in range(len(hyps)): 31 | if one_v_all: 32 | _, pred_class = teach.user_model_binary(hyps[hh], X, Y, alpha) 33 | else: 34 | _, pred_class = teach.user_model(hyps[hh], X, Y, alpha) 35 | err_hyp[hh] = (Y != pred_class).sum() / float(Y.shape[0]) 36 | 37 | return err_hyp 38 | 39 | 40 | def cluster_hyps(X, Y, num_hyps, alpha, clf): 41 | # this is for the multi class hypothesis case - number of combinations explodes 42 | 43 | # generate hypotheses by clustering data 44 | num_classes = np.unique(Y).shape[0] 45 | 46 | # all possible combinations of two classes 47 | class_combs = [] 48 | for ii in range(np.minimum(num_classes-1,2)): 49 | class_combs.extend(itertools.combinations(range(num_classes), ii + 1)) 50 | 51 | classifiers = [] 52 | for cc in class_combs: 53 | inds = [] 54 | for ii in cc: 55 | inds.append(np.where(Y==ii)[0]) 56 | inds = np.hstack(inds) 57 | 58 | # fit SVM 59 | tmp_labels = np.zeros(Y.shape[0]).astype(np.int) 60 | tmp_labels[inds] = 1 61 | clf.fit(X, tmp_labels) 62 | classifiers.append(clf.coef_[0,:].copy()) 63 | classifiers = np.vstack(classifiers) 64 | 65 | # cant use all possible combinations, just use a subset 66 | all_combs = list(itertools.permutations(range(len(classifiers)), num_classes)) 67 | subset_combs = random.sample(range(len(all_combs)), np.minimum(num_hyps,len(all_combs))) 68 | print len(all_combs), 'possible combinations of PW classifers' 69 | 70 | # copy hypothesis 71 | hyps = [] 72 | for ss in subset_combs: 73 | inds = all_combs[ss] 74 | hyps.append(classifiers[inds, :].copy()) 75 | 76 | # add teacher i.e. best hypothesis - trained on all data 77 | clf.fit(X, Y) 78 | hyps.append(clf.coef_.copy()) 79 | 80 | return hyps 81 | 82 | 83 | def cluster_hyps_one_v_all(X, Y, alpha, clf): 84 | # generate hypotheses by clustering data for 1 versus all 85 | 86 | hyps = [] 87 | num_classes = np.unique(Y).shape[0] 88 | clusters_per_class = 2 89 | 90 | # 1) sub classes against the rest 91 | cinds = 0 92 | Y_hal = np.zeros(Y.shape[0]).astype(np.int) 93 | for cc in np.unique(Y): 94 | inds = np.where(Y==cc)[0] 95 | kmeans = KMeans(n_clusters=clusters_per_class).fit(X[inds, :]) 96 | Y_hal[inds] = kmeans.labels_.copy()+cinds 97 | cinds += clusters_per_class 98 | 99 | for cc in np.unique(Y_hal): 100 | inds = np.where(Y_hal==cc)[0] 101 | tmp_labels = np.zeros(Y_hal.shape[0]).astype(np.int) 102 | tmp_labels[inds] = 1 103 | clf.fit(X, tmp_labels) 104 | hyps.append(clf.coef_[0,:].copy()) 105 | 106 | # 2) each class against the rest - GT 107 | for cc in np.unique(Y): 108 | inds = np.where(Y==cc)[0] 109 | tmp_labels = np.zeros(Y.shape[0]).astype(np.int) 110 | tmp_labels[inds] = 1 111 | clf.fit(X, tmp_labels) 112 | hyps.append(clf.coef_[0,:].copy()) 113 | 114 | # 3) pairs of classes against the rest 115 | combs = list(itertools.combinations(range(num_classes), 2)) 116 | for cc in combs: 117 | inds = [] 118 | for cur_class in cc: 119 | inds.append(np.where(Y==cur_class)) 120 | tmp_labels = np.zeros(Y.shape[0]).astype(np.int) 121 | tmp_labels[np.hstack(inds)] = 1 122 | clf.fit(X, tmp_labels) 123 | hyps.append(clf.coef_[0,:].copy()) 124 | 125 | return hyps 126 | 127 | 128 | def sparse_hyps(X, Y, num_hyps, one_v_all, clf, fit_gt=False): 129 | # sparse hypotheses with small number of -1 or 1 entries 130 | num_non_zero = 2 131 | num_classes = np.unique(Y).shape[0] 132 | hyps = [] 133 | for hh in range(num_hyps): 134 | if one_v_all or (num_classes == 2): 135 | w = np.zeros((X.shape[1])) 136 | inds = random.sample(range(X.shape[1]), num_non_zero) 137 | w[inds] = np.random.choice([-1,1], num_non_zero) 138 | else: 139 | w = np.zeros((num_classes, X.shape[1])) 140 | for cc in range(num_classes): 141 | inds = random.sample(range(X.shape[1]), num_non_zero) 142 | w[cc, inds] = np.random.choice([-1,1], num_non_zero) 143 | hyps.append(w) 144 | 145 | # add GT 146 | if fit_gt: 147 | clf.fit(X, Y) 148 | if one_v_all: 149 | for cc in range(num_classes): 150 | hyps.append(clf.coef_[cc, :].copy()) 151 | else: 152 | hyps.append(clf.coef_.copy()) 153 | 154 | return hyps 155 | 156 | 157 | def random_hyps(X, Y, num_hyps, alpha, one_v_all, clf, fit_gt=False): 158 | # generate random set of hypotheses 159 | num_classes = np.unique(Y).shape[0] 160 | hyps = [] 161 | for hh in range(num_hyps): 162 | if one_v_all: 163 | hyp = np.random.randn(X.shape[1]) 164 | elif num_classes == 2: 165 | hh = np.random.randn(X.shape[1]) 166 | hyp = np.vstack((-hh, hh)) 167 | else: 168 | hyp = np.random.randn(num_classes, X.shape[1]) 169 | hyps.append(hyp) 170 | 171 | # add GT 172 | if fit_gt: 173 | clf.fit(X, Y) 174 | if one_v_all: 175 | for cc in range(num_classes): 176 | hyps.append(clf.coef_[cc, :].copy()) 177 | elif num_classes == 2: 178 | hyps.append(np.vstack((-clf.coef_.copy(), clf.coef_.copy()))) 179 | else: 180 | hyps.append(clf.coef_.copy()) 181 | 182 | return hyps 183 | 184 | 185 | def generate_hyps(dataset, alpha, num_hyps, hyp_type, one_v_all): 186 | # generates the hypothesis space 187 | # if one_v_all is True we create D dim hypothesis otherwise we do CxD 188 | X = dataset['X'] 189 | Y = dataset['Y'] 190 | num_classes = np.unique(Y).shape[0] 191 | clf = svm.LinearSVC(fit_intercept=False, penalty='l1', loss='squared_hinge', dual=False) 192 | 193 | # if only 2D we will add negative versions of hyps later for visualization 194 | # <= 3 as we might have bias term 195 | if X.shape[1] <= 3 and num_classes == 2: 196 | print '2D dataset -> generating less hypotheses' 197 | num_hyps /= 2 198 | 199 | if hyp_type == 'cluster': 200 | if one_v_all: 201 | hyps = cluster_hyps_one_v_all(X, Y, alpha, clf) 202 | else: 203 | hyps = cluster_hyps(X, Y, num_hyps, alpha, clf) 204 | 205 | if hyp_type == 'cluster_rand': 206 | if one_v_all: 207 | hyps = cluster_hyps_one_v_all(X, Y, alpha, clf) 208 | else: 209 | hyps = cluster_hyps(X, Y, num_hyps, alpha, clf) 210 | 211 | if num_hyps - len(hyps) > 0: 212 | hyps.extend(random_hyps(X, Y, num_hyps - len(hyps), alpha, one_v_all, clf, False)) 213 | 214 | elif hyp_type == 'rand': 215 | hyps = random_hyps(X, Y, num_hyps, alpha, one_v_all, clf, True) 216 | elif hyp_type == 'sparse': 217 | hyps = sparse_hyps(X, Y, num_hyps, one_v_all, clf, True) 218 | 219 | # if 2D data add negative versions of hypothesis - makes visualization easier 220 | if X.shape[1] <= 3 and num_classes == 2: 221 | hyps_opposite = [] 222 | for hh in range(len(hyps)): 223 | hyps_opposite.append(hyps[hh].copy()*-1) 224 | hyps.extend(hyps_opposite) 225 | 226 | # create prior 227 | prior_h = np.ones(len(hyps)) / float(len(hyps)) 228 | 229 | return hyps, prior_h 230 | -------------------------------------------------------------------------------- /code/explain/main.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import os 4 | import offline_teachers as teach 5 | import datasets as ds 6 | import utils as ut 7 | import data_output as op 8 | import hypothesis as hp 9 | 10 | 11 | plt.close('all') 12 | dataset_root = '../../data/' 13 | datasets = ['blobs_2_class', '2d_outlier', 'blobs_3_class', '3blobs', 14 | 'iris', 'breast_cancer', 'wine', 15 | 'oct', 'butterflies_crop', 'chinese_chars', 'chinese_chars_crowd'] 16 | dataset_name = datasets[7] 17 | 18 | experiment_id = 0 19 | num_teaching_itrs = 20 20 | num_random_test_ims = 20 21 | num_init_hyps = 100 22 | density_sigma = 1.0 23 | interp_sigma = 1.0 24 | alpha = 0.5 25 | image_scale = 2.0 26 | hyp_type = 'cluster_rand' # rand, cluster, cluster_rand, sparse 27 | dataset_dir = dataset_root + dataset_name + '/' 28 | url_root = '' # set this to the location of the images on the web 29 | 30 | save_ops = False 31 | add_bias = True 32 | remove_mean = True 33 | do_pca = False 34 | pca_dims = 2 35 | 36 | 37 | op_dir = 'output/' + str(experiment_id) +'/' 38 | if save_ops: 39 | print 'saving output to', op_dir 40 | if not os.path.isdir(op_dir): 41 | os.makedirs(op_dir) 42 | 43 | # load data 44 | dataset_train, dataset_test = ds.load_datasets(dataset_name, dataset_dir, do_pca, pca_dims, add_bias, remove_mean, density_sigma, interp_sigma) 45 | if len(np.unique(dataset_train['Y'])) > 2: 46 | one_v_all = True # multi class 47 | else: 48 | one_v_all = False # binary 49 | 50 | # generate set of hypotheses 51 | hyps, prior_h = hp.generate_hyps(dataset_train, alpha, num_init_hyps, hyp_type, one_v_all) 52 | print len(hyps), hyp_type, 'hypotheses\n' 53 | 54 | # remove examples that are inconsistent with best hypothesis 55 | if one_v_all: 56 | err_hyp = hp.compute_hyps_error_one_vs_all(hyps, dataset_train['X'], dataset_train['Y'], alpha) 57 | else: 58 | err_hyp = hp.compute_hyps_error(hyps, dataset_train['X'], dataset_train['Y'], alpha) 59 | dataset_train = ds.remove_exs(dataset_train, hyps, err_hyp, alpha, 'train', one_v_all) 60 | 61 | # re compute hypothesis errors - after removing inconsistent examples 62 | if one_v_all: 63 | err_hyp = hp.compute_hyps_error_one_vs_all(hyps, dataset_train['X'], dataset_train['Y'], alpha) 64 | err_hyp_test = hp.compute_hyps_error_one_vs_all(hyps, dataset_test['X'], dataset_test['Y'], alpha) 65 | else: 66 | err_hyp = hp.compute_hyps_error(hyps, dataset_train['X'], dataset_train['Y'], alpha) 67 | err_hyp_test = hp.compute_hyps_error(hyps, dataset_test['X'], dataset_test['Y'], alpha) 68 | 69 | # compute the likelihood for each datapoint according to each hypothesis 70 | if one_v_all: 71 | likelihood = ut.compute_likelihood_one_vs_all(hyps, dataset_train['X'], dataset_train['Y'], alpha) 72 | else: 73 | likelihood = ut.compute_likelihood(hyps, dataset_train['X'], dataset_train['Y'], alpha) 74 | 75 | # teachers 76 | teachers = {} 77 | if one_v_all: 78 | teachers['rand_1vall'] = teach.RandomImageTeacherOneVsAll(dataset_train, alpha, prior_h) 79 | teachers['strict_1vall'] = teach.StrictTeacherOneVsAll(dataset_train, alpha, prior_h) 80 | teachers['explain_1vall'] = teach.ExplainTeacherOneVsAll(dataset_train, alpha, prior_h) 81 | else: 82 | teachers['random'] = teach.RandomImageTeacher(dataset_train, alpha, prior_h) 83 | teachers['strict'] = teach.StrictTeacher(dataset_train, alpha, prior_h) 84 | teachers['explain'] = teach.ExplainTeacher(dataset_train, alpha, prior_h) 85 | 86 | # run teaching 87 | for alg_name in teachers.keys(): 88 | print alg_name 89 | teachers[alg_name].run_teaching(num_teaching_itrs, dataset_train, likelihood, hyps, err_hyp, err_hyp_test) 90 | 91 | # plot in 2D 92 | fig_id = 0 93 | if (dataset_train['X'].shape[1] <= 3): 94 | for alg_name in teachers.keys(): 95 | if one_v_all: 96 | ut.plot_2D_data(dataset_train['X'], dataset_train['Y'], alpha, hyps, teachers[alg_name].teaching_exs, teachers[alg_name].posterior(), alg_name, fig_id, one_v_all, np.argmin(err_hyp)) 97 | else: 98 | ut.plot_2D_data_hyper(dataset_train['X'], dataset_train['Y'], alpha, hyps, teachers[alg_name].teaching_exs, teachers[alg_name].posterior(), alg_name, fig_id, one_v_all, np.argmin(err_hyp)) 99 | fig_id += 1 100 | 101 | plt.figure(fig_id) 102 | plt.title('learners expected error - train') 103 | for alg_name in teachers.keys(): 104 | exp_err = teachers[alg_name].exp_err 105 | plt.plot(np.arange(len(exp_err))+1, exp_err, label=alg_name) 106 | plt.legend() 107 | if save_ops: 108 | plt.savefig(op_dir + 'eer.pdf') 109 | 110 | plt.figure(fig_id+1) 111 | plt.title('learners expected error - test') 112 | for alg_name in teachers.keys(): 113 | exp_err_test = teachers[alg_name].exp_err_test 114 | plt.plot(np.arange(len(exp_err_test))+1, exp_err_test, label=alg_name) 115 | plt.legend() 116 | if save_ops: 117 | plt.savefig(op_dir + 'eer_test.pdf') 118 | 119 | plt.figure(fig_id+2) 120 | plt.title('example difficulty') 121 | for alg_name in teachers.keys(): 122 | difficulty = teachers[alg_name].difficulty 123 | plt.plot(np.arange(len(difficulty))+1, difficulty, label=alg_name) 124 | plt.legend() 125 | plt.show() 126 | 127 | # save strategy files 128 | if not save_ops: 129 | print '\nnot saving outputs' 130 | else: 131 | print '\nsaving outputs' 132 | for alg_name in teachers.keys(): 133 | op.save_teaching_sequence(teachers[alg_name], alg_name, op_dir + alg_name + '.strat') 134 | 135 | op.save_teaching_images(dataset_train, dataset_test, op_dir + 'teaching_images.json', url_root) 136 | op.save_settings(dataset_train, dataset_test, experiment_id, num_random_test_ims, image_scale, op_dir + 'settings.json') 137 | np.savez(op_dir + 'params.npz', dataset_train=dataset_train, dataset_test=dataset_test, hyps=hyps, teachers=teachers) 138 | -------------------------------------------------------------------------------- /code/explain/offline_teachers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.stats import entropy 3 | 4 | 5 | def user_model_binary(w, x, y, alpha): 6 | # binary user model - w is D and X is NxD 7 | # prob is probability that the hyp agrees with the datapoint 8 | # need to make prob = 1.0 / (1.0 + np.exp(-z*2*(2*y-1))) to be same as softmax 9 | if len(w.shape) == 2: 10 | z = alpha*np.dot(x, w[1,:]) 11 | else: 12 | z = alpha*np.dot(x, w) 13 | pred_class = (z>0).astype(np.int) # will be 0 or 1 14 | prob = 1.0 / (1.0 + np.exp(-z*(2*y-1))) # make y={-1,1} 15 | return prob, pred_class 16 | 17 | 18 | def user_model(w, x, y, alpha): 19 | # multi-class user model - w is CxD and X is NxD 20 | # prob is probability that the hyp agrees with the datapoint 21 | z = alpha*np.dot(x, w.T) 22 | pred_class = np.argmax(z,1) 23 | z_norm = np.exp(z) / np.exp(z).sum(1)[..., np.newaxis] 24 | 25 | prob = z_norm[np.arange(x.shape[0]), pred_class] # pred_class == y 26 | inds = np.where(pred_class != y)[0] 27 | prob[inds] = 1.0 - prob[inds] # pred_class != y 28 | return prob, pred_class 29 | 30 | 31 | def teaching_stats(cur_post, pred, err_hyp, err_hyp_test): 32 | 33 | cur_post_norm = cur_post/cur_post.sum() 34 | exp_err = (cur_post_norm*err_hyp).sum() 35 | exp_err_test = (cur_post_norm*err_hyp_test).sum() 36 | ent = entropy(cur_post_norm) 37 | 38 | z = (cur_post_norm*pred).sum() + 0.0000000001 # add small noise 39 | difficulty = -(z*np.log2(z) + (1-z)*np.log2(1-z)) 40 | 41 | return exp_err, exp_err_test, ent, difficulty 42 | 43 | 44 | def teaching_stats_one_vs_all(cur_post, pred, err_hyp, err_hyp_test): 45 | 46 | exp_err = np.empty(cur_post.shape) 47 | exp_err_test = np.empty(cur_post.shape) 48 | entropy = np.empty(cur_post.shape) 49 | difficulty = np.empty(cur_post.shape) 50 | for cc in range(cur_post.shape[0]): 51 | exp_err[cc, :], exp_err_test[cc, :], entropy[cc, :], difficulty[cc, :] = teaching_stats(cur_post[cc, :], pred[cc, :], err_hyp[cc], err_hyp_test[cc]) 52 | 53 | return exp_err.mean(), exp_err_test.mean(), entropy.mean(), difficulty.mean() 54 | 55 | 56 | class StrictTeacher: 57 | # Singla et al. Near-Optimally Teaching the Crowd to Classify 58 | # https://arxiv.org/abs/1402.2092 59 | 60 | def __init__(self, dataset, alpha, prior_h): 61 | self.initialize(dataset['X'], dataset['Y'], alpha, prior_h) 62 | 63 | def initialize(self, X, Y, alpha, prior_h): 64 | self.teaching_exs = [] 65 | self.unseen_exs = np.arange(X.shape[0]) 66 | self.prior_h = prior_h 67 | self.cur_post = prior_h.copy() 68 | self.alpha = alpha 69 | self.exp_err = [] 70 | self.exp_err_test = [] 71 | self.hyp_entropy = [] 72 | self.difficulty = [] 73 | 74 | def posterior(self): 75 | return self.cur_post/self.cur_post.sum() 76 | 77 | def run_teaching(self, num_teaching_itrs, dataset, likelihood, hyps, err_hyp, err_hyp_test): 78 | for tt in range(num_teaching_itrs): 79 | self.teaching_iteration(dataset['X'], dataset['Y'], likelihood, hyps, err_hyp, err_hyp_test) 80 | 81 | def teaching_iteration(self, X, Y, likelihood, hyps, err_hyp, err_hyp_test): 82 | 83 | # this is eqivalent to looping over h and x 84 | # comes from separating P(h|(A U x)) into P(h|A)P(h|x) 85 | err = -np.dot(self.cur_post*err_hyp, likelihood) 86 | selected_ind = self.unseen_exs[np.argmax(err[self.unseen_exs])] 87 | 88 | # update the posterior with the selected example 89 | self.cur_post *= likelihood[:, selected_ind] 90 | 91 | # get predictions for each hyp for selected example 92 | pred = np.zeros(len(hyps)) 93 | for hh in range(len(hyps)): 94 | pred[hh], _ = user_model(hyps[hh], X[selected_ind,:][np.newaxis, ...], Y[selected_ind], self.alpha) 95 | 96 | # bookkeeping and compute stats 97 | print len(self.teaching_exs), '\t', Y[selected_ind], '\t', selected_ind, '\t', round(err[self.unseen_exs].max(),4) 98 | ee, ee_test, ent, diff = teaching_stats(self.cur_post, pred, err_hyp, err_hyp_test) 99 | self.exp_err.append(ee) 100 | self.exp_err_test.append(ee_test) 101 | self.hyp_entropy.append(ent) 102 | self.difficulty.append(diff) 103 | self.teaching_exs.append(selected_ind) 104 | self.unseen_exs = np.setdiff1d(np.arange(X.shape[0]), self.teaching_exs) 105 | 106 | def teaching_iteration_slow(self, X, Y, hyps, err_hyp): 107 | eer = np.zeros(self.unseen_exs.shape[0]) 108 | for ii, ex in enumerate(self.unseen_exs): 109 | cur_post_delta = np.ones(len(hyps)) 110 | for hh in range(len(hyps)): 111 | 112 | # can store a H*X matrix where it will be 1 where hyp gets it correct and y_p else where 113 | y_p, pred_class = user_model(hyps[hh], X[ex,:][np.newaxis, ...], Y[ex], self.alpha) 114 | if pred_class != Y[ex]: 115 | cur_post_delta[hh] *= y_p 116 | eer[ii] += (self.prior_h[hh] - (self.cur_post[hh]*cur_post_delta[hh]))*err_hyp[hh] 117 | #eer[ii] += -(self.cur_post[hh]*cur_post_delta[hh])*err_hyp[hh] # dont need to subtract prior 118 | 119 | # recompute the posterior of the selected example 120 | selected_ind = self.unseen_exs[np.argmax(eer)] 121 | pred = np.zeros(len(hyps)) 122 | for hh in range(len(hyps)): 123 | pred[hh], pred_class = user_model(hyps[hh], X[selected_ind,:][np.newaxis, ...], Y[selected_ind], self.alpha) 124 | if pred_class != Y[selected_ind]: 125 | self.cur_post[hh] *= pred[hh] 126 | 127 | # bookkeeping and compute stats 128 | print len(self.teaching_exs), '\t', selected_ind, '\t', round(eer.max(),4) 129 | ee, ee_test, ent, diff = teaching_stats(self.cur_post, pred, err_hyp, err_hyp_test) 130 | self.exp_err.append(ee) 131 | self.exp_err_test.append(ee_test) 132 | self.hyp_entropy.append(ent) 133 | self.difficulty.append(diff) 134 | self.teaching_exs.append(selected_ind) 135 | self.unseen_exs = np.setdiff1d(np.arange(X.shape[0]), self.teaching_exs) 136 | 137 | 138 | class StrictTeacherOneVsAll: 139 | # 1 vs all version of 140 | # Singla et al. Near-Optimally Teaching the Crowd to Classify 141 | # https://arxiv.org/abs/1402.2092 142 | 143 | def __init__(self, dataset, alpha, prior_h): 144 | self.initialize(dataset['X'], dataset['Y'], alpha, prior_h) 145 | 146 | def initialize(self, X, Y, alpha, prior_h): 147 | self.teaching_exs = [] 148 | self.unseen_exs = np.arange(X.shape[0]) 149 | self.classes = np.unique(Y) # TODO need to update this for binary 150 | self.prior_h = np.tile(prior_h, (len(self.classes), 1)) 151 | self.cur_post = np.tile(prior_h.copy(), (len(self.classes), 1)) 152 | self.alpha = alpha 153 | self.exp_err = [] 154 | self.exp_err_test = [] 155 | self.hyp_entropy = [] 156 | self.difficulty = [] 157 | 158 | def posterior(self): 159 | return self.cur_post/self.cur_post.sum(1)[..., np.newaxis] 160 | 161 | def run_teaching(self, num_teaching_itrs, dataset, likelihood, hyps, err_hyp, err_hyp_test): 162 | for tt in range(num_teaching_itrs): 163 | self.teaching_iteration(dataset['X'], dataset['Y'], likelihood, hyps, err_hyp, err_hyp_test) 164 | 165 | def teaching_iteration(self, X, Y, likelihood, hyps, err_hyp, err_hyp_test): 166 | 167 | err = np.empty((len(self.classes), X.shape[0])) 168 | for cc in self.classes: 169 | # this is eqivalent to looping over h and x 170 | # comes from separating P(h|(A U x)) into P(h|A)P(h|x) 171 | err[cc, :] = -np.dot(self.cur_post[cc, :]*err_hyp[cc], likelihood[cc]) 172 | 173 | if len(self.classes) > 2: 174 | err = err.sum(0) # could try other methods for combining, min, max, ... 175 | selected_ind = self.unseen_exs[np.argmax(err[self.unseen_exs])] 176 | 177 | # update the posterior with the selected example 178 | for cc in self.classes: 179 | self.cur_post[cc, :] *= likelihood[cc][:, selected_ind] 180 | 181 | # get predictions for each hyp for selected example 182 | pred = np.zeros((len(self.classes), len(hyps))) 183 | for cc in self.classes: 184 | Y_bin = int(Y[selected_ind] == cc) 185 | for hh in range(len(hyps)): 186 | pred[cc, hh], _ = user_model_binary(hyps[hh], X[selected_ind,:][np.newaxis, ...], Y_bin, self.alpha) 187 | 188 | # bookkeeping and compute stats 189 | print len(self.teaching_exs), '\t', Y[selected_ind], '\t', selected_ind, '\t', round(err[self.unseen_exs].max(),4) 190 | ee, ee_test, ent, diff = teaching_stats_one_vs_all(self.cur_post, pred, err_hyp, err_hyp_test) 191 | self.exp_err.append(ee) 192 | self.exp_err_test.append(ee_test) 193 | self.hyp_entropy.append(ent) 194 | self.difficulty.append(diff) 195 | self.teaching_exs.append(selected_ind) 196 | self.unseen_exs = np.setdiff1d(np.arange(X.shape[0]), self.teaching_exs) 197 | 198 | 199 | class ExplainTeacher: 200 | def __init__(self, dataset, alpha, prior_h): 201 | self.initialize(dataset['X'], dataset['Y'], alpha, prior_h) 202 | 203 | def initialize(self, X, Y, alpha, prior_h): 204 | self.teaching_exs = [] 205 | self.unseen_exs = np.arange(X.shape[0]) 206 | self.prior_h = prior_h 207 | self.cur_post = prior_h.copy() 208 | self.alpha = alpha 209 | self.exp_err = [] 210 | self.exp_err_test = [] 211 | self.hyp_entropy = [] 212 | self.difficulty = [] 213 | 214 | def posterior(self): 215 | return self.cur_post/self.cur_post.sum() 216 | 217 | def run_teaching(self, num_teaching_itrs, dataset, likelihood, hyps, err_hyp, err_hyp_test): 218 | for tt in range(num_teaching_itrs): 219 | self.teaching_iteration(dataset['X'], dataset['Y'], dataset['X_density'], dataset['explain_interp'], likelihood, hyps, err_hyp, err_hyp_test) 220 | 221 | def teaching_iteration(self, X, Y, X_density, interpretability, likelihood, hyps, err_hyp, err_hyp_test): 222 | # X_density is how representative points are - dont want to select outliers 223 | # interpretability is how easy it is for user to make sense of explanation 224 | 225 | # this is eqivalent to looping over h and x 226 | # comes from separating P(h|(A U x)) into P(h|A)P(h|x) 227 | 228 | # err is negative, we want to find max. To increase it we multiply by smaller numbers 229 | # this has the effect of discounting less the relevant ones 230 | err = -np.dot(self.cur_post*err_hyp, likelihood) 231 | err = err*X_density*interpretability 232 | selected_ind = self.unseen_exs[np.argmax(err[self.unseen_exs])] 233 | 234 | # update the posterior with the selected example 235 | self.cur_post *= likelihood[:, selected_ind]*X_density[selected_ind]*interpretability[selected_ind] 236 | #self.cur_post = self.cur_post / self.cur_post.sum() # don't need to renormalize 237 | 238 | # get predictions for each hyp for selected example 239 | pred = np.zeros(len(hyps)) 240 | for hh in range(len(hyps)): 241 | pred[hh], _ = user_model(hyps[hh], X[selected_ind,:][np.newaxis, ...], Y[selected_ind], self.alpha) 242 | 243 | # bookkeeping and compute stats 244 | print len(self.teaching_exs), '\t', Y[selected_ind], '\t', selected_ind, '\t', round(err[self.unseen_exs].max(),4) 245 | ee, ee_test, ent, diff = teaching_stats(self.cur_post, pred, err_hyp, err_hyp_test) 246 | self.exp_err.append(ee) 247 | self.exp_err_test.append(ee_test) 248 | self.hyp_entropy.append(ent) 249 | self.difficulty.append(diff) 250 | self.teaching_exs.append(selected_ind) 251 | self.unseen_exs = np.setdiff1d(np.arange(X.shape[0]), self.teaching_exs) 252 | 253 | 254 | class ExplainTeacherOneVsAll: 255 | # 1 vs all version 256 | 257 | def __init__(self, dataset, alpha, prior_h): 258 | self.initialize(dataset['X'], dataset['Y'], alpha, prior_h) 259 | 260 | def initialize(self, X, Y, alpha, prior_h): 261 | self.teaching_exs = [] 262 | self.unseen_exs = np.arange(X.shape[0]) 263 | self.classes = np.unique(Y) 264 | self.prior_h = np.tile(prior_h, (len(self.classes), 1)) 265 | self.cur_post = np.tile(prior_h.copy(), (len(self.classes), 1)) 266 | self.alpha = alpha 267 | self.exp_err = [] 268 | self.exp_err_test = [] 269 | self.hyp_entropy = [] 270 | self.difficulty = [] 271 | 272 | def posterior(self): 273 | return self.cur_post/self.cur_post.sum(1)[..., np.newaxis] 274 | 275 | def run_teaching(self, num_teaching_itrs, dataset, likelihood, hyps, err_hyp, err_hyp_test): 276 | for tt in range(num_teaching_itrs): 277 | self.teaching_iteration(dataset['X'], dataset['Y'], dataset['X_density'], dataset['explain_interp'], likelihood, hyps, err_hyp, err_hyp_test) 278 | 279 | def teaching_iteration(self, X, Y, X_density, interpretability, likelihood, hyps, err_hyp, err_hyp_test): 280 | # X_density is how representative points are - dont want to select outliers 281 | # interpretability is how easy it is for user to make sense of explanation 282 | 283 | err = np.empty((len(self.classes), X.shape[0])) 284 | for cc in self.classes: 285 | # this is eqivalent to looping over h and x 286 | # comes from separating P(h|(A U x)) into P(h|A)P(h|x) 287 | err[cc, :] = -np.dot(self.cur_post[cc, :]*err_hyp[cc], likelihood[cc]) 288 | # TODO should interpretability be per class or just for GT? 289 | err[cc, :] = err[cc, :]*X_density*interpretability 290 | 291 | if len(self.classes) > 2: 292 | err = err.sum(0) # could try other methods for combining, min, max, ... 293 | selected_ind = self.unseen_exs[np.argmax(err[self.unseen_exs])] 294 | 295 | # update the posterior with the selected example 296 | for cc in self.classes: 297 | #self.cur_post[cc, :] *= likelihood[cc][:, selected_ind] 298 | self.cur_post[cc, :] *= likelihood[cc][:, selected_ind]*X_density[selected_ind]*interpretability[selected_ind] 299 | 300 | # get predictions for each hyp for selected example 301 | pred = np.zeros((len(self.classes), len(hyps))) 302 | for cc in self.classes: 303 | Y_bin = int(Y[selected_ind] == cc) 304 | for hh in range(len(hyps)): 305 | pred[cc, hh], _ = user_model_binary(hyps[hh], X[selected_ind,:][np.newaxis, ...], Y_bin, self.alpha) 306 | 307 | # bookkeeping and compute stats 308 | print len(self.teaching_exs), '\t', Y[selected_ind], '\t', selected_ind, '\t', round(err[self.unseen_exs].max(),4) 309 | ee, ee_test, ent, diff = teaching_stats_one_vs_all(self.cur_post, pred, err_hyp, err_hyp_test) 310 | self.exp_err.append(ee) 311 | self.exp_err_test.append(ee_test) 312 | self.hyp_entropy.append(ent) 313 | self.difficulty.append(diff) 314 | self.teaching_exs.append(selected_ind) 315 | self.unseen_exs = np.setdiff1d(np.arange(X.shape[0]), self.teaching_exs) 316 | 317 | 318 | class RandomImageTeacher: 319 | # assumes CxD hypotheses 320 | 321 | def __init__(self, dataset, alpha, prior_h): 322 | self.initialize(alpha, prior_h) 323 | 324 | def initialize(self, alpha, prior_h): 325 | self.teaching_exs = [] 326 | self.alpha = alpha 327 | self.cur_post = prior_h.copy() 328 | self.exp_err = [] 329 | self.exp_err_test = [] 330 | self.hyp_entropy = [] 331 | self.difficulty = [] 332 | 333 | def posterior(self): 334 | return self.cur_post/self.cur_post.sum() 335 | 336 | def run_teaching(self, num_teaching_itrs, dataset, likelihood, hyps, err_hyp, err_hyp_test): 337 | X = dataset['X'] 338 | Y = dataset['Y'] 339 | self.teaching_exs = np.random.choice(X.shape[0], num_teaching_itrs, replace=False) 340 | 341 | for teaching_ex in self.teaching_exs: 342 | 343 | # compute the posterior of the selected example 344 | pred = np.zeros(len(hyps)) 345 | for hh in range(len(hyps)): 346 | pred[hh], pred_class = user_model(hyps[hh], X[teaching_ex,:][np.newaxis, ...], Y[teaching_ex], self.alpha) 347 | if pred_class != Y[teaching_ex]: 348 | self.cur_post[hh] *= pred[hh] 349 | 350 | # bookkeeping and compute stats 351 | ee, ee_test, ent, diff = teaching_stats(self.cur_post, pred, err_hyp, err_hyp_test) 352 | self.exp_err.append(ee) 353 | self.exp_err_test.append(ee_test) 354 | self.hyp_entropy.append(ent) 355 | self.difficulty.append(diff) 356 | 357 | 358 | class RandomImageTeacherOneVsAll: 359 | # assumes 1xD hypotheses 360 | 361 | def __init__(self, dataset, alpha, prior_h): 362 | self.initialize(dataset['X'], dataset['Y'], alpha, prior_h) 363 | 364 | def initialize(self, X, Y, alpha, prior_h): 365 | self.teaching_exs = [] 366 | self.unseen_exs = np.arange(X.shape[0]) 367 | self.alpha = alpha 368 | self.classes = np.unique(Y) 369 | self.prior_h = np.tile(prior_h, (len(self.classes), 1)) 370 | self.cur_post = np.tile(prior_h.copy(), (len(self.classes), 1)) 371 | self.exp_err = [] 372 | self.exp_err_test = [] 373 | self.hyp_entropy = [] 374 | self.difficulty = [] 375 | 376 | def posterior(self): 377 | return self.cur_post/self.cur_post.sum(1)[..., np.newaxis] 378 | 379 | def run_teaching(self, num_teaching_itrs, dataset, likelihood, hyps, err_hyp, err_hyp_test): 380 | for tt in range(num_teaching_itrs): 381 | self.teaching_iteration(dataset['X'], dataset['Y'], likelihood, hyps, err_hyp, err_hyp_test) 382 | 383 | def teaching_iteration(self, X, Y, likelihood, hyps, err_hyp, err_hyp_test): 384 | 385 | selected_ind = np.random.choice(self.unseen_exs) 386 | 387 | # update the posterior with the selected example 388 | for cc in self.classes: 389 | self.cur_post[cc, :] *= likelihood[cc][:, selected_ind] 390 | 391 | # get predictions for each hyp for selected example 392 | pred = np.zeros((len(self.classes), len(hyps))) 393 | for cc in self.classes: 394 | Y_bin = int(Y[selected_ind] == cc) 395 | for hh in range(len(hyps)): 396 | pred[cc, hh], _ = user_model_binary(hyps[hh], X[selected_ind,:][np.newaxis, ...], Y_bin, self.alpha) 397 | 398 | # bookkeeping and compute stats 399 | print len(self.teaching_exs), '\t', Y[selected_ind], '\t', selected_ind 400 | ee, ee_test, ent, diff = teaching_stats_one_vs_all(self.cur_post, pred, err_hyp, err_hyp_test) 401 | self.exp_err.append(ee) 402 | self.exp_err_test.append(ee_test) 403 | self.hyp_entropy.append(ent) 404 | self.difficulty.append(diff) 405 | self.teaching_exs.append(selected_ind) 406 | self.unseen_exs = np.setdiff1d(np.arange(X.shape[0]), self.teaching_exs) 407 | -------------------------------------------------------------------------------- /code/explain/readme.md: -------------------------------------------------------------------------------- 1 | # Teaching Categories to Human Learners with Visual Explanations - Main Code 2 | Code for generating teaching sequences in multi-class setting. Random, Strict, and Explain are implemented. 3 | 4 | ## Notes 5 | Run `main.py` to execute code. 6 | It will output teaching strategy files that can be used in the web interface. 7 | For 2D binary datasets it will plot the hypothesis space. 8 | 9 | 10 | The following are good settings for visualization: 11 | ``` 12 | dataset_name = datasets[1] 13 | do_pca = True 14 | pca_dims = 2 15 | num_init_hyps = 10 16 | num_teaching_itrs = 5 17 | hyp_type = 'rand' 18 | ``` 19 | 20 | When comparing the multi-class model to the binary one make sure to scale alpha appropriately i.e. for the binary case `alpha = alpha/2.0`. 21 | 22 | `hyps` is a list of hypothesis where each entry is a CxD matrix. For a binary classification problem a hypothesis = np.vstack((-w, w)), where w is the linear weight vector. 23 | `X` is the NxD feature matrix. 24 | `Y` is the N vector of labels. Labels go from 0 to number of classes. 25 | `explain_interp` is the N vector of explanation interpretability. 1.0 means easy to interpret and 0 means hard. Setting these to all ones will be the same as using vanilla strict. 26 | 27 | 28 | ## Reference 29 | If you find our work useful in your research please consider citing our paper: 30 | ``` 31 | @inproceedings{explainteachcvpr18, 32 | title = {Teaching Categories to Human Learners with Visual Explanations}, 33 | author = {Mac Aodha, Oisin and Su, Shihan and Chen, Yuxin and Perona, Pietro and Yue, Yisong}, 34 | booktitle = {CVPR}, 35 | year = {2018} 36 | } 37 | ``` 38 | -------------------------------------------------------------------------------- /code/explain/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import offline_teachers as teach 4 | import random 5 | from scipy.spatial.distance import pdist, squareform 6 | from scipy.stats import entropy 7 | 8 | 9 | def compute_interpretability(explains, Y, Y_pred, interp_sigma): 10 | # WARNING might choose examples with no explanations that will be biased to being picked 11 | 12 | # want to be high if its a good explanation 13 | ent = np.zeros(Y.shape[0]) 14 | for ii in range(Y.shape[0]): 15 | explain_pred = explains[ii, :, :, Y[ii]].copy() 16 | explain_pred -= explain_pred.min() 17 | #ent[ii] = entropy(explain_pred.ravel()) 18 | explain_pred /= explain_pred.max() 19 | aa = explain_pred.ravel() + 0.0000001 20 | ent[ii] = -(np.log(aa)*aa).mean() 21 | 22 | # if the predctions from CNN don't match the GT we should discourage showing this example 23 | for ii in range(Y.shape[0]): 24 | if Y_pred[ii] != Y[ii]: 25 | ent[ii] = ent.max() 26 | 27 | # remove the class mean entropy - to prevent bias towards some classes 28 | min_ent = ent.min() 29 | for cc in np.unique(Y): 30 | inds = np.where(Y==cc)[0] 31 | mu = ent[inds].mean() 32 | ent[inds] -= mu 33 | 34 | # put in the range [0,1], 0 is easiest, 1 is hardest 35 | ent -= ent.min() 36 | ent /= ent.max() 37 | 38 | # low entropy means discounting more 39 | ent = 1.0 / (1.0 + np.exp(-interp_sigma*(ent + 0.0000001))) 40 | return ent 41 | 42 | 43 | def compute_likelihood_one_vs_all(hyps, X, Y, alpha): 44 | likelihoods = [] 45 | if np.unique(Y).shape[0] == 2: 46 | # binary 47 | ll = compute_likelihood(hyps, X, Y, alpha, True) 48 | likelihoods.append(ll) 49 | else: 50 | # multi class 51 | for cc in np.unique(Y): 52 | Y_bin = np.zeros(Y.shape[0]).astype(np.int) 53 | Y_bin[np.where(Y==cc)[0]] = 1 54 | ll = compute_likelihood(hyps, X, Y_bin, alpha, True) 55 | likelihoods.append(ll) 56 | return likelihoods 57 | 58 | 59 | def compute_likelihood(hyps, X, Y, alpha, one_v_all=False): 60 | # compute P(y|h,x) - size HxN 61 | # is set to one where h(x) = y i.e. correct guess 62 | likelihood = np.ones((len(hyps), X.shape[0])) 63 | likelihood_opp = np.ones((len(hyps), X.shape[0])) 64 | 65 | for hh in range(len(hyps)): 66 | if one_v_all: 67 | # assumes that hyps[hh] is a D dim vector 68 | prob_agree, pred_class = teach.user_model_binary(hyps[hh], X, Y, alpha) 69 | else: 70 | # assumes that hyps[hh] is a CxD dim maxtrix 71 | prob_agree, pred_class = teach.user_model(hyps[hh], X, Y, alpha) 72 | inds = np.where(pred_class != Y)[0] 73 | likelihood[hh, inds] = prob_agree[inds] 74 | 75 | return likelihood 76 | 77 | 78 | def compute_density(X, Y, sigma, per_class=True): 79 | # compute the density of the datapoints 80 | dist = squareform(pdist(X)**2) 81 | if per_class: 82 | dens = np.zeros((X.shape[0])) 83 | for cc in np.unique(Y): 84 | inds = np.where(Y==cc)[0] 85 | dens[inds] = dist[inds, :][:, inds].mean(1) 86 | else: 87 | dens = dist.mean(1) 88 | dens = 1.0 / (1.0 + np.exp(-sigma*dens)) 89 | 90 | return dens 91 | 92 | 93 | def plot_2D_data(X, Y, alpha, hyps, random_exs, post, title_txt, fig_id, one_v_all, best_ind): 94 | plt.figure(fig_id) 95 | plt.title(title_txt) 96 | 97 | # plot hyper-planes 98 | l_weight_range = (0.5, 10) 99 | delta = 1.0 100 | xx = np.linspace(X[:,0].min()-delta, X[:,0].max()+delta) 101 | for hh in range(len(hyps)): 102 | if one_v_all: 103 | ww = hyps[hh] 104 | else: 105 | ww = hyps[hh][1,:] # for binary this is positive class 106 | 107 | 108 | if ww.shape[0] == 3: 109 | # with intercept i.e. ww[2] 110 | m = -ww[0] / ww[1] 111 | yy = m * xx - (ww[2]) / ww[1] 112 | else: 113 | # no intercept 114 | yy = (-ww[0] / ww[1])*xx 115 | 116 | plt.plot(xx,yy, 'g') 117 | 118 | # plot datapoints and text labels 119 | for ii, ll in enumerate(random_exs): 120 | plt.text(X[ll, 0]+0.1, X[ll, 1]+0.1, ii) 121 | plt.plot(X[ll, 0], X[ll, 1], 'yo') 122 | 123 | cols = ['r.', 'b.', 'c.', 'm.', 'k.'] 124 | for ii, yy in enumerate(np.unique(Y)): 125 | plt.plot(X[Y==yy,0], X[Y==yy,1], cols[ii]) 126 | 127 | plt.axis('equal') 128 | delta = 0.5 129 | plt.axis([X[:,0].min()-delta,X[:,0].max()+delta, X[:,1].min()-delta, X[:,1].max()+delta]) 130 | plt.show() 131 | 132 | 133 | def plot_2D_data_hyper(X, Y, alpha, hyps, random_exs, post, title_txt, fig_id, one_v_all, best_ind): 134 | # TODO this doesnt work for 1 v all need to plot hyper plans separately per class 135 | 136 | # this plots the data points X, the labels Y, along with the different 137 | # hypotheses hyp and their associated posterior weights 138 | # currently only works for binary classes 139 | # also best to use rand hypotheses for 2D datasets 140 | plt.figure(fig_id) 141 | plt.title(title_txt) 142 | 143 | # plot hyper-planes 144 | l_weight_range = (0.5, 10) 145 | delta = 1.0 146 | xx = np.linspace(X[:,0].min()-delta, X[:,0].max()+delta) 147 | for hh in range(len(hyps)): 148 | if one_v_all: 149 | print 'WARNING this is only implemented for binary' 150 | ww = hyps[hh] 151 | else: 152 | ww = hyps[hh][1,:] # for binary this is positive class 153 | 154 | if ww.shape[0] == 3: 155 | # with intercept i.e. ww[2] 156 | m = -ww[0] / ww[1] 157 | yy = m * xx - (ww[2]) / ww[1] 158 | else: 159 | # no intercept 160 | yy = (-ww[0] / ww[1])*xx 161 | 162 | lw = post[hh]*(l_weight_range[1]-l_weight_range[0]) + l_weight_range[0] 163 | if hh == best_ind: 164 | plt.plot(xx,yy, 'r', linewidth=lw) # optimal hypothesis 165 | else: 166 | plt.plot(xx,yy, 'g', linewidth=lw) # regular hypothesis 167 | 168 | # plot datapoints and text labels 169 | for ii, ll in enumerate(random_exs): 170 | plt.text(X[ll, 0]+0.1, X[ll, 1]+0.1, ii) 171 | plt.plot(X[ll, 0], X[ll, 1], 'yo') 172 | 173 | cols = ['r.', 'b.', 'k.', 'c.', 'm.'] 174 | for ii, yy in enumerate(np.unique(Y)): 175 | plt.plot(X[Y==yy,0], X[Y==yy,1], cols[ii]) 176 | 177 | plt.axis('equal') 178 | delta = 0.5 179 | plt.axis([X[:,0].min()-delta,X[:,0].max()+delta, X[:,1].min()-delta, X[:,1].max()+delta]) 180 | plt.show() 181 | -------------------------------------------------------------------------------- /code/readme.md: -------------------------------------------------------------------------------- 1 | # Teaching Categories to Human Learners with Visual Explanations 2 | Code for recreating the results in our CVPR 2018 paper. 3 | 4 | `explain` is the main code for the teaching algorithms. 5 | `data_gen` is the pre-processing scripts to generate the visual explanations. 6 | `teaching_app` is the Flask based web app that was used to get results on MTurk. 7 | 8 | 9 | ## Reference 10 | If you find our work useful in your research please consider citing our paper. 11 | ``` 12 | @inproceedings{explainteachcvpr18, 13 | title = {Teaching Categories to Human Learners with Visual Explanations}, 14 | author = {Mac Aodha, Oisin and Su, Shihan and Chen, Yuxin and Perona, Pietro and Yue, Yisong}, 15 | booktitle = {CVPR}, 16 | year = {2018} 17 | } 18 | ``` 19 | -------------------------------------------------------------------------------- /code/teaching_app/Procfile: -------------------------------------------------------------------------------- 1 | web: gunicorn app:app -------------------------------------------------------------------------------- /code/teaching_app/application.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from flask import Flask, Response, session, redirect, url_for, request, render_template 3 | import uuid 4 | import random 5 | import numpy as np 6 | import glob 7 | import os 8 | import utils as ut 9 | import datetime 10 | from pymongo import MongoClient 11 | from bson.json_util import dumps 12 | import config 13 | 14 | application = Flask(__name__) 15 | application.secret_key = config.SECRET_KEY 16 | 17 | # database - hosted on mlab 18 | if config.MONGO_DB_STR is not '': 19 | client = MongoClient(config.MONGO_DB_STR) 20 | db = client.get_default_database() 21 | else: 22 | db = MongoClient().database 23 | 24 | # load image data 25 | images_file = 'data/teaching_images.json' 26 | images = ut.load_ims(images_file) 27 | 28 | # load other settings 29 | settings_file = 'data/settings.json' 30 | class_names, train_indices, test_indices, test_sequence, experiment_id, scale = ut.load_settings(settings_file) 31 | 32 | # load strategy data i.e. image sequences 33 | strat_files = glob.glob('data/*.strat') 34 | strats = ut.load_strats(strat_files, test_sequence) 35 | 36 | # tutorial images 37 | tutorial_images = ['tutorial_0.jpg', 'tutorial_1.jpg', 'tutorial_2.jpg', 'tutorial_3.jpg'] 38 | 39 | def initalize_session(): 40 | # create new user session when user visits home page 41 | session.clear() 42 | session['name'] = str(uuid.uuid4()) 43 | session['response'] = [] 44 | session['time'] = [] 45 | session['gt_label'] = [] 46 | session['strategy'] = random.choice(strats.keys()) 47 | # different users will see options in different order - order will be consistent for the length of session 48 | session['button_order'] = random.sample(range(len(class_names)), len(class_names)) 49 | session['current_id'] = 0 50 | session['experiment_id'] = experiment_id 51 | 52 | num_ims = strats[session['strategy']]['num_train'] + strats[session['strategy']]['num_test'] 53 | session['num_ims'] = num_ims 54 | session['num_train'] = strats[session['strategy']]['num_train'] 55 | session['num_test'] = strats[session['strategy']]['num_test'] 56 | 57 | if 'random' in session['strategy']: 58 | # random strategies 59 | session['image_id'] = random.sample(train_indices, session['num_train']) + list(strats[session['strategy']]['test_sequence']) 60 | session['is_train'] = [0]*num_ims 61 | session['display_explain_image'] = [0]*num_ims 62 | 63 | for ii in range(session['num_train']): 64 | session['is_train'][ii] = 1 65 | 66 | if strats[session['strategy']]['display_explain_image']: 67 | session['display_explain_image'][ii] = 1 68 | 69 | else: 70 | session['is_train'] = list(strats[session['strategy']]['is_train']) 71 | session['image_id'] = list(strats[session['strategy']]['image_id']) 72 | session['display_explain_image'] = list(strats[session['strategy']]['display_explain_image']) 73 | 74 | 75 | # can have random images in the test set by specifying the index as -1 76 | valid_remain_test = list(set(test_indices) - set(session['image_id'])) 77 | valid_remain_test = random.sample(valid_remain_test, len(valid_remain_test)) 78 | vv = 0 79 | for ii in range(num_ims): 80 | if session['image_id'][ii] == -1 and session['is_train'][ii] == 0: 81 | session['image_id'][ii] = valid_remain_test[vv] 82 | vv += 1 83 | 84 | # add labels to session 85 | for ii in range(num_ims): 86 | session['gt_label'].append(images[session['image_id'][ii]]['class_label']) 87 | 88 | session.modified = True 89 | 90 | 91 | @application.route('/') 92 | def index(): 93 | # Create new session when users visits the homepage 94 | initalize_session() 95 | print(session['strategy']) 96 | 97 | params = {} 98 | params['num_ims'] = session['num_ims'] 99 | params['class_names'] = class_names 100 | 101 | return render_template('index.html', params=params) 102 | 103 | 104 | @application.route('/tutorial/') 105 | def tutorial(im_id): 106 | params = {} 107 | params['im_id'] = int(im_id) 108 | params['tutorial_images'] = tutorial_images 109 | 110 | return render_template('tutorial.html', params=params) 111 | 112 | 113 | @application.route('/debug') 114 | def debug(): 115 | params = {} 116 | params['images'] = images 117 | params['class_names'] = class_names 118 | params['strats'] = strats 119 | 120 | return render_template('debug.html', params=params) 121 | 122 | @application.route('/dashboard') 123 | def dashboard(): 124 | # Display results per strategy - for this experiment 125 | user_data = list(db.user_results.find({'experiment_id':experiment_id})) 126 | 127 | if len(user_data) == 0: 128 | return 'No results file exists yet.' 129 | else: 130 | strat_names = [uu['strategy'] for uu in user_data] 131 | test_scores = [uu['score'] for uu in user_data] 132 | 133 | params = {} 134 | params['num_turkers'] = len(user_data) 135 | params['strat_names'] = strats.keys() 136 | params['test_scores'] = [0]*len(strats.keys()) 137 | params['users_per_strat'] = [0]*len(strats.keys()) 138 | params['experiment_id'] = experiment_id 139 | 140 | # compute the per strategy average results 141 | for jj, ss in enumerate(strats.keys()): 142 | for ii, rr in enumerate(strat_names): 143 | if ss == rr: 144 | params['test_scores'][jj] += test_scores[ii] 145 | params['users_per_strat'][jj] += 1 146 | 147 | if params['users_per_strat'][jj] > 0: 148 | params['test_scores'][jj] /= params['users_per_strat'][jj] 149 | 150 | return render_template('dashboard.html', params=params) 151 | 152 | 153 | @application.route('/user_data') 154 | def user_data(): 155 | # Display user results - for this experiment 156 | user_data = list(db.user_results.find({'experiment_id':experiment_id})) 157 | 158 | if len(user_data) == 0: 159 | return 'No results file exists yet.' 160 | else: 161 | return dumps(user_data) 162 | 163 | 164 | def save_results(session): 165 | # Write data for current session to database 166 | result = {} 167 | result['completion_time'] = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 168 | result['strategy'] = session['strategy'] 169 | result['image_id'] = session['image_id'] 170 | result['response'] = session['response'] 171 | result['gt_label'] = session['gt_label'] 172 | result['is_train'] = session['is_train'] 173 | result['time'] = session['time'] 174 | result['mturk_code'] = session['name'] 175 | result['score'] = session['score'] 176 | result['experiment_id'] = session['experiment_id'] 177 | 178 | # write to DB 179 | db.user_results.insert_one(result) 180 | 181 | 182 | @application.route('/teaching', methods=['GET','POST']) 183 | def teaching(): 184 | # Main function that handles the teaching logic 185 | 186 | # session not initialized - so do it 187 | if 'response' not in session.keys(): 188 | initalize_session() 189 | 190 | # capture user button presses 191 | if request.method == 'POST': 192 | session['response'].append(int(request.form['action'])) 193 | session['time'].append(datetime.datetime.now().strftime("%H:%M:%S")) 194 | session['current_id'] += 1 195 | session.modified = True 196 | 197 | # if they are finished, send users to results screen and save results to db 198 | if len(session['response']) == session['num_ims']: 199 | 200 | # compute score on test set 201 | num_corr = 0.0 202 | for rr in range(session['num_train'], session['num_ims']): 203 | if session['response'][rr] == session['gt_label'][rr]: 204 | num_corr += 1.0 205 | 206 | params = {} 207 | params['mturk_code'] = session['name'] 208 | params['score'] = round(100*num_corr / float(session['num_test']),2) 209 | session['score'] = params['score'] 210 | session.modified = True 211 | 212 | save_results(session) 213 | 214 | return render_template('results.html', params=params) 215 | 216 | # select next image to show 217 | current_id = session['current_id'] 218 | image_id = session['image_id'][current_id] 219 | 220 | params = {} 221 | params['image'] = images[image_id]['image_url'] 222 | params['explain_image'] = images[image_id]['explain_url'] 223 | params['label'] = images[image_id]['class_label'] 224 | params['display_explain_image'] = session['display_explain_image'][current_id] and session['is_train'][current_id] 225 | params['len_resp'] = len(session['response']) 226 | params['strategy'] = session['strategy'] 227 | params['is_train'] = session['is_train'][current_id] 228 | params['total_num_ims'] = session['num_ims'] 229 | params['num_train'] = session['num_train'] 230 | params['class_names'] = class_names 231 | params['training_finished'] = session['num_train'] == session['current_id'] 232 | params['button_order'] = session['button_order'] 233 | params['scale'] = scale 234 | 235 | params['train_feedback'] = False 236 | if len(session['response']) == (session['num_train']//2): 237 | params['train_feedback'] = True 238 | 239 | return render_template('teaching.html', params=params) 240 | 241 | 242 | if __name__ == '__main__': 243 | #application.run(debug=True, use_reloader=True) 244 | application.run() 245 | -------------------------------------------------------------------------------- /code/teaching_app/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration parameters 3 | """ 4 | 5 | MONGO_DB_STR = '' # CHANGE_THIS 6 | SECRET_KEY = 'CHANGE_THIS' # CHANGE_THIS 7 | -------------------------------------------------------------------------------- /code/teaching_app/data/random_image.strat: -------------------------------------------------------------------------------- 1 | { 2 | "num_train": 3, 3 | "display_explain_image": 0 4 | } 5 | -------------------------------------------------------------------------------- /code/teaching_app/data/random_image_with_explain.strat: -------------------------------------------------------------------------------- 1 | { 2 | "num_train": 3, 3 | "display_explain_image": 1 4 | } 5 | -------------------------------------------------------------------------------- /code/teaching_app/data/readme.md: -------------------------------------------------------------------------------- 1 | Currently `teaching_images.json` just contains dummy data. Replace the files in this folder with your own to specify the teaching settings. 2 | -------------------------------------------------------------------------------- /code/teaching_app/data/settings.json: -------------------------------------------------------------------------------- 1 | {"train_indices": [0, 2, 4], "class_names": ["Blackbird", "Robin", "Cardinal"], "test_indices": [1, 3, 5], "test_sequence": [-1, -1], "experiment_id": 1, "scale":1.0} 2 | -------------------------------------------------------------------------------- /code/teaching_app/data/strategy_0.strat: -------------------------------------------------------------------------------- 1 | { 2 | "num_train": 3, 3 | "display_explain_image": [0, 0, 0], 4 | "im_ids": [3, 2, 1] 5 | } 6 | -------------------------------------------------------------------------------- /code/teaching_app/data/strategy_1.strat: -------------------------------------------------------------------------------- 1 | { 2 | "num_train": 3, 3 | "display_explain_image": [1, 1, 1], 4 | "im_ids": [2, 1, 0] 5 | } 6 | -------------------------------------------------------------------------------- /code/teaching_app/data/teaching_images.json: -------------------------------------------------------------------------------- 1 | [{"class_label": 0, "image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a9/Common_Blackbird.jpg/320px-Common_Blackbird.jpg", "explain_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/13/NIE_1905_Bird_-_topography.jpg/320px-NIE_1905_Bird_-_topography.jpg"}, {"class_label": 0, "image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/0/09/Blackbird_2.jpg/320px-Blackbird_2.jpg", "explain_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/13/NIE_1905_Bird_-_topography.jpg/320px-NIE_1905_Bird_-_topography.jpg"}, {"class_label": 1, "image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/3/32/American_robin.jpg/183px-American_robin.jpg", "explain_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/13/NIE_1905_Bird_-_topography.jpg/320px-NIE_1905_Bird_-_topography.jpg"}, {"class_label": 1, "image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b7/American_Robin_0026.jpg/320px-American_Robin_0026.jpg", "explain_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/13/NIE_1905_Bird_-_topography.jpg/320px-NIE_1905_Bird_-_topography.jpg"}, {"class_label": 2, "image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d9/Cardinal_side_view.JPG/320px-Cardinal_side_view.JPG", "explain_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/13/NIE_1905_Bird_-_topography.jpg/320px-NIE_1905_Bird_-_topography.jpg"}, {"class_label": 2, "image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/5f/Northern_Cardinal_Broadside.jpg/317px-Northern_Cardinal_Broadside.jpg", "explain_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/1/13/NIE_1905_Bird_-_topography.jpg/320px-NIE_1905_Bird_-_topography.jpg"}] 2 | -------------------------------------------------------------------------------- /code/teaching_app/readme.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | This is a simple Flask based web app that will display either random images to user or images in a sequence specified in the configuration files. When a user visits the site, they will be assigned randomly to one of several possible specified teaching strategies and will be shown images and potentially visual exanations as feedback. 3 | 4 | 5 | ### Data Format 6 | There are three types of files that are used to configure the images that are shown to the users. These files should be placed in `data/`. 7 | 8 | 1) List of images - teaching_images.json 9 | This is a list of dictionaries where each entry contains the image url, the url of an explanation image (can be blank), and a class label from 0 to the number of classes (i.e. this can be multi-class). 10 | Images can be hosted on your own website or some external service such as Amazon S3. 11 | Here is an example entry for one image: 12 | ``` 13 | { 14 | 'class_label': 0, 15 | 'image_url': 'https://blaablaablaa.com/image.jpg', 16 | 'explain_url': 'https://blaablaablaa.com/explain_image.jpg' 17 | } 18 | ``` 19 | 20 | 2) Settings - settings.json 21 | This contains a list of class names, the indices of the train (train_indices) and testing images (test_indices), and the indices of the sequence of test images that will be shown (test_sequence) - these indices reference the id in the image list in teaching_images.json. If there is a `-1` in test_sequence a random image from the unseen test set will be chosen for that location. There is also a `scale` value which allows images to be displayed bigger in the user interface. The number of class names should be the same as the number of unique classes in teaching_images.json. You can also specify an experiment id number (experiment_id) to keep track of results. 22 | 23 | 3) Teaching sequence files e.g. strategy_0.strat, strategy_1.strat, ... 24 | These files can have any name as long as they end with .strat. 25 | im_ids is the list of images that will be displayed during training, selected from teaching_images.json - test images are specified in settings.json to ensure that all strategies use the same test images. 26 | For example, the following strategy file would display three images (2, 1, and 0), with an explanation image after each. 27 | ``` 28 | { 29 | "num_train": 3, 30 | "display_explain_image": [1, 1, 1], 31 | "im_ids": [2, 1, 0] 32 | } 33 | ``` 34 | 35 | You can also specify random strategy files e.g. random_image.strat and random_image_with_explain.strat. These contains two entries - the number of training images, and whether to display the explanation image: 36 | ``` 37 | { 38 | "num_train": 3, 39 | "display_explain_image": 1 40 | } 41 | ``` 42 | 43 | ### General Notes 44 | Will only save user data if they complete the task. 45 | Good idea to keep all images the same size where possible. 46 | Change `config.py` to point to database and set the secret key. 47 | 48 | 49 | ## Deployment 50 | For testing the app can be deployed locally, but to run experiments on the web it must be accessible online. 51 | 52 | 53 | ### Set up Database on mlab 54 | We need to store our results in a database. 55 | Create account at `https://mlab.com` 56 | When logged in go to MongoDB Deployments, click Create new, choose Sandbox (free), click Continue (bottom right), select US region, give it a name. 57 | Once created you need to create a new users account for the database. To do this click on database name, Users, and add database user. 58 | Finally set the MonogoDB URI in `config.py`. 59 | 60 | 61 | ### Locally 62 | Assuming you have Flask installed on your machine (it comes bundled in the Anaconda installation), just download the repository and run `python application.py`. 63 | As configured, there must be a database installed and a secret key set in `config.py` or the app will throw an error. 64 | 65 | 66 | ### Heroku 67 | Alternatively, you can deploy the app on the web so it's accessible to others using a service such as Heroku. 68 | First, register for a free heroku account at `https://www.heroku.com/`. 69 | Download Heroku CLI for your machine - `https://devcenter.heroku.com/articles/heroku-cli`. 70 | Download the code from this repository and save it on your computer. 71 | 72 | Create new a Heroku webapp - this will be the website that hosts our application 73 | ``` 74 | heroku login 75 | heroku create e.g. heroku create machine-teaching-demo (you won't be able to use this app name as I already have it) 76 | ``` 77 | 78 | Create database - we use the database we created on mlab 79 | `heroku addons:create mongolab:sandbox --app ` 80 | 81 | Create new git repo - Heroku uses git for deploying app 82 | ``` 83 | cd my-project/ 84 | git init 85 | heroku git:remote -a e.g. heroku git:remote -a machine-teaching-demo 86 | ``` 87 | 88 | Deploy - add the code to the repository and psuh to Heroku 89 | ``` 90 | git add . 91 | git commit -am "first commit" 92 | git push heroku master 93 | ``` 94 | 95 | For existing repositories, simply add the Heroku remote 96 | ``` 97 | heroku git:remote -a machine-teaching-demo 98 | ``` 99 | 100 | View app at 101 | ``` 102 | https://.herokuapp.com/ e.g. https://your-machine-teaching-demo.herokuapp.com/ 103 | ``` 104 | 105 | ### AWS Elastic Beanstalk 106 | Create Access Key 107 | AWS Console -> My Security Credentials 108 | 109 | Configure CLI 110 | http://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html 111 | 112 | Flask details 113 | http://docs.aws.amazon.com/elasticbeanstalk/latest/dg/create-deploy-python-flask.html 114 | 115 | ``` 116 | eb init -p python2.7 image-quiz # create app 117 | eb create image-quiz-production # create environment 118 | eb deploy # update app every time you make changes 119 | eb terminate image-quiz-production # deletes the app 120 | ``` 121 | 122 | ### View Results 123 | Go to `https:///debug` to see if data loaded correctly. 124 | 125 | Go to `https:///user_data`, copy the text and save it in `results.txt`. 126 | ``` 127 | import json 128 | results_file = 'results.txt' 129 | with open(results_file) as f: 130 | user_data = json.load(f) 131 | 132 | print len(user_data), 'users completed the task' 133 | print user_data[0]['mturk_code'] 134 | print user_data[0]['strategy'] 135 | print user_data[0]['response'] 136 | print user_data[0]['gt_label'] 137 | ``` 138 | 139 | Can also go to `https:///dashboard` to view the test set summary live. 140 | 141 | Alternatively, you can load the results directly from the database: 142 | ``` 143 | from pymongo import MongoClient 144 | import config 145 | 146 | client = MongoClient(config.MONGO_DB_STR) 147 | db = client.get_default_database() 148 | experiment_of_interest = 1 149 | 150 | user_data = list(db.user_results.find()) 151 | 152 | # This will delete all the data in the DB - warning only do this if everything is backed up 153 | # db.user_results.drop() 154 | 155 | # This will delete all the entries for a specific experiment. 156 | # db.user_results.delete_many({'experiment_id':experiment_of_interest}) 157 | ``` -------------------------------------------------------------------------------- /code/teaching_app/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.8.2 2 | Flask==1.0.2 3 | Jinja2==2.10.1 4 | gunicorn==19.7.1 5 | pymongo==3.4.0 6 | -------------------------------------------------------------------------------- /code/teaching_app/runtime.txt: -------------------------------------------------------------------------------- 1 | python-2.7.14 -------------------------------------------------------------------------------- /code/teaching_app/static/style.css: -------------------------------------------------------------------------------- 1 | canvas { 2 | padding: 0; 3 | margin: auto; 4 | display: block; 5 | } 6 | 7 | .disabled-group { 8 | pointer-events: none; 9 | } 10 | 11 | .form-group input[type="radio"] { 12 | -webkit-appearance: none; 13 | } 14 | 15 | .form-group{ 16 | margin-bottom:0px; 17 | } 18 | 19 | .form-group label { 20 | width: 100%; 21 | margin-right:0px; 22 | } 23 | 24 | .form-group span { 25 | margin-top:0px; 26 | } 27 | 28 | img.tutorial { 29 | padding:1px; 30 | border:1px solid #010d21; 31 | background-color:#010d21; 32 | } 33 | -------------------------------------------------------------------------------- /code/teaching_app/static/tutorial_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/explain_teach/ac681a061924ef39a4f07e8b89af86a201477fc4/code/teaching_app/static/tutorial_0.jpg -------------------------------------------------------------------------------- /code/teaching_app/static/tutorial_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/explain_teach/ac681a061924ef39a4f07e8b89af86a201477fc4/code/teaching_app/static/tutorial_1.jpg -------------------------------------------------------------------------------- /code/teaching_app/static/tutorial_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/explain_teach/ac681a061924ef39a4f07e8b89af86a201477fc4/code/teaching_app/static/tutorial_2.jpg -------------------------------------------------------------------------------- /code/teaching_app/static/tutorial_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/macaodha/explain_teach/ac681a061924ef39a4f07e8b89af86a201477fc4/code/teaching_app/static/tutorial_3.jpg -------------------------------------------------------------------------------- /code/teaching_app/templates/dashboard.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Dashboard 9 | 10 | 11 | 12 |
13 |
14 |
15 |

Test Set Results for Experiment {{ params['experiment_id'] }}

16 |

Current number of completed HITS is {{ params['num_turkers'] }}.

17 |
18 | 19 | 20 | 21 | 22 | 23 | 24 | {% for rr in range(params['strat_names']|length) %} 25 | 26 | 27 | 28 | 29 | 30 | {% endfor %} 31 | 32 |
Strategy Num TurkersTest Average
{{ params['strat_names'][rr]}}{{ params['users_per_strat'][rr]}}{{ params['test_scores'][rr]}}
33 | 34 |
35 |
36 | 37 | 38 | 39 | 40 |
41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /code/teaching_app/templates/debug.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Debug 9 | 10 | 11 | 12 |
13 |
14 |
15 |

Images

16 | 17 | 18 | 19 | 20 | 21 | {% for rr in range(params['images']|length) %} 22 | 23 | 24 | 25 | 26 | {% endfor %} 27 | 28 |
num data
{{ rr }}{{ params['images'][rr]}}
29 | 30 |

Class names

31 | 32 | 33 | 34 | 35 | 36 | {% for rr in range(params['class_names']|length) %} 37 | 38 | 39 | 40 | 41 | {% endfor %} 42 | 43 |
class num class name
{{ rr }}{{ params['class_names'][rr]}}
44 | 45 |

Strategies

46 | 47 | 48 | 49 | 50 | 51 | {% for key, value in params['strats'].items() %} 52 | 53 | 54 | 55 | 56 | {% endfor %} 57 | 58 |
strategy name data
{{ key }}{{ value }}
59 | 60 |
61 |
62 | 63 |
64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /code/teaching_app/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Image Quiz - Welcome 9 | 10 | 11 | 12 |
13 |
14 |
15 |

You will be presented with a sequence of {{ params['num_ims'] }} images and asked to guess if the image is one of the following:

16 | {% for cc in range(params['class_names']|length) %} 17 |

{{ params['class_names'][cc] }}

18 | {% endfor %} 19 |
20 |

After every guess you will get feedback so you can answer better in the future.

21 | 22 |
23 |
24 | 25 |

26 |
27 |
28 | Begin Tutorial 29 |
30 |
31 | 32 |
33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /code/teaching_app/templates/results.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Results 9 | 10 | 11 | 12 |
13 |

Thank you for participating!

14 |
15 | 16 |

You got {{params['score']}}% correct.

17 |
18 |
19 | 20 |

Enter the following code into the Mechanical Turk survey link box to receive payment:

21 |

{{params['mturk_code']}}

22 | 23 |
24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /code/teaching_app/templates/teaching.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | Image Quiz 10 | 11 | 12 | 13 | 14 |
15 | 16 |
17 |
18 | 19 |
20 | 21 |
22 |
23 |

Please select one of the following options:

24 | 25 |
26 |
27 | {% for bb in params['button_order'] %} 28 |
29 | 32 |
33 | {% endfor %} 34 |
35 | 36 |

37 | 38 | 39 |
40 | 41 | 42 | 43 |
44 | 45 |
46 |
47 |

{{ params['total_num_ims'] - params['len_resp'] -1 }} images left.

48 |
49 | 50 | 68 | 69 | 87 | 88 | 89 |
90 | 91 | 92 | 93 | 94 | 95 | 221 | 222 | 223 | 224 | 225 | -------------------------------------------------------------------------------- /code/teaching_app/templates/tutorial.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | Image Quiz - Tutorial 10 | 11 | 12 | 13 |
14 |
15 | 16 |
17 | 18 |
19 |
20 |
21 | {% if params['im_id'] == params['tutorial_images']|length-1 %} 22 | Next 23 | {% else %} 24 | Next 25 | {% endif%} 26 |
27 |
28 | 29 | 47 | 48 |
49 | 50 | 51 | 52 | 53 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /code/teaching_app/utils.py: -------------------------------------------------------------------------------- 1 | # Helper functions that deal with file loading 2 | 3 | from __future__ import print_function 4 | import numpy as np 5 | import os 6 | import json 7 | 8 | 9 | def load_ims(im_file): 10 | # load images and class names 11 | print('\nloading images') 12 | with open(im_file, 'r') as fs: 13 | images = json.load(fs) 14 | 15 | for ii, im in enumerate(images): 16 | if 'explain_url' not in im.keys(): 17 | im['explain_url'] = im['image_url'] 18 | 19 | for ii, im in enumerate(images): 20 | print(ii, os.path.basename(im['image_url'])) 21 | 22 | print(len(images), 'images loaded') 23 | return images 24 | 25 | 26 | def load_settings(settings_file): 27 | print('\nloading settings') 28 | with open(settings_file, 'r') as fs: 29 | settings = json.load(fs) 30 | return settings['class_names'], settings['train_indices'], settings['test_indices'], settings['test_sequence'], settings['experiment_id'], settings['scale'] 31 | 32 | 33 | def load_strats(strat_files, test_sequence): 34 | print('\nloading strategy files') 35 | strats = {} 36 | for ii, sf in enumerate(strat_files): 37 | strat_name = os.path.basename(sf)[:-6] 38 | print(ii, strat_name) 39 | with open(sf, 'r') as fp: 40 | strat_data = json.load(fp) 41 | res = {} 42 | 43 | res['num_train'] = strat_data['num_train'] 44 | res['num_test'] = len(test_sequence) 45 | 46 | if 'random' in strat_name: 47 | res['test_sequence'] = list(test_sequence) 48 | res['display_explain_image'] = strat_data['display_explain_image'] 49 | else: 50 | res['image_id'] = strat_data['im_ids'] + list(test_sequence) 51 | res['display_explain_image'] = strat_data['display_explain_image'] + [0]*res['num_test'] 52 | res['is_train'] = [1]*res['num_train'] + [0]*res['num_test'] 53 | strats[strat_name] = res 54 | 55 | print(len(strats), 'strategies loaded\n') 56 | return strats 57 | 58 | 59 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | oct/ 2 | -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | # Teaching Categories to Human Learners with Visual Explanations 2 | Datasets used in our CVPR 2018 paper. 3 | 4 | For patient privacy reasons we cannot re-distribute the images in the OCT dataset. 5 | 6 | The datasets can be downloaded from [here](https://homepages.inf.ed.ac.uk/omacaod/files/cvpr_2018_teach_data.zip). 7 | 8 | ## Reference 9 | If you find our work useful in your research please consider citing our paper. 10 | ``` 11 | @inproceedings{explainteachcvpr18, 12 | title = {Teaching Categories to Human Learners with Visual Explanations}, 13 | author = {Mac Aodha, Oisin and Su, Shihan and Chen, Yuxin and Perona, Pietro and Yue, Yisong}, 14 | booktitle = {CVPR}, 15 | year = {2018} 16 | } 17 | ``` 18 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Teaching Categories to Human Learners with Visual Explanations 2 | Code for recreating the results in our CVPR 2018 paper. 3 | 4 | `code` contains the main code for the teaching algorithms and data generation. 5 | `data` contains the image datasets. 6 | `results` contains the results files and plot generation scripts. 7 | 8 | 9 | ## Reference 10 | If you find our work useful in your research please consider citing our paper. 11 | ``` 12 | @inproceedings{explainteachcvpr18, 13 | title = {Teaching Categories to Human Learners with Visual Explanations}, 14 | author = {Mac Aodha, Oisin and Su, Shihan and Chen, Yuxin and Perona, Pietro and Yue, Yisong}, 15 | booktitle = {CVPR}, 16 | year = {2018} 17 | } 18 | ``` 19 | -------------------------------------------------------------------------------- /results/create_plots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import json 4 | import seaborn as sns 5 | import datetime as dt 6 | import os 7 | from matplotlib.ticker import MultipleLocator, FormatStrFormatter 8 | from sklearn.metrics import confusion_matrix 9 | sns.set_style("whitegrid") 10 | 11 | 12 | def get_time_diff(user_time): 13 | times = [dt.datetime.strptime(tt, '%H:%M:%S') for tt in user_time] 14 | time_diff = [] 15 | for tt in range(len(times)-1): 16 | diff = (times[tt+1] - times[tt]).seconds 17 | time_diff.append(diff) 18 | time_diff = np.hstack((0, time_diff)) 19 | return time_diff 20 | 21 | 22 | # select which daatset to plot 23 | exp_name = 'oct' 24 | #exp_name = 'butterflies_crop' 25 | #exp_name = 'chinese_chars' 26 | #exp_name = 'chinese_chars_crowd' 27 | 28 | 29 | plt.close('all') 30 | save_data = False 31 | remove_bottom = False 32 | rm_fraction = 0.2 33 | 34 | col_p = sns.color_palette() 35 | col_p[3], col_p[2] = col_p[2], col_p[3] 36 | majorLocator = MultipleLocator(5) 37 | minorLocator = MultipleLocator(1) 38 | majorFormatter = FormatStrFormatter('%d') 39 | 40 | 41 | if not save_data: 42 | print '***\nNot saving plots\n***\n' 43 | 44 | if remove_bottom: 45 | print '***\nRemoving worst workers\n***\n' 46 | else: 47 | print '***\nUsing all data\n***\n' 48 | 49 | 50 | base_dir = 'experiments/' + exp_name + '/' 51 | results_file = base_dir + 'results.json' 52 | settings_file = base_dir + 'settings.json' 53 | op_dir = base_dir + '/plots/' 54 | if (not os.path.isdir(op_dir)) and save_data: 55 | os.makedirs(op_dir) 56 | 57 | # load data 58 | with open(settings_file) as f: 59 | settings = json.load(f) 60 | num_classes = len(settings['class_names']) 61 | 62 | with open(results_file) as f: 63 | user_data = json.load(f) 64 | 65 | strats = [uu['strategy'] for uu in user_data] 66 | un_strats = np.unique(strats) 67 | ip_strats_names = ['random', 'random_feedback', 'strict_1vall', 'explain_1vall'] 68 | ip_strats_names_full = ['RAND_IM', 'RAND_EXP', 'STRICT', 'EXPLAIN'] 69 | op_strat_names = [ip_strats_names_full[ip_strats_names.index(ss)] for ss in un_strats] 70 | 71 | if (np.unique(un_strats) == np.unique(strats)).mean() != 1.0: 72 | print '\n*****Warning - missing strat\n*******' 73 | 74 | print '\n', len(user_data), 'users completed the task' 75 | print 'strategies\t', un_strats 76 | print 'classes \t', settings['class_names'] 77 | 78 | 79 | train_inds = np.where(np.asarray(user_data[0]['is_train'])==1)[0] 80 | test_inds = np.where(np.asarray(user_data[0]['is_train'])==0)[0] 81 | 82 | # load data 83 | train_acc = {} 84 | test_acc = {} 85 | test_pred = {} 86 | test_gt = {} 87 | scores_all = {} 88 | user_times = {} 89 | print '\n'.ljust(20), ' train'.ljust(10), ' test'.ljust(10), 'med tst'.ljust(10), 'num people' 90 | for fid, ss in enumerate(un_strats): 91 | resp = np.asarray([uu['response'] for uu in user_data if uu['strategy'] == ss]) 92 | gt = np.asarray([uu['gt_label'] for uu in user_data if uu['strategy'] == ss]) 93 | scores = np.asarray([uu['score'] for uu in user_data if uu['strategy'] == ss]) 94 | tm = np.asarray([uu['time'] for uu in user_data if uu['strategy'] == ss]) 95 | 96 | if remove_bottom: 97 | # remove bottom X% 98 | keep_inds = np.argsort(scores)[int(len(scores)*rm_fraction):] 99 | else: 100 | keep_inds = np.arange(scores.shape[0]) 101 | 102 | scores = scores[keep_inds].copy() 103 | resp = resp[keep_inds, :].copy() 104 | gt = gt[keep_inds, :].copy() 105 | tm = tm[keep_inds, :].copy() 106 | 107 | train_acc[ss] = (resp[:, train_inds]==gt[:, train_inds]).mean(0)*100 108 | test_acc[ss] = (resp[:, test_inds]==gt[:, test_inds]).mean(0)*100 109 | scores_all[ss] = scores 110 | user_times[ss] = tm 111 | test_pred[ss] = resp[:, test_inds] 112 | test_gt[ss] = gt[:, test_inds] 113 | 114 | print ss.ljust(20), str(round(np.mean(train_acc[ss]),2)).ljust(10), str(round(np.mean(test_acc[ss]),2)).ljust(10), str(round(np.median(scores),2)).ljust(10), len(scores) 115 | 116 | 117 | # hist of test acc 118 | fig = plt.figure(0, figsize=(7, 6)) 119 | fig.suptitle(exp_name + ' - Test Accuracy', fontsize=14) 120 | for fid, ss in enumerate(un_strats): 121 | plt.subplot(2, 2, fid+1) 122 | if ss == 'explain_1vall': 123 | plt.hist(scores_all[ss], bins=10, range=(0,100), color=col_p[3]) 124 | else: 125 | plt.hist(scores_all[ss], bins=10, range=(0,100), color=col_p[0]) 126 | 127 | plt.xlim(0, 100) 128 | plt.title(op_strat_names[fid]) 129 | plt.tight_layout(rect=[0, 0, 1, 0.95]) 130 | 131 | if save_data: 132 | plt.savefig(op_dir + '0.png') 133 | plt.savefig(op_dir + '0.pdf') 134 | 135 | 136 | # test acc boxplot 137 | fig = plt.figure(1) 138 | fig.suptitle(exp_name + ' - Test Accuracy', fontsize=14) 139 | plt.boxplot([scores_all[ss].tolist() for ss in un_strats], labels=un_strats) 140 | plt.ylim([0, 100]) 141 | 142 | if save_data: 143 | plt.savefig(op_dir + '1.png') 144 | plt.savefig(op_dir + '1.pdf') 145 | 146 | 147 | # train accuracy over time 148 | fig = plt.figure(2) 149 | fig.suptitle(exp_name + ' - Train Accuracy', fontsize=14) 150 | for fid, ss in enumerate(un_strats): 151 | plt.plot(np.arange(len(train_acc[ss]))+1, train_acc[ss], label=op_strat_names[fid], color=col_p[fid]) 152 | plt.legend() 153 | plt.xlabel('Training Image') 154 | plt.ylabel('Average Accuracy') 155 | 156 | plt.gca().xaxis.set_major_locator(majorLocator) 157 | plt.gca().xaxis.set_major_formatter(majorFormatter) 158 | plt.gca().xaxis.set_minor_locator(minorLocator) 159 | 160 | plt.ylim([0, 100]) 161 | plt.xlim([1, train_inds.shape[0]]) 162 | plt.show() 163 | 164 | if save_data: 165 | plt.savefig(op_dir + '2.png') 166 | plt.savefig(op_dir + '2.pdf') 167 | 168 | 169 | # test accuracy over time 170 | fig = plt.figure(3) 171 | fig.suptitle(exp_name + ' - Test Accuracy', fontsize=14) 172 | for fid, ss in enumerate(un_strats): 173 | plt.plot(np.arange(len(test_acc[ss]))+1, test_acc[ss], label=op_strat_names[fid], color=col_p[fid]) 174 | plt.legend() 175 | plt.ylim([0, 100]) 176 | plt.xlim([1, test_inds.shape[0]]) 177 | plt.show() 178 | 179 | if save_data: 180 | plt.savefig(op_dir + '3.png') 181 | plt.savefig(op_dir + '3.pdf') 182 | 183 | 184 | # average time against accuracy 185 | fig = plt.figure(4) 186 | fig.suptitle(exp_name + ' - Test Timings', fontsize=14) 187 | for fid, ss in enumerate(un_strats): 188 | time_diff = [np.mean(get_time_diff(uut)[test_inds]) for uut in user_times[ss]] # mean of workers 189 | plt.plot(time_diff, scores_all[ss], '.', label=op_strat_names[fid]) 190 | plt.legend() 191 | plt.ylim([0, 102]) 192 | plt.xlim(xmin=0) 193 | plt.xlabel('Mean time (seconds)') 194 | plt.ylabel('Test accuracy') 195 | 196 | if save_data: 197 | plt.savefig(op_dir + '4.png') 198 | plt.savefig(op_dir + '4.pdf') 199 | 200 | 201 | # average time by strategy 202 | fig = plt.figure(5) 203 | fig.suptitle(exp_name + ' - Test Timings', fontsize=14) 204 | mean_times = [] 205 | for fid, ss in enumerate(un_strats): 206 | time_diff = [np.mean(get_time_diff(uut)[test_inds]) for uut in user_times[ss]] # mean of workers 207 | mean_times.append(time_diff) 208 | 209 | plt.boxplot(mean_times, labels=un_strats) 210 | plt.ylabel('Test Time (seconds per image)') 211 | 212 | if save_data: 213 | plt.savefig(op_dir + '5.png') 214 | plt.savefig(op_dir + '5.pdf') 215 | 216 | 217 | # confusion matrices 218 | cms = [] 219 | for fid, ss in enumerate(un_strats): 220 | cm = confusion_matrix(test_gt[ss].ravel(), test_pred[ss].ravel()) 221 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 222 | cms.append(cm*100) 223 | 224 | if len(un_strats) == 4: 225 | fig, axes = plt.subplots(nrows=2, ncols=2, num=6) 226 | else: 227 | fig, axes = plt.subplots(nrows=1, ncols=2, num=6) 228 | 229 | fig.suptitle(exp_name + ' - Average Class Confusion', fontsize=14) 230 | for fid, ss in enumerate(un_strats): 231 | #for fid, ax in enumerate(axes.flat): 232 | ax = axes.flat[fid] 233 | im = ax.imshow(cms[fid], cmap='plasma', vmin=0, vmax=100.0) 234 | ax.set_yticks(np.arange(num_classes)) 235 | ax.grid('off') 236 | ax.set_title(op_strat_names[fid]) 237 | plt.tight_layout(rect=[0, 0, 1, 0.95]) 238 | cax = fig.add_axes([0.9, 0.1, 0.03, 0.75]) 239 | fig.colorbar(im, cax=cax) 240 | 241 | if save_data: 242 | plt.savefig(op_dir + '6.png') 243 | plt.savefig(op_dir + '6.pdf') 244 | 245 | 246 | plt.show() -------------------------------------------------------------------------------- /results/experiments/butterflies_crop/explain_1vall.strat: -------------------------------------------------------------------------------- 1 | {"display_text": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "num_train": 20, "positive_feature": [[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]], "im_ids": [1133, 380, 71, 115, 355, 234, 1084, 1379, 23, 184, 1431, 182, 879, 978, 885, 488, 516, 1037, 941, 433], "display_explain_image": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "feat_ids": [[-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1]]} -------------------------------------------------------------------------------- /results/experiments/butterflies_crop/random.strat: -------------------------------------------------------------------------------- 1 | { 2 | "num_train": 20, 3 | "feat_dims_to_show": [], 4 | "num_feats_to_show": 0, 5 | "display_explain_image": 0, 6 | "display_text": 0 7 | } 8 | -------------------------------------------------------------------------------- /results/experiments/butterflies_crop/random_feedback.strat: -------------------------------------------------------------------------------- 1 | { 2 | "num_train": 20, 3 | "feat_dims_to_show": [0], 4 | "num_feats_to_show": 0, 5 | "display_explain_image": 1, 6 | "display_text": 0 7 | } 8 | -------------------------------------------------------------------------------- /results/experiments/butterflies_crop/settings.json: -------------------------------------------------------------------------------- 1 | {"scale": 2.0, "experiment_id": 17, "test_indices": [1479, 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487, 1488, 1489, 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498, 1499, 1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574, 1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1600, 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618, 1619, 1620, 1621, 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1636, 1637, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673, 1674, 1675, 1676, 1677, 1678, 1679, 1680, 1681, 1682, 1683, 1684, 1685, 1686, 1687, 1688, 1689, 1690, 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1701, 1702, 1703, 1704, 1705, 1706, 1707, 1708, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1726, 1727, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1749, 1750, 1751, 1752, 1753, 1754, 1755, 1756, 1757, 1758, 1759, 1760, 1761, 1762, 1763, 1764, 1765, 1766, 1767, 1768, 1769, 1770, 1771, 1772, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1783, 1784, 1785, 1786, 1787, 1788, 1789, 1790, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1841, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1864], "train_indices": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059, 1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115, 1116, 1117, 1118, 1119, 1120, 1121, 1122, 1123, 1124, 1125, 1126, 1127, 1128, 1129, 1130, 1131, 1132, 1133, 1134, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143, 1144, 1145, 1146, 1147, 1148, 1149, 1150, 1151, 1152, 1153, 1154, 1155, 1156, 1157, 1158, 1159, 1160, 1161, 1162, 1163, 1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 1185, 1186, 1187, 1188, 1189, 1190, 1191, 1192, 1193, 1194, 1195, 1196, 1197, 1198, 1199, 1200, 1201, 1202, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1220, 1221, 1222, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237, 1238, 1239, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247, 1248, 1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1260, 1261, 1262, 1263, 1264, 1265, 1266, 1267, 1268, 1269, 1270, 1271, 1272, 1273, 1274, 1275, 1276, 1277, 1278, 1279, 1280, 1281, 1282, 1283, 1284, 1285, 1286, 1287, 1288, 1289, 1290, 1291, 1292, 1293, 1294, 1295, 1296, 1297, 1298, 1299, 1300, 1301, 1302, 1303, 1304, 1305, 1306, 1307, 1308, 1309, 1310, 1311, 1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1320, 1321, 1322, 1323, 1324, 1325, 1326, 1327, 1328, 1329, 1330, 1331, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344, 1345, 1346, 1347, 1348, 1349, 1350, 1351, 1352, 1353, 1354, 1355, 1356, 1357, 1358, 1359, 1360, 1361, 1362, 1363, 1364, 1365, 1366, 1367, 1368, 1369, 1370, 1371, 1372, 1373, 1374, 1375, 1376, 1377, 1378, 1379, 1380, 1381, 1382, 1383, 1384, 1385, 1386, 1387, 1388, 1389, 1390, 1391, 1392, 1393, 1394, 1395, 1396, 1397, 1398, 1399, 1400, 1401, 1402, 1403, 1404, 1405, 1406, 1407, 1408, 1409, 1410, 1411, 1412, 1413, 1414, 1415, 1416, 1417, 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429, 1430, 1431, 1432, 1433, 1434, 1435, 1436, 1437, 1438, 1439, 1440, 1441, 1442, 1443, 1444, 1445, 1446, 1447, 1448, 1449, 1450, 1451, 1452, 1453, 1454, 1455, 1456, 1457, 1458, 1459, 1460, 1461, 1462, 1463, 1464, 1465, 1466, 1467, 1468, 1469, 1470, 1471, 1472, 1473, 1474, 1475, 1476, 1477, 1478], "test_sequence": [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], "class_names": ["Cabbage White", "Monarch", "Queen", "Red Admiral", "Viceroy"]} -------------------------------------------------------------------------------- /results/experiments/butterflies_crop/strict_1vall.strat: -------------------------------------------------------------------------------- 1 | {"display_text": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "num_train": 20, "positive_feature": [[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]], "im_ids": [458, 21, 988, 115, 355, 636, 1432, 182, 234, 1017, 191, 750, 380, 992, 1407, 311, 488, 1139, 1379, 978], "display_explain_image": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "feat_ids": [[-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1]]} -------------------------------------------------------------------------------- /results/experiments/chinese_chars/explain_1vall.strat: -------------------------------------------------------------------------------- 1 | {"display_text": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "num_train": 20, "positive_feature": [[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]], "im_ids": [57, 355, 297, 220, 502, 457, 54, 12, 71, 441, 214, 197, 53, 194, 392, 58, 317, 331, 17, 29], "display_explain_image": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "feat_ids": [[-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1]]} 2 | -------------------------------------------------------------------------------- /results/experiments/chinese_chars/random.strat: -------------------------------------------------------------------------------- 1 | { 2 | "num_train": 20, 3 | "feat_dims_to_show": [], 4 | "num_feats_to_show": 0, 5 | "display_explain_image": 0, 6 | "display_text": 0 7 | } 8 | -------------------------------------------------------------------------------- /results/experiments/chinese_chars/random_feedback.strat: -------------------------------------------------------------------------------- 1 | { 2 | "num_train": 20, 3 | "feat_dims_to_show": [0], 4 | "num_feats_to_show": 0, 5 | "display_explain_image": 1, 6 | "display_text": 0 7 | } 8 | -------------------------------------------------------------------------------- /results/experiments/chinese_chars/settings.json: -------------------------------------------------------------------------------- 1 | {"scale": 2.0, "experiment_id": 22, "test_indices": [559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701], "train_indices": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558], "test_sequence": [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], "class_names": ["grass", "mound", "stem"]} 2 | -------------------------------------------------------------------------------- /results/experiments/chinese_chars/strict_1vall.strat: -------------------------------------------------------------------------------- 1 | {"display_text": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "num_train": 20, "positive_feature": [[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]], "im_ids": [113, 139, 159, 297, 389, 24, 214, 153, 331, 229, 107, 517, 54, 12, 197, 502, 93, 271, 220, 364], "display_explain_image": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "feat_ids": [[-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1]]} -------------------------------------------------------------------------------- /results/experiments/chinese_chars_crowd/explain_1vall.strat: -------------------------------------------------------------------------------- 1 | {"display_text": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "num_train": 20, "positive_feature": [[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]], "im_ids": [307, 154, 435, 174, 61, 318, 339, 391, 423, 64, 146, 0, 205, 189, 68, 362, 340, 252, 462, 395], "display_explain_image": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "feat_ids": [[-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1]]} -------------------------------------------------------------------------------- /results/experiments/chinese_chars_crowd/settings.json: -------------------------------------------------------------------------------- 1 | {"scale": 2.0, "experiment_id": 25, "test_indices": [527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669], "train_indices": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526], "test_sequence": [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], "class_names": ["grass", "mound", "stem"]} 2 | -------------------------------------------------------------------------------- /results/experiments/chinese_chars_crowd/strict_1vall.strat: -------------------------------------------------------------------------------- 1 | {"display_text": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "num_train": 20, "positive_feature": [[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]], "im_ids": [433, 200, 147, 158, 391, 483, 143, 150, 91, 301, 426, 213, 6, 443, 169, 395, 333, 173, 353, 253], "display_explain_image": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "feat_ids": [[-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1]]} -------------------------------------------------------------------------------- /results/experiments/oct/explain_1vall.strat: -------------------------------------------------------------------------------- 1 | {"display_text": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "num_train": 20, "positive_feature": [[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]], "im_ids": [303, 268, 387, 549, 771, 334, 540, 809, 718, 552, 756, 656, 58, 57, 204, 368, 613, 454, 528, 166], "display_explain_image": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "feat_ids": [[-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1]]} -------------------------------------------------------------------------------- /results/experiments/oct/random.strat: -------------------------------------------------------------------------------- 1 | { 2 | "num_train": 20, 3 | "feat_dims_to_show": [], 4 | "num_feats_to_show": 0, 5 | "display_explain_image": 0, 6 | "display_text": 0 7 | } 8 | -------------------------------------------------------------------------------- /results/experiments/oct/random_feedback.strat: -------------------------------------------------------------------------------- 1 | { 2 | "num_train": 20, 3 | "feat_dims_to_show": [0], 4 | "num_feats_to_show": 0, 5 | "display_explain_image": 1, 6 | "display_text": 0 7 | } 8 | -------------------------------------------------------------------------------- /results/experiments/oct/settings.json: -------------------------------------------------------------------------------- 1 | {"scale": 2.0, "experiment_id": 21, "test_indices": [849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059, 1060, 1061], "train_indices": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848], "test_sequence": [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], "class_names": ["Macular edema", "Normal drusen", "Subretinal fluid"]} 2 | -------------------------------------------------------------------------------- /results/experiments/oct/strict_1vall.strat: -------------------------------------------------------------------------------- 1 | {"display_text": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "num_train": 20, "positive_feature": [[0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0]], "im_ids": [15, 362, 8, 809, 606, 747, 148, 152, 20, 115, 85, 334, 771, 549, 228, 58, 656, 780, 278, 638], "display_explain_image": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "feat_ids": [[-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1], [-1]]} -------------------------------------------------------------------------------- /results/readme.md: -------------------------------------------------------------------------------- 1 | # Teaching Categories to Human Learners with Visual Explanations 2 | Code for recreating the plots in our CVPR 2018 paper. Each of the results files were obtained after running the web interface with the settings specified in the corresponding `experiments` folder. 3 | 4 | ## Reference 5 | If you find our work useful in your research please consider citing our paper. 6 | ``` 7 | @inproceedings{explainteachcvpr18, 8 | title = {Teaching Categories to Human Learners with Visual Explanations}, 9 | author = {Mac Aodha, Oisin and Su, Shihan and Chen, Yuxin and Perona, Pietro and Yue, Yisong}, 10 | booktitle = {CVPR}, 11 | year = {2018} 12 | } 13 | ``` --------------------------------------------------------------------------------