├── README.md ├── __init__.py ├── coco_voc.py ├── dataloader.py ├── extract_cls_oid.py ├── extract_det_oid.py ├── model ├── __init__.py ├── coco_voc.py ├── eval_utils.py ├── mil_vocab.pkl ├── models.py ├── readme.md ├── resnet.py ├── resnet_mil.py ├── resnet_utils.py ├── utils.py └── vgg_mil.py ├── opts.py ├── scripts ├── convert_tf2pth.sh └── graphs │ ├── events.out.tfevents.1525752019.jxgu │ ├── events.out.tfevents.1525752072.jxgu │ └── events.out.tfevents.1525752099.jxgu ├── test.py ├── test.sh ├── test_v2.py ├── train.py ├── train.sh └── vocabs ├── vocab_train.pkl └── vocab_words.txt /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of Multiple-instance learning 2 | 3 | 4 | ## Updates 5 | 6 | - [ ] Training/Testing on MS COCO 7 | - [x] Testing on Openimages, object detection and classification 8 | - [x] Testing on single image 9 | 10 | ## About 11 | This repository contains the MIL implementation for the experiments in: 12 | ``` 13 | https://github.com/s-gupta/visual-concepts 14 | ``` 15 | 16 | If you want to use the original caffe implementation, pls follow the instructions below: 17 | ```bash 18 | git clone git@github.com:s-gupta/caffe.git code/caffe 19 | cd code/caffe 20 | git checkout mil 21 | make -j 16 22 | make pycaffe 23 | cd ../../ 24 | ``` 25 | 26 | Or, you can use the transformed PyTorch model of mine. You can download them from [Tencent Weiyun](https://share.weiyun.com/5TxJAM4) or [Google Drive](https://drive.google.com/open?id=1wgzA7giTKEsZpSJt-NB0JnpSJgK2unVF). Files included in that directory are: 27 | ``` 28 | model/coco_valid1_eval.pkl 29 | model/mil.pth 30 | ``` 31 | 32 | ## Test results 33 | You can change the url in test.py to your testing image. Then run: 34 | ```python 35 | python test.py 36 | ``` 37 | 38 | test0 39 | 40 | ``` 41 | ['beach', 'dog', 'brown', 'standing', 'people', 'his', 'sandy', 'white', 'sitting', 'laying'] 42 | [1.0, 0.62, 0.62, 0.5, 0.45000000000000001, 0.42999999999999999, 0.37, 0.35999999999999999, 0.34000000000000002, 0.28000000000000003] 43 | ``` 44 | 45 | test1 46 | 47 | ``` 48 | ['cat', 'sink', 'sitting', 'black', 'bathroom', 'white', 'top', 'sits', 'counter', 'looking'] 49 | [1.0, 0.68000000000000005, 0.56999999999999995, 0.56000000000000005, 0.39000000000000001, 0.34999999999999998, 0.31, 0.26000000000000001, 0.23000000000000001, 0.22] 50 | ``` 51 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/__init__.py -------------------------------------------------------------------------------- /coco_voc.py: -------------------------------------------------------------------------------- 1 | pycoco/coco_voc.py -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import cPickle 5 | import json 6 | import h5py 7 | import os 8 | import numpy as np 9 | import random 10 | import torch 11 | import cv2, numpy as np 12 | from torchvision import transforms as trn 13 | from multiprocessing.dummy import Pool 14 | import math 15 | import gc 16 | 17 | preprocess = trn.Compose([ 18 | #trn.ToTensor(), 19 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 20 | ]) 21 | 22 | preprocess_vgg16 = trn.Compose([ 23 | #trn.ToTensor(), 24 | trn.Normalize([123.680, 103.939, 116.779], [1.000, 1.000, 1.000]) 25 | ]) 26 | 27 | def upsample_image(im, sz): 28 | h = im.shape[0] 29 | w = im.shape[1] 30 | s = np.float(max(h, w)) 31 | #I_out = np.zeros((sz, sz, 3), dtype=np.float); 32 | #I = cv2.resize(im, None, None, fx=np.float(sz) / s, fy=np.float(sz) / s, interpolation=cv2.INTER_CUBIC); #INTER_CUBIC, INTER_LINEAR 33 | I = cv2.resize(im, (sz, sz), interpolation=cv2.INTER_LINEAR) 34 | SZ = I.shape; 35 | #I_out[0:I.shape[0], 0:I.shape[1], :] = I; 36 | return I, I, SZ 37 | 38 | def preprocess_vgg19_mil(Image): 39 | if len(Image.shape) == 2: 40 | Image = Image[:, :, np.newaxis] 41 | Image = np.concatenate((Image, Image, Image), axis=2) 42 | 43 | mean = np.array([[[103.939, 116.779, 123.68]]]); 44 | base_image_size = 565; 45 | Image = cv2.resize(np.transpose(Image, axes=(1, 2, 0)), (base_image_size, base_image_size), interpolation=cv2.INTER_CUBIC) 46 | Image_orig = Image.astype(np.float32, copy=True) 47 | Image_orig -= mean 48 | im = Image_orig 49 | #im, gr, grr = upsample_image(Image_orig, base_image_size) 50 | # im = cv2.resize(Image_orig, (base_image_size, base_image_size), interpolation=cv2.INTER_CUBIC) 51 | im = np.transpose(im, axes=(2, 0, 1)) 52 | im = im[np.newaxis, :, :, :] 53 | return im 54 | 55 | ''' 56 | Load data from h5 files 57 | ''' 58 | class DataLoader(): 59 | def reset_iterator(self, split): 60 | # if load files from directory, then reset the prefetch process 61 | self.iterators[split] = 0 62 | 63 | def get_vocab_size(self): 64 | return self.vocab_size 65 | 66 | def get_vocab(self): 67 | return self.ix_to_word 68 | 69 | def get_seq_length(self): 70 | return self.seq_length 71 | 72 | def __init__(self, opt): 73 | self.type = 'h5' 74 | self.opt = opt 75 | self.model = getattr(opt, 'model', 'resnet101') 76 | self.attrs_in = getattr(opt, 'attrs_in', 0) 77 | self.attrs_out = getattr(opt, 'attrs_out', 0) 78 | self.att_im = getattr(opt, 'att_im', 1) 79 | self.pre_ft = getattr(opt, 'pre_ft', 1) 80 | self.mil_vocab_outsize = 1000 81 | self.top_attrs = 10 82 | self.fc_feat_size = opt.fc_feat_size 83 | self.att_feat_size = opt.att_feat_size 84 | self.batch_size = opt.batch_size 85 | self.seq_per_img = opt.seq_per_img 86 | 87 | # load the json file which contains additional information about the dataset 88 | print('DataLoader loading json file: ', opt.input_json) 89 | self.info = json.load(open(self.opt.input_json)) 90 | self.mil_vocab = cPickle.load(open('model/mil_vocab.pkl')) 91 | self.ix_to_word = self.info['ix_to_word'] 92 | self.vocab_size = len(self.ix_to_word) 93 | print('vocab size is ', self.vocab_size) 94 | 95 | # open the hdf5 file 96 | print('DataLoader loading h5 file: ', opt.input_im_h5) 97 | self.h5_im_file = h5py.File(self.opt.input_im_h5) 98 | # extract image size from dataset 99 | images_size = self.h5_im_file['images'].shape 100 | assert len(images_size) == 4, 'images should be a 4D tensor' 101 | assert images_size[2] == images_size[3], 'width and height must match' 102 | self.num_images = images_size[0] 103 | self.num_channels = images_size[1] 104 | self.max_image_size = images_size[2] 105 | print('read %d images of size %dx%dx%d' %(self.num_images, 106 | self.num_channels, self.max_image_size, self.max_image_size)) 107 | # load in the sequence data 108 | self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core') 109 | seq_size = self.h5_label_file['labels'].shape 110 | self.seq_length = seq_size[1] 111 | semantic_attrs_size = self.h5_label_file['semantic_words'].shape 112 | self.semantic_attrs_length = semantic_attrs_size[1] 113 | print('max sequence length in data is', self.seq_length) 114 | print('max semantic words length in data is', self.semantic_attrs_length) 115 | # load the pointers in full to RAM (should be small enough) 116 | self.label_start_ix = self.h5_label_file['label_start_ix'][:] 117 | self.label_end_ix = self.h5_label_file['label_end_ix'][:] 118 | 119 | self.num_images = self.label_start_ix.shape[0] 120 | print('read %d image / features' %(self.num_images)) 121 | 122 | # separate out indexes for each of the provided splits 123 | self.split_ix = {'train': [], 'val': [], 'test': []} 124 | for ix in range(len(self.info['images'])): 125 | img = self.info['images'][ix] 126 | if img['split'] == 'train': 127 | self.split_ix['train'].append(ix) 128 | elif img['split'] == 'val': 129 | self.split_ix['val'].append(ix) 130 | elif img['split'] == 'test': 131 | self.split_ix['test'].append(ix) 132 | elif opt.train_only == 0: # restval 133 | self.split_ix['train'].append(ix) 134 | 135 | print('assigned %d images to split train' %len(self.split_ix['train'])) 136 | print('assigned %d images to split val' %len(self.split_ix['val'])) 137 | print('assigned %d images to split test' %len(self.split_ix['test'])) 138 | 139 | self.iterators = {'train': 0, 'val': 0, 'test': 0} 140 | 141 | def gen_mil_gt(self, attrs): 142 | mil_batch = np.zeros([1, self.mil_vocab_outsize], dtype='int') 143 | for k in range(len(attrs)): 144 | if attrs[k] > 0: 145 | for i in range(self.mil_vocab_outsize): 146 | if self.ix_to_word[str(attrs[k])] == self.mil_vocab[i]: 147 | mil_batch[0, i] = 1 148 | 149 | return mil_batch 150 | 151 | def get_batch(self, split, batch_size=None, seq_per_img=None): 152 | split_ix = self.split_ix[split] 153 | batch_size = batch_size or self.batch_size 154 | seq_per_img = seq_per_img or self.seq_per_img 155 | 156 | if 'vgg19' in self.model: 157 | img_batch = np.ndarray([batch_size, 3, 565, 565], dtype='float32') 158 | else: 159 | img_batch = np.ndarray([batch_size, 3, 224, 224], dtype='float32') 160 | label_batch = np.zeros([batch_size * self.seq_per_img, self.seq_length + 2], dtype = 'int') 161 | mask_batch = np.zeros([batch_size * self.seq_per_img, self.seq_length + 2], dtype = 'float32') 162 | attrs_batch = np.zeros([batch_size, self.top_attrs], dtype = 'int') 163 | mil_batch = np.zeros([batch_size, self.mil_vocab_outsize], dtype='int') 164 | max_index = len(split_ix) 165 | wrapped = False 166 | 167 | infos = [] 168 | gts = [] 169 | 170 | for i in range(batch_size): 171 | import time 172 | t_start = time.time() 173 | 174 | ri = self.iterators[split] 175 | ri_next = ri + 1 176 | if ri_next >= max_index: 177 | ri_next = 0 178 | wrapped = True 179 | self.iterators[split] = ri_next 180 | ix = split_ix[ri] 181 | 182 | #img = self.load_image(self.image_info[ix]['filename']) 183 | img = self.h5_im_file['images'][ix, :, :, :] 184 | if 'resnet' in self.model: 185 | img_batch[i] = preprocess(torch.from_numpy(img[:, 16:-16, 16:-16].astype('float32')/255.0)).numpy() 186 | else: 187 | #img_batch[i] = preprocess_vgg16(torch.from_numpy(img[:, 16:-16, 16:-16].astype('float32'))).numpy() 188 | img_batch[i] = preprocess_vgg19_mil(img) 189 | 190 | # fetch the semantic_attributes 191 | attrs_batch[i] = self.h5_label_file['semantic_words'][ix, : self.top_attrs] 192 | mil_batch[i] = self.gen_mil_gt(attrs_batch[i]) 193 | 194 | # fetch the sequence labels 195 | ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1 196 | ix2 = self.label_end_ix[ix] - 1 197 | ncap = ix2 - ix1 + 1 # number of captions available for this image 198 | assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t' 199 | 200 | # record associated info as well 201 | info_dict = {} 202 | info_dict['ix'] = ix 203 | info_dict['id'] = self.info['images'][ix]['id'] 204 | info_dict['file_path'] = self.info['images'][ix]['file_path'] 205 | infos.append(info_dict) 206 | 207 | # generate mask 208 | t_start = time.time() 209 | nonzeros = np.array(map(lambda x: (x != 0).sum()+2, label_batch)) 210 | for ix, row in enumerate(mask_batch): 211 | row[:nonzeros[ix]] = 1 212 | 213 | data = {} 214 | 215 | data['images'] = img_batch # if pre_ft is 1, then it equals None 216 | data['semantic_words'] = attrs_batch # if attributes is 1, then it equals None 217 | data['mil_label'] = mil_batch 218 | data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(split_ix), 'wrapped': wrapped} 219 | data['infos'] = infos 220 | 221 | gc.collect() 222 | 223 | return data -------------------------------------------------------------------------------- /extract_cls_oid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2017 The Open Images Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | r"""Classifier inference utility. 18 | 19 | This code takes a resnet_v1_101 checkpoint, runs the classifier on the image and 20 | prints predictions in human-readable form. 21 | 22 | ------------------------------- 23 | Example command: 24 | ------------------------------- 25 | 26 | # 0. Create directory for model/data 27 | WORK_PATH="/tmp/oidv2" 28 | mkdir -p "${WORK_PATH}" 29 | cd "${WORK_PATH}" 30 | 31 | # 1. Download the model, inference code, and sample image 32 | wget https://storage.googleapis.com/openimages/2017_07/classes-trainable.txt 33 | wget https://storage.googleapis.com/openimages/2017_07/class-descriptions.csv 34 | wget https://storage.googleapis.com/openimages/2017_07/oidv2-resnet_v1_101.ckpt.tar.gz 35 | wget https://raw.githubusercontent.com/openimages/dataset/master/tools/classify_oidv2.py 36 | tar -xzf oidv2-resnet_v1_101.ckpt.tar.gz 37 | 38 | wget -O cat.jpg https://farm6.staticflickr.com/5470/9372235876_d7d69f1790_b.jpg 39 | 40 | # 2. Run inference 41 | python classify_oidv2.py \ 42 | --checkpoint_path='oidv2-resnet_v1_101.ckpt' \ 43 | --labelmap='classes-trainable.txt' \ 44 | --dict='class-descriptions.csv' \ 45 | --image="cat.jpg" \ 46 | --top_k=10 \ 47 | --score_threshold=0.3 48 | 49 | # Sample output: 50 | Image: "cat.jpg" 51 | 52 | 3272: /m/068hy - Pet (score = 0.96) 53 | 1076: /m/01yrx - Cat (score = 0.95) 54 | 0708: /m/01l7qd - Whiskers (score = 0.90) 55 | 4755: /m/0jbk - Animal (score = 0.90) 56 | 2847: /m/04rky - Mammal (score = 0.89) 57 | 2036: /m/0307l - Felidae (score = 0.79) 58 | 3574: /m/07k6w8 - Small to medium-sized cats (score = 0.77) 59 | 4799: /m/0k0pj - Nose (score = 0.70) 60 | 1495: /m/02cqfm - Close-up (score = 0.55) 61 | 0036: /m/012c9l - Domestic short-haired cat (score = 0.40) 62 | 63 | ------------------------------- 64 | Note on image preprocessing: 65 | ------------------------------- 66 | 67 | This is the code used to perform preprocessing: 68 | -------- 69 | from preprocessing import preprocessing_factory 70 | 71 | def PreprocessImage(image, network='resnet_v1_101', image_size=299): 72 | # If resolution is larger than 224 we need to adjust some internal resizing 73 | # parameters for vgg preprocessing. 74 | if any(network.startswith(x) for x in ['resnet', 'vgg']): 75 | preprocessing_kwargs = { 76 | 'resize_side_min': int(256 * image_size / 224), 77 | 'resize_side_max': int(512 * image_size / 224) 78 | } 79 | else: 80 | preprocessing_kwargs = {} 81 | preprocessing_fn = preprocessing_factory.get_preprocessing( 82 | name=network, is_training=False) 83 | 84 | height = image_size 85 | width = image_size 86 | image = preprocessing_fn(image, height, width, **preprocessing_kwargs) 87 | image.set_shape([height, width, 3]) 88 | return image 89 | -------- 90 | 91 | Note that there appears to be a small difference between the public version 92 | of slim image processing library and the internal version (which the meta 93 | graph is based on). Results that are very close, but not exactly identical to 94 | that of the metagraph. 95 | """ 96 | from __future__ import absolute_import 97 | from __future__ import division 98 | from __future__ import print_function 99 | import tensorflow as tf 100 | import argparse 101 | import random 102 | import numpy as np 103 | import time, os, sys 104 | import json 105 | import cv2 106 | import os 107 | import six.moves.urllib as urllib 108 | import sys 109 | import tarfile 110 | import zipfile 111 | 112 | from collections import defaultdict 113 | from io import StringIO 114 | from matplotlib import pyplot as plt 115 | from PIL import Image 116 | 117 | # This is needed since the notebook is stored in the object_detection folder. 118 | flags = tf.app.flags 119 | FLAGS = flags.FLAGS 120 | 121 | def load_image_ids(split_name): 122 | ''' Load a list of (path,image_id tuples). Modify this to suit your data locations. ''' 123 | split = [] 124 | if split_name == 'coco_test2014': 125 | with open('data/mscoco/annotations/image_info_test2014.json') as f: 126 | data = json.load(f) 127 | for item in data['images']: 128 | image_id = int(item['id']) 129 | filepath = os.path.join('data/mscoco/test2014/', item['file_name']) 130 | split.append((filepath,image_id)) 131 | elif split_name == 'coco_val2014': 132 | with open('data/mscoco/annotations/captions_val2014.json') as f: 133 | data = json.load(f) 134 | for item in data['images']: 135 | image_id = int(item['id']) 136 | filepath = os.path.join('data/mscoco/val2014/', item['file_name']) 137 | split.append((filepath,image_id)) 138 | elif split_name == 'coco_train2014': 139 | with open('data/mscoco/annotations/captions_train2014.json') as f: 140 | data = json.load(f) 141 | for item in data['images']: 142 | image_id = int(item['id']) 143 | filepath = os.path.join('data/mscoco/train2014/', item['file_name']) 144 | split.append((filepath,image_id)) 145 | elif split_name == 'coco_test2015': 146 | with open('data/mscoco/annotations/image_info_test2015.json') as f: 147 | data = json.load(f) 148 | for item in data['images']: 149 | image_id = int(item['id']) 150 | filepath = os.path.join('data/mscoco/test2015/', item['file_name']) 151 | split.append((filepath,image_id)) 152 | elif split_name == 'genome': 153 | with open('data/visualgenome/image_data.json') as f: 154 | for item in json.load(f): 155 | image_id = int(item['image_id']) 156 | filepath = os.path.join('data/visualgenome/', item['url'].split('rak248/')[-1]) 157 | split.append((filepath,image_id)) 158 | elif split_name == 'chinese': 159 | with open('data/aic_i2t/ai_challenger_caption_train_20170902/caption_train_annotations_20170902.json') as f: 160 | for item in json.load(f): 161 | image_id = item['image_id'] 162 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_train_20170902/caption_train_images_20170902', image_id) 163 | split.append((filepath,image_id)) 164 | elif split_name == 'chinese_val': 165 | with open('data/aic_i2t/ai_challenger_caption_validation_20170910/caption_validation_annotations_20170910.json') as f: 166 | for item in json.load(f): 167 | image_id = item['image_id'] 168 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_validation_20170910/caption_validation_images_20170910', image_id) 169 | split.append((filepath,image_id)) 170 | elif split_name == 'chinese_test1': 171 | with open('data/aic_i2t/ai_challenger_caption_test1_20170923/caption_test1_annotations_20170923.json') as f: 172 | for item in json.load(f): 173 | image_id = item['image_id'] 174 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_test1_20170923/caption_test1_images_20170923', image_id) 175 | split.append((filepath,image_id)) 176 | else: 177 | print('Unknown split') 178 | return split 179 | 180 | def get_classifications_from_im(args, g, sess, image_ids): 181 | save_dir = 'oid_data/' 182 | input_values = g.get_tensor_by_name('input_values:0') 183 | predictions = g.get_tensor_by_name('multi_predictions:0') 184 | count = 0 185 | for im_file, image_id in image_ids: 186 | compressed_image = tf.gfile.FastGFile(im_file, 'rb').read() 187 | predictions_eval = sess.run(predictions, feed_dict={input_values: [compressed_image]}) 188 | if 'chinese' in args.data_split: 189 | np.savez_compressed(save_dir + 'aic_i2t/oid_cls/' + str(image_id), feat=predictions_eval) 190 | else: 191 | np.savez_compressed(save_dir + 'mscoco/oid_cls/' + str(image_id), feat=predictions_eval) 192 | if (count % 100) == 0: 193 | print('{:d}'.format(count + 1)) 194 | count += 1 195 | 196 | def parse_args(): 197 | """ 198 | Parse input arguments 199 | """ 200 | parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network') 201 | parser.add_argument('--gpu', dest='gpu_id', help='GPU id(s) to use', default='0', type=str) 202 | parser.add_argument('--type', dest='type', help='', default='det', type=str) 203 | parser.add_argument('--def', dest='prototxt', help='prototxt file defining the network', default='../models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt', type=str) 204 | parser.add_argument('--net', dest='caffemodel', help='model to use', default='../data/faster_rcnn_models/resnet101_faster_rcnn_final.caffemodel', type=str) 205 | parser.add_argument('--out', dest='outfile', help='output filepath', default='karpathy_train_resnet101_faster_rcnn_genome', type=str) 206 | parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default='../experiments/cfgs/faster_rcnn_end2end_resnet.yml', type=str) 207 | parser.add_argument('--split', dest='data_split', help='dataset to use', default='coco_val2014', type=str) 208 | parser.add_argument('--checkpoint_path', dest='checkpoint_path', help='', default='model/oidv2_resnet_v1_101/oidv2-resnet_v1_101.ckpt', type=str) 209 | parser.add_argument('--set', dest='set_cfgs', help='set config keys', default=None, nargs=argparse.REMAINDER) 210 | 211 | args = parser.parse_args() 212 | return args 213 | 214 | 215 | def main(_): 216 | args = parse_args() 217 | 218 | print('Called with args:') 219 | print(args) 220 | g = tf.Graph() 221 | with g.as_default(): 222 | with tf.Session() as sess: 223 | saver = tf.train.import_meta_graph(args.checkpoint_path + '.meta') 224 | saver.restore(sess, args.checkpoint_path) 225 | image_ids = load_image_ids(args.data_split) 226 | get_classifications_from_im(args, g, sess, image_ids) 227 | 228 | if __name__ == '__main__': 229 | tf.app.run() 230 | -------------------------------------------------------------------------------- /extract_det_oid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2017 The Open Images Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | r"""Classifier inference utility. 18 | 19 | This code takes a resnet_v1_101 checkpoint, runs the classifier on the image and 20 | prints predictions in human-readable form. 21 | 22 | ------------------------------- 23 | Example command: 24 | ------------------------------- 25 | 26 | # 0. Create directory for model/data 27 | WORK_PATH="/tmp/oidv2" 28 | mkdir -p "${WORK_PATH}" 29 | cd "${WORK_PATH}" 30 | 31 | # 1. Download the model, inference code, and sample image 32 | wget https://storage.googleapis.com/openimages/2017_07/classes-trainable.txt 33 | wget https://storage.googleapis.com/openimages/2017_07/class-descriptions.csv 34 | wget https://storage.googleapis.com/openimages/2017_07/oidv2-resnet_v1_101.ckpt.tar.gz 35 | wget https://raw.githubusercontent.com/openimages/dataset/master/tools/classify_oidv2.py 36 | tar -xzf oidv2-resnet_v1_101.ckpt.tar.gz 37 | 38 | wget -O cat.jpg https://farm6.staticflickr.com/5470/9372235876_d7d69f1790_b.jpg 39 | 40 | # 2. Run inference 41 | python classify_oidv2.py \ 42 | --checkpoint_path='oidv2-resnet_v1_101.ckpt' \ 43 | --labelmap='classes-trainable.txt' \ 44 | --dict='class-descriptions.csv' \ 45 | --image="cat.jpg" \ 46 | --top_k=10 \ 47 | --score_threshold=0.3 48 | 49 | # Sample output: 50 | Image: "cat.jpg" 51 | 52 | 3272: /m/068hy - Pet (score = 0.96) 53 | 1076: /m/01yrx - Cat (score = 0.95) 54 | 0708: /m/01l7qd - Whiskers (score = 0.90) 55 | 4755: /m/0jbk - Animal (score = 0.90) 56 | 2847: /m/04rky - Mammal (score = 0.89) 57 | 2036: /m/0307l - Felidae (score = 0.79) 58 | 3574: /m/07k6w8 - Small to medium-sized cats (score = 0.77) 59 | 4799: /m/0k0pj - Nose (score = 0.70) 60 | 1495: /m/02cqfm - Close-up (score = 0.55) 61 | 0036: /m/012c9l - Domestic short-haired cat (score = 0.40) 62 | 63 | ------------------------------- 64 | Note on image preprocessing: 65 | ------------------------------- 66 | 67 | This is the code used to perform preprocessing: 68 | -------- 69 | from preprocessing import preprocessing_factory 70 | 71 | def PreprocessImage(image, network='resnet_v1_101', image_size=299): 72 | # If resolution is larger than 224 we need to adjust some internal resizing 73 | # parameters for vgg preprocessing. 74 | if any(network.startswith(x) for x in ['resnet', 'vgg']): 75 | preprocessing_kwargs = { 76 | 'resize_side_min': int(256 * image_size / 224), 77 | 'resize_side_max': int(512 * image_size / 224) 78 | } 79 | else: 80 | preprocessing_kwargs = {} 81 | preprocessing_fn = preprocessing_factory.get_preprocessing( 82 | name=network, is_training=False) 83 | 84 | height = image_size 85 | width = image_size 86 | image = preprocessing_fn(image, height, width, **preprocessing_kwargs) 87 | image.set_shape([height, width, 3]) 88 | return image 89 | -------- 90 | 91 | Note that there appears to be a small difference between the public version 92 | of slim image processing library and the internal version (which the meta 93 | graph is based on). Results that are very close, but not exactly identical to 94 | that of the metagraph. 95 | """ 96 | from __future__ import absolute_import 97 | from __future__ import division 98 | from __future__ import print_function 99 | import tensorflow as tf 100 | import argparse 101 | import random 102 | import numpy as np 103 | import time, os, sys 104 | import json 105 | import cv2 106 | import os 107 | import six.moves.urllib as urllib 108 | import sys 109 | import tarfile 110 | import zipfile 111 | 112 | from collections import defaultdict 113 | from io import StringIO 114 | from matplotlib import pyplot as plt 115 | from PIL import Image 116 | 117 | # This is needed since the notebook is stored in the object_detection folder. 118 | sys.path.append("/home/jxgu/github/MIL.pytorch/misc/models/research") 119 | from utils import label_map_util 120 | from utils import visualization_utils as vis_util 121 | import utils.ops as utils_ops 122 | flags = tf.app.flags 123 | FLAGS = flags.FLAGS 124 | 125 | def load_image_ids(split_name): 126 | ''' Load a list of (path,image_id tuples). Modify this to suit your data locations. ''' 127 | split = [] 128 | if split_name == 'coco_test2014': 129 | with open('data/mscoco/annotations/image_info_test2014.json') as f: 130 | data = json.load(f) 131 | for item in data['images']: 132 | image_id = int(item['id']) 133 | filepath = os.path.join('data/mscoco/test2014/', item['file_name']) 134 | split.append((filepath,image_id)) 135 | elif split_name == 'coco_val2014': 136 | with open('data/mscoco/annotations/captions_val2014.json') as f: 137 | data = json.load(f) 138 | for item in data['images']: 139 | image_id = int(item['id']) 140 | filepath = os.path.join('data/mscoco/val2014/', item['file_name']) 141 | split.append((filepath,image_id)) 142 | elif split_name == 'coco_train2014': 143 | with open('data/mscoco/annotations/captions_train2014.json') as f: 144 | data = json.load(f) 145 | for item in data['images']: 146 | image_id = int(item['id']) 147 | filepath = os.path.join('data/mscoco/train2014/', item['file_name']) 148 | split.append((filepath,image_id)) 149 | elif split_name == 'coco_test2015': 150 | with open('data/mscoco/annotations/image_info_test2015.json') as f: 151 | data = json.load(f) 152 | for item in data['images']: 153 | image_id = int(item['id']) 154 | filepath = os.path.join('data/mscoco/test2015/', item['file_name']) 155 | split.append((filepath,image_id)) 156 | elif split_name == 'genome': 157 | with open('data/visualgenome/image_data.json') as f: 158 | for item in json.load(f): 159 | image_id = int(item['image_id']) 160 | filepath = os.path.join('data/visualgenome/', item['url'].split('rak248/')[-1]) 161 | split.append((filepath,image_id)) 162 | elif split_name == 'chinese': 163 | with open('data/aic_i2t/ai_challenger_caption_train_20170902/caption_train_annotations_20170902.json') as f: 164 | for item in json.load(f): 165 | image_id = item['image_id'] 166 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_train_20170902/caption_train_images_20170902', image_id) 167 | split.append((filepath,image_id)) 168 | elif split_name == 'chinese_val': 169 | with open('data/aic_i2t/ai_challenger_caption_validation_20170910/caption_validation_annotations_20170910.json') as f: 170 | for item in json.load(f): 171 | image_id = item['image_id'] 172 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_validation_20170910/caption_validation_images_20170910', image_id) 173 | split.append((filepath,image_id)) 174 | elif split_name == 'chinese_test1': 175 | with open('data/aic_i2t/ai_challenger_caption_test1_20170923/caption_test1_annotations_20170923.json') as f: 176 | for item in json.load(f): 177 | image_id = item['image_id'] 178 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_test1_20170923/caption_test1_images_20170923', image_id) 179 | split.append((filepath,image_id)) 180 | else: 181 | print('Unknown split') 182 | return split 183 | 184 | def run_inference_for_single_image(image, graph): 185 | with graph.as_default(): 186 | with tf.Session() as sess: 187 | # Get handles to input and output tensors 188 | ops = tf.get_default_graph().get_operations() 189 | all_tensor_names = {output.name for op in ops for output in op.outputs} 190 | tensor_dict = {} 191 | for key in [ 192 | 'num_detections', 'detection_boxes', 'detection_scores', 193 | 'detection_classes', 'detection_masks' 194 | ]: 195 | tensor_name = key + ':0' 196 | if tensor_name in all_tensor_names: 197 | tensor_dict[key] = tf.get_default_graph().get_tensor_by_name( 198 | tensor_name) 199 | if 'detection_masks' in tensor_dict: 200 | # The following processing is only for single image 201 | detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0]) 202 | detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0]) 203 | # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size. 204 | real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32) 205 | detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1]) 206 | detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1]) 207 | detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks( 208 | detection_masks, detection_boxes, image.shape[0], image.shape[1]) 209 | detection_masks_reframed = tf.cast(tf.greater(detection_masks_reframed, 0.2), tf.uint8) 210 | # Follow the convention by adding back the batch dimension 211 | tensor_dict['detection_masks'] = tf.expand_dims( 212 | detection_masks_reframed, 0) 213 | image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0') 214 | 215 | # Run inference 216 | output_dict = sess.run(tensor_dict, feed_dict={image_tensor: np.expand_dims(image, 0)}) 217 | 218 | # all outputs are float32 numpy arrays, so convert types as appropriate 219 | output_dict['num_detections'] = int(output_dict['num_detections'][0]) 220 | output_dict['detection_classes'] = output_dict[ 221 | 'detection_classes'][0].astype(np.uint8) 222 | output_dict['detection_boxes'] = output_dict['detection_boxes'][0] 223 | output_dict['detection_scores'] = output_dict['detection_scores'][0] 224 | if 'detection_masks' in output_dict: 225 | output_dict['detection_masks'] = output_dict['detection_masks'][0] 226 | return output_dict 227 | 228 | def get_detections_from_im(args, detection_graph, image_ids): 229 | save_dir = '/home/jxgu/github/MIL.pytorch/oid_data/' 230 | count = 0 231 | for im_file, image_id in image_ids: 232 | image = Image.open(im_file) 233 | # the array based representation of the image will be used later in order to prepare the 234 | # result image with boxes and labels on it. 235 | image_np = load_image_into_numpy_array(image) 236 | # Expand dimensions since the model expects images to have shape: [1, None, None, 3] 237 | image_np_expanded = np.expand_dims(image_np, axis=0) 238 | # Actual detection. 239 | output_dict = run_inference_for_single_image(image_np, detection_graph) 240 | if 'chinese' in args.data_split: 241 | np.savez_compressed(save_dir + 'aic_i2t/oid_det/' + str(image_id), feat=output_dict) 242 | else: 243 | np.savez_compressed(save_dir + 'mscoco/oid_det/' + str(image_id), feat=output_dict) 244 | 245 | if (count % 100) == 0: 246 | print('{:d}'.format(count + 1)) 247 | count += 1 248 | 249 | def load_image_into_numpy_array(image): 250 | (im_width, im_height) = image.size 251 | return np.array(image.getdata()).reshape( 252 | (im_height, im_width, 3)).astype(np.uint8) 253 | 254 | def parse_args(): 255 | """ 256 | Parse input arguments 257 | """ 258 | parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network') 259 | parser.add_argument('--gpu', dest='gpu_id', help='GPU id(s) to use', default='0', type=str) 260 | parser.add_argument('--type', dest='type', help='', default='det', type=str) 261 | parser.add_argument('--def', dest='prototxt', help='prototxt file defining the network', default='../models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt', type=str) 262 | parser.add_argument('--net', dest='caffemodel', help='model to use', default='../../../../data/faster_rcnn_models/resnet101_faster_rcnn_final.caffemodel', type=str) 263 | parser.add_argument('--out', dest='outfile', help='output filepath', default='karpathy_train_resnet101_faster_rcnn_genome', type=str) 264 | parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default='../experiments/cfgs/faster_rcnn_end2end_resnet.yml', type=str) 265 | parser.add_argument('--split', dest='data_split', help='dataset to use', default='coco_train2014', type=str) 266 | parser.add_argument('--checkpoint_path', dest='checkpoint_path', help='', default='/home/jxgu/github/MIL.pytorch/model/faster_rcnn_inception_resnet_v2_atrous_oid_2018_01_28.tar.gz', type=str) 267 | parser.add_argument('--set', dest='set_cfgs', help='set config keys', default=None, nargs=argparse.REMAINDER) 268 | 269 | args = parser.parse_args() 270 | return args 271 | 272 | 273 | def main(_): 274 | args = parse_args() 275 | 276 | # What model to download. 277 | MODEL_NAME = '/home/jxgu/github/MIL.pytorch/model/faster_rcnn_inception_resnet_v2_atrous_oid_2018_01_28' 278 | MODEL_FILE = MODEL_NAME + '.tar.gz' 279 | PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 280 | tar_file = tarfile.open(MODEL_FILE) 281 | for file in tar_file.getmembers(): 282 | file_name = os.path.basename(file.name) 283 | if 'frozen_inference_graph.pb' in file_name: 284 | tar_file.extract(file, os.getcwd()) 285 | 286 | tar_file = tarfile.open(args.checkpoint_path) 287 | for file in tar_file.getmembers(): 288 | file_name = os.path.basename(file.name) 289 | if 'frozen_inference_graph.pb' in file_name: 290 | tar_file.extract(file, os.getcwd()) 291 | 292 | detection_graph = tf.Graph() 293 | with detection_graph.as_default(): 294 | od_graph_def = tf.GraphDef() 295 | with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 296 | serialized_graph = fid.read() 297 | od_graph_def.ParseFromString(serialized_graph) 298 | tf.import_graph_def(od_graph_def, name='') 299 | image_ids = load_image_ids(args.data_split) 300 | get_detections_from_im(args, detection_graph, image_ids) 301 | if __name__ == '__main__': 302 | tf.app.run() 303 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/model/__init__.py -------------------------------------------------------------------------------- /model/coco_voc.py: -------------------------------------------------------------------------------- 1 | ../pycoco/coco_voc.py -------------------------------------------------------------------------------- /model/eval_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | 9 | import numpy as np 10 | import json 11 | from json import encoder 12 | import random 13 | import string 14 | import time 15 | import os 16 | import sys 17 | import model.utils as utils 18 | 19 | def eval_split(opt, model, crit, loader, eval_kwargs={}): 20 | verbose = eval_kwargs.get('verbose', True) 21 | num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) 22 | split = eval_kwargs.get('split', 'val') 23 | dataset = eval_kwargs.get('dataset', 'coco') 24 | 25 | # Make sure in the evaluation mode 26 | model.eval() 27 | loader.reset_iterator(split) 28 | 29 | n = 0 30 | loss = 0 31 | loss_sum = 0 32 | loss_evals = 1e-8 33 | predictions = [] 34 | while True: 35 | data = loader.get_batch(split) 36 | n = n + loader.batch_size 37 | # forward the model to get loss 38 | tmp = [data['images'], data['mil_label']] 39 | tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp] 40 | images, mil_label = tmp 41 | 42 | loss = crit(model(images), mil_label).data[0] 43 | loss_sum = loss_sum + loss 44 | loss_evals = loss_evals + 1 45 | 46 | if data['bounds']['wrapped']: 47 | break 48 | if num_images >= 0 and n >= num_images: 49 | break 50 | 51 | # Switch back to training mode 52 | model.train() 53 | return loss_sum / loss_evals -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import time 10 | import Image 11 | import os 12 | import os.path as osp 13 | import sys 14 | import platform 15 | import cPickle 16 | import urllib 17 | import cv2, numpy as np 18 | from scipy.interpolate import interp1d 19 | from matplotlib.pyplot import show 20 | import matplotlib.pyplot as plt 21 | import torch 22 | import torch.nn as nn 23 | import torch.optim as optim 24 | import torch.nn.functional as F 25 | import torch.backends.cudnn as cudnn 26 | import tensorflow as tf 27 | import torchvision 28 | import torchvision.transforms as transforms 29 | from six.moves import cPickle 30 | from model.vgg_mil import * 31 | from model.resnet_mil import * 32 | from model.utils import * 33 | import itertools 34 | 35 | class Criterion(nn.Module): 36 | def __init__(self): 37 | super(Criterion, self).__init__() 38 | #self.loss0 = nn.MultiLabelMarginLoss() 39 | self.loss0 = nn.MultiLabelSoftMarginLoss() 40 | #self.loss0 = nn.MultiMarginLoss() 41 | #self.loss0 = nn.CrossEntropyLoss() 42 | #self.loss0 = nn.NLLLoss() 43 | 44 | #self.loss1 = nn.MultiMarginLoss() 45 | 46 | def forward(self, input, target): 47 | output0 = self.loss0(input, target.float()) 48 | #output1 = self.loss1(input, target.long()) 49 | return output0 50 | 51 | def build_mil(opt): 52 | opt.n_gpus = getattr(opt, 'n_gpus', 1) 53 | 54 | if 'resnet101' in opt.model: 55 | mil_model = resnet_mil(opt) 56 | else: 57 | mil_model = vgg_mil(opt) 58 | 59 | if opt.n_gpus>1: 60 | print('Construct multi-gpu model ...') 61 | model = nn.DataParallel(mil_model, device_ids=opt.gpus, dim=0) 62 | else: 63 | model = mil_model 64 | # check compatibility if training is continued from previously saved model 65 | if len(opt.start_from) != 0: 66 | # check if all necessary files exist 67 | assert os.path.isdir(opt.start_from), " %s must be a a path" % opt.start_from 68 | lm_info_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.infos-best.pkl') 69 | lm_pth_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.model-best.pth') 70 | assert os.path.isfile(lm_info_path), "infos.pkl file does not exist in path %s" % opt.start_from 71 | model.load_state_dict(torch.load(lm_pth_path)) 72 | model.cuda() 73 | model.train() # Assure in training mode 74 | return model 75 | 76 | def build_optimizer(opt, model, infos): 77 | opt.pre_ft = getattr(opt, 'pre_ft', 1) 78 | 79 | #model_parameters = itertools.ifilter(lambda p: p.requires_grad, model.parameters()) 80 | optimize = opt.optim 81 | if optimize == 'adam': 82 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) 83 | elif optimize == 'sgd': 84 | optimizer = torch.optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=0.999, weight_decay=0.0005) 85 | elif optimize == 'Adadelta': 86 | optimizer = torch.optim.Adadelta(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) 87 | elif optimize == 'Adagrad': 88 | optimizer = torch.optim.Adagrad(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) 89 | elif optimize == 'Adamax': 90 | optimizer = torch.optim.Adamax(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) 91 | elif optimize == 'ASGD': 92 | optimizer = torch.optim.ASGD(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) 93 | elif optimize == 'LBFGS': 94 | optimizer = torch.optim.LBFGS(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) 95 | elif optimize == 'RMSprop': 96 | optimizer = torch.optim.RMSprop(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005) 97 | 98 | infos['optimized'] = True 99 | 100 | # Load the optimizer 101 | if len(opt.start_from) != 0: 102 | if os.path.isfile(os.path.join(opt.start_from, opt.model_id + '.optimizer.pth')): 103 | optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, opt.model_id + '.optimizer.pth'))) 104 | 105 | return optimizer, infos 106 | 107 | def build_models(opt, infos): 108 | model = build_mil(opt) 109 | optimizer, infos = build_optimizer(opt, model, infos) 110 | crit = Criterion() # Training with RL, then add reward crit 111 | model.cuda() 112 | model.train() # Assure in training mode 113 | return model, crit, optimizer, infos 114 | 115 | def load_models(opt, infos): 116 | model = build_mil(opt) 117 | crit = Criterion(opt) 118 | return model, crit -------------------------------------------------------------------------------- /model/readme.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/model/readme.md -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | "3x3 convolution with padding" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | def __init__(self, block, layers, num_classes=1000): 98 | self.inplanes = 64 99 | super(ResNet, self).__init__() 100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 101 | bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 109 | self.avgpool = nn.AvgPool2d(7) 110 | self.fc = nn.Linear(512 * block.expansion, num_classes) 111 | 112 | for m in self.modules(): 113 | if isinstance(m, nn.Conv2d): 114 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 115 | m.weight.data.normal_(0, math.sqrt(2. / n)) 116 | elif isinstance(m, nn.BatchNorm2d): 117 | m.weight.data.fill_(1) 118 | m.bias.data.zero_() 119 | 120 | def _make_layer(self, block, planes, blocks, stride=1): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | nn.Conv2d(self.inplanes, planes * block.expansion, 125 | kernel_size=1, stride=stride, bias=False), 126 | nn.BatchNorm2d(planes * block.expansion), 127 | ) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.maxpool(x) 142 | 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | 148 | x = self.avgpool(x) 149 | x = x.view(x.size(0), -1) 150 | x = self.fc(x) 151 | 152 | return x 153 | 154 | 155 | def resnet18(pretrained=False): 156 | """Constructs a ResNet-18 model. 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | """ 161 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 162 | if pretrained: 163 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 164 | return model 165 | 166 | 167 | def resnet34(pretrained=False): 168 | """Constructs a ResNet-34 model. 169 | 170 | Args: 171 | pretrained (bool): If True, returns a model pre-trained on ImageNet 172 | """ 173 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 174 | if pretrained: 175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 176 | return model 177 | 178 | 179 | def resnet50(pretrained=False): 180 | """Constructs a ResNet-50 model. 181 | 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 186 | if pretrained: 187 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 188 | return model 189 | 190 | 191 | def resnet101(pretrained=False): 192 | """Constructs a ResNet-101 model. 193 | 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 198 | if pretrained: 199 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 200 | return model 201 | 202 | 203 | def resnet152(pretrained=False): 204 | """Constructs a ResNet-152 model. 205 | 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 212 | return model -------------------------------------------------------------------------------- /model/resnet_mil.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | 7 | class resnet_mil(nn.Module): 8 | def __init__(self, opt): 9 | super(resnet_mil, self).__init__() 10 | import model.resnet as resnet 11 | resnet = resnet.resnet101() 12 | resnet.load_state_dict(torch.load('/media/jxgu/d2tb/model/resnet/resnet101.pth')) 13 | self.conv = torch.nn.Sequential() 14 | self.conv.add_module("conv1", resnet.conv1) 15 | self.conv.add_module("bn1", resnet.bn1) 16 | self.conv.add_module("relu", resnet.relu) 17 | self.conv.add_module("maxpool", resnet.maxpool) 18 | self.conv.add_module("layer1", resnet.layer1) 19 | self.conv.add_module("layer2", resnet.layer2) 20 | self.conv.add_module("layer3", resnet.layer3) 21 | self.conv.add_module("layer4", resnet.layer4) 22 | self.l1 = nn.Sequential(nn.Linear(2048, 1000), 23 | nn.ReLU(True), 24 | nn.Dropout(0.5)) 25 | self.att_size = 7 26 | self.pool_mil = nn.MaxPool2d(kernel_size=self.att_size, stride=0) 27 | 28 | def forward(self, img, att_size=14): 29 | x0 = self.conv(img) 30 | x = self.pool_mil(x0) 31 | x = x.squeeze(2).squeeze(2) 32 | x = self.l1(x) 33 | x1 = torch.add(torch.mul(x.view(x.size(0), 1000, -1), -1), 1) 34 | cumprod = torch.cumprod(x1, 2) 35 | out = torch.max(x, torch.add(torch.mul(cumprod[:, :, -1], -1), 1)) 36 | return out 37 | -------------------------------------------------------------------------------- /model/resnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | 6 | 7 | class myResnet(nn.Module): 8 | def __init__(self, resnet): 9 | super(myResnet, self).__init__() 10 | self.resnet = resnet 11 | 12 | def forward(self, img, att_size=14): 13 | x = img.unsqueeze(0) 14 | 15 | x = self.resnet.conv1(x) 16 | x = self.resnet.bn1(x) 17 | x = self.resnet.relu(x) 18 | x = self.resnet.maxpool(x) 19 | 20 | x = self.resnet.layer1(x) 21 | x = self.resnet.layer2(x) 22 | x = self.resnet.layer3(x) 23 | x = self.resnet.layer4(x) 24 | 25 | fc = x.mean(3).mean(2) 26 | att = F.adaptive_avg_pool2d(x, [att_size, att_size]).squeeze().permute(1, 2, 0) 27 | 28 | return fc, att 29 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import collections 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import time 10 | import Image 11 | import os 12 | import os.path as osp 13 | import sys 14 | import platform 15 | import cPickle 16 | import urllib 17 | import cv2, numpy as np 18 | from scipy.interpolate import interp1d 19 | from matplotlib.pyplot import show 20 | import matplotlib.pyplot as plt 21 | import torch 22 | import torch.nn as nn 23 | import torch.optim as optim 24 | import torch.nn.functional as F 25 | import torch.backends.cudnn as cudnn 26 | import tensorflow as tf 27 | import torchvision 28 | import torchvision.transforms as transforms 29 | from six.moves import cPickle 30 | from model.vgg_mil import * 31 | from model.resnet_mil import * 32 | import itertools 33 | 34 | ''' 35 | -- Learning rate annealing schedule. We will build a new optimizer for 36 | -- each epoch. 37 | -- 38 | -- By default we follow a known recipe for a 55-epoch training. If 39 | -- the learningRate command-line parameter has been specified, though, 40 | -- we trust the user is doing something manual, and will use her 41 | -- exact settings for all optimization. 42 | -- 43 | -- Return values: 44 | -- diff to apply to optimState, 45 | -- true IFF this is the first epoch of a new regime 46 | ''' 47 | def set_lr(optimizer, lr): 48 | for group in optimizer.param_groups: 49 | group['lr'] = lr 50 | 51 | def set_weightDecay(optimizer, weightDecay): 52 | for group in optimizer.param_groups: 53 | group['weight_decay'] = weightDecay 54 | 55 | def update_lr(opt, epoch, optimizer): 56 | if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: 57 | frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every 58 | decay_factor = opt.learning_rate_decay_rate ** frac 59 | opt.current_lr = opt.learning_rate * decay_factor 60 | set_lr(optimizer, opt.current_lr) # set the decayed rate 61 | else: 62 | opt.current_lr = opt.learning_rate 63 | 64 | def paramsForEpoch(opt, epoch, optimizer): 65 | # start, end, LR, WD, 66 | regimes = [ [ 0, 5, 1e-2, 5e-4, ], 67 | [ 6, 10, 5e-3, 5e-4 ], 68 | [ 11, 20, 1e-3, 0 ], 69 | [ 31, 30, 5e-4, 0 ], 70 | [ 31, 1e8, 1e-4, 0 ]] 71 | for row in regimes: 72 | if epoch>=row[0] and epoch<= row[1]: 73 | learningRate = row[2] 74 | weightDecay = row[3] 75 | opt.learning_rate = learningRate 76 | opt.weight_decay = weightDecay 77 | set_lr(optimizer, learningRate) 78 | set_weightDecay(optimizer, weightDecay) 79 | 80 | def add_summary_value(writer, key, value, iteration): 81 | summary = tf.Summary(value=[tf.Summary.Value(tag=key, simple_value=value)]) 82 | writer.add_summary(summary, iteration) 83 | 84 | def history_infos(opt): 85 | infos = {} 86 | if len(opt.start_from) != 0: # open old infos and check if models are compatible 87 | model_id = opt.start_from 88 | infos_id = model_id.replace('save/', '') + '.infos-best.pkl' 89 | with open(os.path.join(opt.start_from, infos_id)) as f: 90 | infos = cPickle.load(f) 91 | saved_model_opt = infos['opt'] 92 | 93 | iteration = infos.get('iter', 0) 94 | epoch = infos.get('epoch', 0) 95 | val_result_history = infos.get('val_result_history', {}) 96 | loss_history = infos.get('loss_history', {}) 97 | lr_history = infos.get('lr_history', {}) 98 | best_val_score = infos.get('best_val_score', None) if opt.load_best_score == 1 else 0 99 | val_loss = 0.0 100 | val_history = [val_result_history, best_val_score, val_loss] 101 | train_history = [loss_history, lr_history] 102 | return opt, infos, iteration, epoch, val_history, train_history 103 | 104 | 105 | def add_path(path): 106 | if path not in sys.path: 107 | sys.path.insert(0, path) 108 | print('added {}'.format(path)) 109 | 110 | def save_variables(pickle_file_name, var, info, overwrite=False): 111 | if os.path.exists(pickle_file_name) and overwrite == False: 112 | raise Exception('{:s} exists and over write is false.'.format(pickle_file_name)) 113 | # Construct the dictionary 114 | assert (type(var) == list); 115 | assert (type(info) == list); 116 | d = {} 117 | for i in xrange(len(var)): 118 | d[info[i]] = var[i] 119 | with open(pickle_file_name, 'wb') as f: 120 | cPickle.dump(d, f, cPickle.HIGHEST_PROTOCOL) 121 | 122 | 123 | def load_variables(pickle_file_name): 124 | # d is a dictionary of variables stored in the pickle file. 125 | if os.path.exists(pickle_file_name): 126 | with open(pickle_file_name, 'rb') as f: 127 | d = cPickle.load(f) 128 | return d 129 | else: 130 | raise Exception('{:s} does not exists.'.format(pickle_file_name)) 131 | 132 | def to_contiguous(tensor): 133 | if tensor.is_contiguous(): 134 | return tensor 135 | else: 136 | return tensor.contiguous() 137 | 138 | def clip_gradient(optimizer, grad_clip): 139 | for group in optimizer.param_groups: 140 | for param in group['params']: 141 | param.grad.data.clamp_(-grad_clip, grad_clip) 142 | 143 | # METHOD #1: OpenCV, NumPy, and urllib 144 | def url_to_image(url): 145 | # download the image, convert it to a NumPy array, and then read 146 | # it into OpenCV format 147 | resp = urllib.urlopen(url) 148 | image = np.asarray(bytearray(resp.read()), dtype="uint8") 149 | image = cv2.imdecode(image, cv2.IMREAD_COLOR) 150 | 151 | # return the image 152 | return image 153 | 154 | def upsample_image(im, sz): 155 | h = im.shape[0] 156 | w = im.shape[1] 157 | s = np.float(max(h, w)) 158 | I_out = np.zeros((sz, sz, 3), dtype=np.float); 159 | I = cv2.resize(im, None, None, fx=np.float(sz) / s, fy=np.float(sz) / s, interpolation=cv2.INTER_LINEAR); 160 | SZ = I.shape; 161 | I_out[0:I.shape[0], 0:I.shape[1], :] = I; 162 | return I_out, I, SZ 163 | 164 | def compute_precision_score_mapping_torch(thresh, prec, score): 165 | thresh, ind_thresh = torch.sort(torch.from_numpy(thresh), 0, descending=False) 166 | 167 | prec, ind_prec = torch.sort(torch.from_numpy(prec), 0, descending=False) 168 | val = None 169 | return val 170 | 171 | def compute_precision_mapping(pt): 172 | thresh_all = [] 173 | prec_all = [] 174 | for jj in xrange(1000): 175 | thresh = pt['details']['score'][:, jj] 176 | prec = pt['details']['precision'][:, jj] 177 | ind = np.argsort(thresh); # thresh, ind = torch.sort(thresh) 178 | thresh = thresh[ind]; 179 | indexes = np.unique(thresh, return_index=True)[1] 180 | indexes = np.sort(indexes); 181 | thresh = thresh[indexes] 182 | 183 | thresh = np.vstack((min(-1000, min(thresh) - 1), thresh[:, np.newaxis], max(1000, max(thresh) + 1))); 184 | 185 | prec = prec[ind]; 186 | for i in xrange(1, len(prec)): 187 | prec[i] = max(prec[i], prec[i - 1]); 188 | prec = prec[indexes] 189 | 190 | prec = np.vstack((prec[0], prec[:, np.newaxis], prec[-1])); 191 | thresh_all.append(thresh) 192 | prec_all.append(prec) 193 | precision_score = {'thresh': thresh_all, "prec": prec_all} 194 | return precision_score 195 | 196 | def compute_precision_score_mapping(thresh, prec, score): 197 | ind = np.argsort(thresh); # thresh, ind = torch.sort(thresh) 198 | thresh = thresh[ind]; 199 | indexes = np.unique(thresh, return_index=True)[1] 200 | indexes = np.sort(indexes); 201 | thresh = thresh[indexes] 202 | 203 | thresh = np.vstack((min(-1000, min(thresh) - 1), thresh[:, np.newaxis], max(1000, max(thresh) + 1))); 204 | 205 | prec = prec[ind]; 206 | for i in xrange(1, len(prec)): 207 | prec[i] = max(prec[i], prec[i - 1]); 208 | prec = prec[indexes] 209 | 210 | prec = np.vstack((prec[0], prec[:, np.newaxis], prec[-1])); 211 | 212 | f = interp1d(thresh[:, 0], prec[:, 0]) 213 | val = f(score) 214 | return val 215 | 216 | def load_vocabulary(): 217 | # Load the vocabulary 218 | vocab_file = os.getcwd()+'/vocabs/vocab_train.pkl' 219 | vocab = load_variables(vocab_file) 220 | 221 | # define functional words 222 | functional_words = ['a', 'on', 'of', 'the', 'in', 'with', 'and', 'is', 'to', 'an', 'two', 'at', 'next', 'are'] 223 | is_functional = np.array([x not in functional_words for x in vocab['words']]) 224 | 225 | # load the score precision mapping file 226 | eval_file = os.getcwd()+'/model/coco_valid1_eval.pkl' 227 | pt = load_variables(eval_file) 228 | return vocab, functional_words, is_functional, pt 229 | 230 | def tic_toc_print(interval, string): 231 | global tic_toc_print_time_old 232 | if 'tic_toc_print_time_old' not in globals(): 233 | tic_toc_print_time_old = time.time() 234 | print(string) 235 | else: 236 | new_time = time.time() 237 | if new_time - tic_toc_print_time_old > interval: 238 | tic_toc_print_time_old = new_time; 239 | print(string) 240 | 241 | -------------------------------------------------------------------------------- /model/vgg_mil.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from torch.autograd import Variable 5 | from torch.utils.serialization import load_lua 6 | import torch.nn.functional as F 7 | 8 | class vgg_mil(nn.Module): 9 | def __init__(self, opt): 10 | super(vgg_mil, self).__init__() 11 | self.conv = torch.nn.Sequential() 12 | self.conv.add_module("conv1_1", nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)) 13 | self.conv.add_module("relu_1_1", torch.nn.ReLU()) 14 | self.conv.add_module("conv1_2", nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)) 15 | self.conv.add_module("relu_1_2", torch.nn.ReLU()) 16 | self.conv.add_module("maxpool_1", torch.nn.MaxPool2d(kernel_size=2)) 17 | 18 | self.conv.add_module("conv2_1", nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)) 19 | self.conv.add_module("relu_2_1", torch.nn.ReLU()) 20 | self.conv.add_module("conv2_2", nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)) 21 | self.conv.add_module("relu_2_2", torch.nn.ReLU()) 22 | self.conv.add_module("maxpool_2", torch.nn.MaxPool2d(kernel_size=2)) 23 | 24 | self.conv.add_module("conv3_1", nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)) 25 | self.conv.add_module("relu_3_1", torch.nn.ReLU()) 26 | self.conv.add_module("conv3_2", nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)) 27 | self.conv.add_module("relu_3_2", torch.nn.ReLU()) 28 | self.conv.add_module("conv3_3", nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)) 29 | self.conv.add_module("relu_3_3", torch.nn.ReLU()) 30 | self.conv.add_module("maxpool_3", torch.nn.MaxPool2d(kernel_size=2)) 31 | 32 | self.conv.add_module("conv4_1", nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)) 33 | self.conv.add_module("relu_4_1", torch.nn.ReLU()) 34 | self.conv.add_module("conv4_2", nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)) 35 | self.conv.add_module("relu_4_2", torch.nn.ReLU()) 36 | self.conv.add_module("conv4_3", nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)) 37 | self.conv.add_module("relu_4_3", torch.nn.ReLU()) 38 | self.conv.add_module("maxpool_4", torch.nn.MaxPool2d(kernel_size=2)) 39 | 40 | self.conv.add_module("conv5_1", nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)) 41 | self.conv.add_module("relu_5_1", torch.nn.ReLU()) 42 | self.conv.add_module("conv5_2", nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)) 43 | self.conv.add_module("relu_5_2", torch.nn.ReLU()) 44 | self.conv.add_module("conv5_3", nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)) 45 | self.conv.add_module("relu_5_3", torch.nn.ReLU()) 46 | self.conv.add_module("maxpool_5", torch.nn.MaxPool2d(kernel_size=2)) 47 | 48 | self.conv.add_module("fc6_conv", nn.Conv2d(512, 4096, kernel_size=7, stride=1, padding=0)) 49 | self.conv.add_module("relu_6_1", torch.nn.ReLU()) 50 | 51 | self.conv.add_module("fc7_conv", nn.Conv2d(4096, 4096, kernel_size=1, stride=1, padding=0)) 52 | self.conv.add_module("relu_7_1", torch.nn.ReLU()) 53 | 54 | self.conv.add_module("fc8_conv", nn.Conv2d(4096, 1000, kernel_size=1, stride=1, padding=0)) 55 | self.conv.add_module("sigmoid_8", torch.nn.Sigmoid()) 56 | 57 | self.pool_mil = nn.MaxPool2d(kernel_size=11, stride=0) 58 | 59 | self.weight_init() 60 | 61 | def weight_init(self): 62 | self.cnn_weight = 'model/vgg16_full_conv_mil.pth' 63 | self.conv.load_state_dict(torch.load(self.cnn_weight)) 64 | print("Load pretrained CNN model from " + self.cnn_weight) 65 | 66 | def forward(self, x): 67 | x0 = self.conv.forward(x.float()) 68 | x = self.pool_mil(x0) 69 | x = x.squeeze(2).squeeze(2) 70 | x1 = torch.add(torch.mul(x0.view(x.size(0), 1000, -1), -1), 1) 71 | cumprod = torch.cumprod(x1, 2) 72 | out = torch.max(x, torch.add(torch.mul(cumprod[:, :, -1], -1), 1)) 73 | out = F.softmax(out) 74 | return out 75 | 76 | class MIL_Precision_Score_Mapping(nn.Module): 77 | def __init__(self): 78 | super(MIL_Precision_Score_Mapping, self).__init__() 79 | self.mil = nn.MaxPool2d(kernel_size=11, stride=0) 80 | 81 | def forward(self, x, score, precision, mil_prob): 82 | out = self.mil(x) 83 | return out 84 | 85 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | 5 | 6 | def parse_opt(): 7 | parser = argparse.ArgumentParser() 8 | # RL setting 9 | parser.add_argument('--model', type=str, default='vgg19') 10 | parser.add_argument('--learning', type=str, default='mle') 11 | parser.add_argument('--start_from', type=str, default='') 12 | # Data input settings 13 | parser.add_argument('--fc_feat_size', type=int, default=2048) # '2048 for resnet, 4096 for vgg' 14 | parser.add_argument('--att_feat_size', type=int, default=2048) # '2048 for resnet, 512 for vgg' 15 | # Optimization: General 16 | parser.add_argument('--max_epochs', type=int, default=-1) # 'number of epochs' 17 | parser.add_argument('--batch_size', type=int, default=2) # 'minibatch size' 18 | parser.add_argument('--seq_per_img', type=int,default=5) # number of captions to sample for each image during training. 19 | # Optimization: for the Language Model 20 | parser.add_argument('--optim', type=str, default='adam') # rmsprop|sgd|sgdmom|adagrad|adam 21 | parser.add_argument('--learning_rate', type=float, default=4e-4) # 'learning rate' 22 | parser.add_argument('--learning_rate_decay_start', type=int,default=0) # at what iteration to start decaying learning rate? (-1 = dont) (in epoch) 23 | parser.add_argument('--learning_rate_decay_every', type=int,default=5000) # every how many iterations thereafter to drop LR?(in epoch) 24 | parser.add_argument('--learning_rate_decay_rate', type=float,default=0.8) # every how many iterations thereafter to drop LR?(in epoch) 25 | parser.add_argument('--optim_alpha', type=float, default=0.8) # alpha for adam 26 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M') 27 | parser.add_argument('--optim_beta', type=float, default=0.999) # beta used for adam 28 | parser.add_argument('--optim_epsilon', type=float, default=1e-8) # epsilon that goes into denominator for smoothing 29 | parser.add_argument('--weight_decay', type=float, default=1e-4) # weight_decay 30 | parser.add_argument('--grad_clip', type=float, default=0.1) # clip gradients at this value 31 | parser.add_argument('--drop_prob_lm', type=float, default=0.5) # strength of dropout in the Language Model RNN 32 | # Datasets 33 | parser.add_argument('--input_json', type=str, default='data/mscoco/cocotalk_karpathy.json') 34 | parser.add_argument('--input_im_h5', type=str, default='data/mscoco/cocotalk_karpathy.h5') 35 | parser.add_argument('--input_label_h5', type=str, default='data/mscoco/cocotalk_karpathy_label_semantic_words.h5') 36 | # Evaluation/Checkpointing 37 | parser.add_argument('--split', type=str, default='train') # Dataset split type 38 | parser.add_argument('--val_images_use', type=int, default=5000) # number of images for period validation (-1 = all) 39 | parser.add_argument('--save_checkpoint_every', type=int,default=100) 40 | parser.add_argument('--checkpoint_path', type=str, default='save') # directory to store checkpointed models' 41 | parser.add_argument('--losses_log_every', type=int, default=25) # How often do we snapshot losses, (0 = disable) 42 | parser.add_argument('--load_best_score', type=int, default=1) # load previous best score when resuming training. 43 | # misc 44 | parser.add_argument('--n_gpus', type=int, default=1) 45 | parser.add_argument('--train_only', type=int, default=0) # If true then use 80k, else use 110k 46 | parser.add_argument('--gpus', default=[0, 1], nargs='+', type=int) # Use CUDA on the listed devices 47 | parser.add_argument('--model_id', type=str, default='') # Id identifying this run/job. 48 | # used in cross-val and appended when writing progress files' 49 | 50 | args = parser.parse_args() 51 | # Check if args are valid 52 | assert args.batch_size > 0, "batch_size should be greater than 0" 53 | assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1" 54 | assert args.seq_per_img > 0, "seq_per_img should be greater than 0" 55 | assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0" 56 | assert args.losses_log_every > 0, "losses_log_every should be greater than 0" 57 | assert args.load_best_score == 0 or args.load_best_score == 1, "should be 0 or 1" 58 | 59 | # Update args 60 | args.gpus = range(args.n_gpus) 61 | last_name = os.path.basename(args.start_from) 62 | last_time = last_name[0:8] 63 | if len(args.start_from): 64 | args.model_id = last_name 65 | else: 66 | args.model_id = datetime.datetime.now().strftime("%m%d%H%M") + "_mil_" + args.model + '_' + args.learning 67 | args.checkpoint_path = args.checkpoint_path + '/' + args.model_id 68 | return args 69 | -------------------------------------------------------------------------------- /scripts/convert_tf2pth.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | # pip install mmdnn 3 | convert_tf_2_pytorch() 4 | { 5 | if [ ! -d tmp ]; then 6 | mkdir tmp 7 | fi 8 | # Extract tf model files 9 | python -m mmdnn.conversion.examples.tensorflow.extract_model -n resnet_v1_101 -ckpt /home/jxgu/github/MIL.pytorch/model/oidv2_resnet_v1_101/oidv2-resnet_v1_101.ckpt 10 | # Convert tf to IR 11 | python -m mmdnn.conversion._script.convertToIR -f tensorflow -d kit_imagenet -n imagenet_resnet_v1_101.ckpt.meta --dstNodeName Squeeze -w imagenet_resnet_v1_101.ckpt 12 | # Convert IR to Pytorch 13 | python -m mmdnn.conversion._script.IRToCode -f pytorch --IRModelPath kit_imagenet.pb --dstModelPath kit_imagenet.py --IRWeightPath kit_imagenet.npy -dw kit_pytorch.npy 14 | # Dump the PyTorch model 15 | python -m mmdnn.conversion.examples.pytorch.imagenet_test --dump resnet.pth -n kit_imagenet.py -w tmp/kit_pytorch.npy 16 | } 17 | 18 | convert_tf_2_pytorch -------------------------------------------------------------------------------- /scripts/graphs/events.out.tfevents.1525752019.jxgu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/scripts/graphs/events.out.tfevents.1525752019.jxgu -------------------------------------------------------------------------------- /scripts/graphs/events.out.tfevents.1525752072.jxgu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/scripts/graphs/events.out.tfevents.1525752072.jxgu -------------------------------------------------------------------------------- /scripts/graphs/events.out.tfevents.1525752099.jxgu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/scripts/graphs/events.out.tfevents.1525752099.jxgu -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """Set up paths.""" 2 | import os 3 | import os.path as osp 4 | import sys 5 | import platform 6 | import cPickle 7 | import cv2, numpy as np 8 | from matplotlib.pyplot import show 9 | import matplotlib.pyplot as plt 10 | from scipy.interpolate import interp1d 11 | import time 12 | import Image 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torch.nn.functional as F 17 | import torch.backends.cudnn as cudnn 18 | import tensorflow as tf 19 | import torchvision 20 | import torchvision.transforms as transforms 21 | from six.moves import cPickle 22 | import gc 23 | import os 24 | import pickle 25 | import argparse 26 | from model.models import * 27 | from model.utils import * 28 | import coco_voc 29 | 30 | this_dir = osp.dirname(__file__) 31 | ############################################################################################## 32 | 33 | def test_img(im, net, base_image_size, means): 34 | """ 35 | Calls Caffe to get output for this image 36 | """ 37 | batch_size = 1 38 | # Resize image 39 | im_orig = im.astype(np.float32, copy=True) 40 | im_orig -= means 41 | 42 | im, gr, grr = upsample_image(im_orig, base_image_size) 43 | im = np.transpose(im, axes=(2, 0, 1)) 44 | im = im[np.newaxis, :, :, :] 45 | 46 | # Pass into model 47 | mil_prob = net(Variable(torch.from_numpy(im), requires_grad=False).cuda()) 48 | return mil_prob 49 | 50 | 51 | def output_words_image(threshold_metric, output_metric, min_words, threshold, vocab, is_functional): 52 | ind_output = np.argsort(threshold_metric) 53 | ind_output = ind_output[::-1] 54 | must_keep1 = threshold_metric[ind_output] >= threshold; 55 | must_keep2 = np.cumsum(is_functional[ind_output]) < 1 + min_words; 56 | ind_output = [ind for j, ind in enumerate(ind_output) if must_keep1[j] or must_keep2[j]] 57 | out = [(vocab['words'][ind], output_metric[ind], threshold_metric[ind]) for ind in ind_output] 58 | return out 59 | 60 | ############################################################################################## 61 | 62 | '''load vocabulary''' 63 | vocab, functional_words, is_functional, pt = load_vocabulary() 64 | 65 | parser = argparse.ArgumentParser(description='PyTorch MIL Training') 66 | parser.add_argument('--start_from', type=str, default='') 67 | parser.add_argument('--load_precision_score', type=str, default='') 68 | parser.add_argument('--cnn_weight', default='model/mil.pth', 69 | help='cnn weights') 70 | opt = parser.parse_args() 71 | 72 | mil_model = vgg_mil(opt) 73 | mil_model.cuda() 74 | mil_model.eval() 75 | 76 | 77 | '''load caffe model''' 78 | mean = np.array([[[103.939, 116.779, 123.68]]]); 79 | base_image_size = 565; 80 | 81 | '''Load the image''' 82 | imageurl = 'http://img1.10bestmedia.com/Images/Photos/333810/Montrose_54_990x660.jpg' 83 | im = url_to_image(imageurl) 84 | im = cv2.resize(im, (base_image_size, base_image_size), interpolation=cv2.INTER_CUBIC) 85 | 86 | # Run the model 87 | mil_prob = test_img(im, mil_model, base_image_size, mean) 88 | mil_prob = mil_prob.data.cpu().float().numpy() 89 | # Compute precision mapping - slow in per image mode, much faster in batch mode 90 | prec = np.zeros(mil_prob.shape) 91 | if len(opt.load_precision_score) >0 : 92 | precision_score = pickle.load(open(opt.load_precision_score, 'rb')) 93 | else: 94 | precision_score = compute_precision_mapping(pt) 95 | 96 | for jj in xrange(prec.shape[1]): 97 | f = interp1d(precision_score['thresh'][jj][:,0], precision_score['prec'][jj][:,0]) 98 | #prec[:, jj] = f(mil_prob[:, jj]) 99 | prec[:, jj] = mil_prob[:, jj] 100 | mil_prec = prec 101 | 102 | #cv2.imshow('image', im) 103 | # Output words 104 | out = output_words_image(mil_prec[0, :], mil_prec[0, :], \ 105 | min_words=10, threshold=0.0, vocab=vocab, is_functional=is_functional) 106 | 107 | plt.rcParams['figure.figsize'] = (10, 10) 108 | plt.imshow(im[:, :, [2, 1, 0]]) 109 | plt.gca().set_axis_off() 110 | det_atts = [] 111 | det_atts_w = [] 112 | index = 0 113 | for (a, b, c) in out: 114 | if a not in functional_words: 115 | if index < 10: 116 | det_atts.append(a) 117 | det_atts_w.append(np.round(b, 2)) 118 | index = index + 1 119 | # print '{:s} [{:.2f}, {:.2f}] '.format(a, np.round(b,2), np.round(c,2)) 120 | 121 | print det_atts 122 | print det_atts_w 123 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | clear 4 | #----------------------------------------------------------------------------------------------------------------------- 5 | func_cls() 6 | { 7 | CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split coco_val2014 8 | CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split coco_test2014 9 | CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split coco_train2014 10 | #CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split coco_test2015 11 | #CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split chinese 12 | #CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split chinese_val 13 | } 14 | 15 | func_det() 16 | { export PYTHONPATH=$PYTHONPATH:/home/jxgu/github/MIL.pytorch/misc/models/research/object_detection 17 | CUDA_VISIBLE_DEVICES=1 python extract_det_oid.py --split coco_val2014 18 | CUDA_VISIBLE_DEVICES=1 python extract_det_oid.py --split coco_test2014 19 | CUDA_VISIBLE_DEVICES=1 python extract_det_oid.py --split coco_train2014 20 | CUDA_VISIBLE_DEVICES=1 python extract_det_oid.py --split coco_test2015 21 | } 22 | func_cls 23 | -------------------------------------------------------------------------------- /test_v2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright 2017 The Open Images Authors. All Rights Reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | r"""Classifier inference utility. 18 | 19 | This code takes a resnet_v1_101 checkpoint, runs the classifier on the image and 20 | prints predictions in human-readable form. 21 | 22 | ------------------------------- 23 | Example command: 24 | ------------------------------- 25 | 26 | # 0. Create directory for model/data 27 | WORK_PATH="/tmp/oidv2" 28 | mkdir -p "${WORK_PATH}" 29 | cd "${WORK_PATH}" 30 | 31 | # 1. Download the model, inference code, and sample image 32 | wget https://storage.googleapis.com/openimages/2017_07/classes-trainable.txt 33 | wget https://storage.googleapis.com/openimages/2017_07/class-descriptions.csv 34 | wget https://storage.googleapis.com/openimages/2017_07/oidv2-resnet_v1_101.ckpt.tar.gz 35 | wget https://raw.githubusercontent.com/openimages/dataset/master/tools/classify_oidv2.py 36 | tar -xzf oidv2-resnet_v1_101.ckpt.tar.gz 37 | 38 | wget -O cat.jpg https://farm6.staticflickr.com/5470/9372235876_d7d69f1790_b.jpg 39 | 40 | # 2. Run inference 41 | python classify_oidv2.py \ 42 | --checkpoint_path='oidv2-resnet_v1_101.ckpt' \ 43 | --labelmap='classes-trainable.txt' \ 44 | --dict='class-descriptions.csv' \ 45 | --image="cat.jpg" \ 46 | --top_k=10 \ 47 | --score_threshold=0.3 48 | 49 | # Sample output: 50 | Image: "cat.jpg" 51 | 52 | 3272: /m/068hy - Pet (score = 0.96) 53 | 1076: /m/01yrx - Cat (score = 0.95) 54 | 0708: /m/01l7qd - Whiskers (score = 0.90) 55 | 4755: /m/0jbk - Animal (score = 0.90) 56 | 2847: /m/04rky - Mammal (score = 0.89) 57 | 2036: /m/0307l - Felidae (score = 0.79) 58 | 3574: /m/07k6w8 - Small to medium-sized cats (score = 0.77) 59 | 4799: /m/0k0pj - Nose (score = 0.70) 60 | 1495: /m/02cqfm - Close-up (score = 0.55) 61 | 0036: /m/012c9l - Domestic short-haired cat (score = 0.40) 62 | 63 | ------------------------------- 64 | Note on image preprocessing: 65 | ------------------------------- 66 | 67 | This is the code used to perform preprocessing: 68 | -------- 69 | from preprocessing import preprocessing_factory 70 | 71 | def PreprocessImage(image, network='resnet_v1_101', image_size=299): 72 | # If resolution is larger than 224 we need to adjust some internal resizing 73 | # parameters for vgg preprocessing. 74 | if any(network.startswith(x) for x in ['resnet', 'vgg']): 75 | preprocessing_kwargs = { 76 | 'resize_side_min': int(256 * image_size / 224), 77 | 'resize_side_max': int(512 * image_size / 224) 78 | } 79 | else: 80 | preprocessing_kwargs = {} 81 | preprocessing_fn = preprocessing_factory.get_preprocessing( 82 | name=network, is_training=False) 83 | 84 | height = image_size 85 | width = image_size 86 | image = preprocessing_fn(image, height, width, **preprocessing_kwargs) 87 | image.set_shape([height, width, 3]) 88 | return image 89 | -------- 90 | 91 | Note that there appears to be a small difference between the public version 92 | of slim image processing library and the internal version (which the meta 93 | graph is based on). Results that are very close, but not exactly identical to 94 | that of the metagraph. 95 | """ 96 | 97 | from __future__ import absolute_import 98 | from __future__ import division 99 | from __future__ import print_function 100 | 101 | import tensorflow as tf 102 | 103 | flags = tf.app.flags 104 | FLAGS = flags.FLAGS 105 | 106 | flags.DEFINE_string('labelmap', 'classes-trainable.txt', 107 | 'Labels, one per line.') 108 | 109 | flags.DEFINE_string('dict', 'class-descriptions.csv', 110 | 'Descriptive string for each label.') 111 | 112 | flags.DEFINE_string('checkpoint_path', 'oidv2-resnet_v1_101.ckpt', 113 | 'Path to checkpoint file.') 114 | 115 | flags.DEFINE_string('image', '', 116 | 'Comma separated paths to image files on which to perform ' 117 | 'inference.') 118 | 119 | flags.DEFINE_integer('top_k', 10, 'Maximum number of results to show.') 120 | 121 | flags.DEFINE_float('score_threshold', None, 'Score threshold.') 122 | 123 | 124 | def LoadLabelMap(labelmap_path, dict_path): 125 | """Load index->mid and mid->display name maps. 126 | 127 | Args: 128 | labelmap_path: path to the file with the list of mids, describing 129 | predictions. 130 | dict_path: path to the dict.csv that translates from mids to display names. 131 | Returns: 132 | labelmap: an index to mid list 133 | label_dict: mid to display name dictionary 134 | """ 135 | labelmap = [line.rstrip() for line in tf.gfile.GFile(labelmap_path)] 136 | 137 | label_dict = {} 138 | for line in tf.gfile.GFile(dict_path): 139 | words = [word.strip(' "\n') for word in line.split(',', 1)] 140 | label_dict[words[0]] = words[1] 141 | 142 | return labelmap, label_dict 143 | 144 | 145 | def main(_): 146 | # Load labelmap and dictionary from disk. 147 | labelmap, label_dict = LoadLabelMap(FLAGS.labelmap, FLAGS.dict) 148 | 149 | g = tf.Graph() 150 | with g.as_default(): 151 | with tf.Session() as sess: 152 | saver = tf.train.import_meta_graph(FLAGS.checkpoint_path + '.meta') 153 | saver.restore(sess, FLAGS.checkpoint_path) 154 | 155 | input_values = g.get_tensor_by_name('input_values:0') 156 | predictions = g.get_tensor_by_name('multi_predictions:0') 157 | 158 | for image_filename in FLAGS.image.split(','): 159 | compressed_image = tf.gfile.FastGFile(image_filename, 'rb').read() 160 | predictions_eval = sess.run( 161 | predictions, feed_dict={ 162 | input_values: [compressed_image] 163 | }) 164 | top_k = predictions_eval.argsort()[::-1] # indices sorted by score 165 | if FLAGS.top_k > 0: 166 | top_k = top_k[:FLAGS.top_k] 167 | if FLAGS.score_threshold is not None: 168 | top_k = [i for i in top_k 169 | if predictions_eval[i] >= FLAGS.score_threshold] 170 | print('Image: "%s"\n' % image_filename) 171 | for idx in top_k: 172 | mid = labelmap[idx] 173 | display_name = label_dict[mid] 174 | score = predictions_eval[idx] 175 | print('{:04d}: {} - {} (score = {:.2f})'.format( 176 | idx, mid, display_name, score)) 177 | 178 | 179 | if __name__ == '__main__': 180 | tf.app.run() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Use tensorboard 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import functools 8 | import os 9 | import time 10 | from six.moves import cPickle 11 | from dataloader import * 12 | from model import * 13 | import tensorflow as tf 14 | import torch.nn as nn 15 | import torch.utils.model_zoo as model_zoo 16 | import numpy as np 17 | import math 18 | import torch 19 | import torch.nn.init as init 20 | import torch.nn.functional as F 21 | from torch.autograd import Variable 22 | import opts 23 | from model import eval_utils 24 | from model import utils 25 | from model import models 26 | 27 | rusage_denom = 1024 28 | printf = functools.partial(print, end="") 29 | 30 | def extract_fts(opt, data): 31 | images = Variable(torch.from_numpy(data['images']), volatile=False).cuda() 32 | mil_label = Variable(torch.from_numpy(data['mil_label']),volatile=False).cuda() 33 | return images, mil_label 34 | 35 | def record_training(opt, model, iteration, tf_summary_writer, current_record, history_record): 36 | [train_loss] = current_record 37 | [loss_history, lr_history] = history_record 38 | utils.add_summary_value(tf_summary_writer, 'train_lr', opt.learning_rate, iteration) 39 | utils.add_summary_value(tf_summary_writer, 'train_weight_decay', opt.weight_decay, iteration) 40 | utils.add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration) 41 | utils.add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration) 42 | tf_summary_writer.flush() 43 | loss_history[iteration] = train_loss 44 | lr_history[iteration] = opt.current_lr 45 | return history_record 46 | 47 | def record_ckpt(opt, infos, model, optimizer, best_flag): 48 | tag = '-best' if best_flag else '' 49 | print("Save language model start") 50 | checkpoint_path = os.path.join(opt.checkpoint_path, opt.model_id + '.model' + tag + '.pth') 51 | optimizer_path = os.path.join(opt.checkpoint_path, opt.model_id + '.optimizer' + tag + '.pth') 52 | torch.save(model.state_dict(), checkpoint_path) 53 | torch.save(optimizer.state_dict(), optimizer_path) 54 | print("Save infos start") 55 | with open(os.path.join(opt.checkpoint_path, opt.model_id + '.infos' + tag + '.pkl'), 'wb') as f: 56 | cPickle.dump(infos, f) 57 | print("model saved to {}".format(checkpoint_path)) 58 | 59 | def train(opt): 60 | print("Load dataset with image features, and labels\n") 61 | loader = DataLoader(opt) 62 | opt.vocab_size = loader.vocab_size 63 | opt.seq_length = loader.seq_length 64 | 65 | tf_summary_writer = tf.summary.FileWriter(opt.checkpoint_path) 66 | print("Load informations from infos.pkl ... ") 67 | opt, infos, iteration, epoch, val_history, train_history = utils.history_infos(opt) 68 | [loss_history, lr_history] = train_history 69 | [val_result_history, best_val_score, val_loss] = val_history 70 | 71 | # Update dataloader info 72 | loader.iterators = infos.get('iterators', loader.iterators) 73 | loader.split_ix = infos.get('split_ix', loader.split_ix) 74 | 75 | print("Build image cnn model, and initialize it with pre-trained cnn model") 76 | model, crit, optimizer, infos = models.build_models(opt, infos) 77 | 78 | update_lr_flag = True 79 | while True: 80 | gc.collect() # collect cpu memory 81 | if update_lr_flag: 82 | utils.paramsForEpoch(opt, epoch, optimizer) 83 | utils.update_lr(opt, epoch, optimizer) # Assign the learning rate 84 | 85 | data = loader.get_batch('train') # Load data from train split (0) 86 | torch.cuda.synchronize() 87 | start = time.time() 88 | 89 | images, mil_label = extract_fts(opt, data) 90 | optimizer.zero_grad() 91 | crit_outputs = crit(model(images), mil_label) 92 | loss = crit_outputs[0] 93 | loss.backward() 94 | utils.clip_gradient(optimizer, opt.grad_clip) 95 | optimizer.step() 96 | torch.cuda.synchronize() 97 | train_loss = loss.data[0] 98 | 99 | last_name = os.path.basename(opt.model_id) 100 | last_time = last_name[0:8] 101 | print( 102 | "{}/{},{}/{},loss(t|{:.4f},v|{:.4f})|T/B({:.2f})" \ 103 | .format(opt.model+'.'+last_time, iteration, epoch, opt.batch_size, 104 | train_loss, val_loss, 105 | time.time() - start)) 106 | 107 | # Update the iteration and epoch 108 | iteration += 1 109 | if data['bounds']['wrapped']: 110 | epoch += 1 111 | update_lr_flag = True 112 | 113 | # Write the training loss summary 114 | if (iteration % opt.losses_log_every == 0): 115 | current_record = [train_loss] 116 | history_record = [loss_history, lr_history] 117 | history_record = record_training(opt, model, iteration, tf_summary_writer, current_record, history_record) 118 | [loss_history, lr_history] = history_record 119 | 120 | # make evaluation on validation set, and save model 121 | if (iteration % opt.save_checkpoint_every == 0): 122 | eval_kwargs = {'split': 'test', 'dataset': opt.input_json} 123 | eval_kwargs.update(vars(opt)) 124 | eval_kwargs['split'] = 'test' 125 | eval_kwargs['dataset'] = opt.input_json 126 | val_loss = eval_utils.eval_split(opt, model, crit, loader, eval_kwargs) 127 | 128 | utils.add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration) 129 | tf_summary_writer.flush() 130 | val_result_history[iteration] = {'loss': val_loss} 131 | 132 | # Save model if is improving on validation result 133 | current_score = val_loss 134 | best_flag = False 135 | if True: # if true 136 | if best_val_score is None or current_score > best_val_score: 137 | best_val_score = current_score 138 | best_flag = True 139 | # Dump miscalleous informations 140 | infos['iter'] = iteration 141 | infos['epoch'] = epoch 142 | infos['iterators'] = loader.iterators 143 | infos['split_ix'] = loader.split_ix 144 | infos['best_val_score'] = best_val_score 145 | infos['opt'] = opt 146 | infos['val_result_history'] = val_result_history 147 | infos['loss_history'] = loss_history 148 | infos['lr_history'] = lr_history 149 | infos['vocab'] = loader.get_vocab() 150 | # Dump checkpoint 151 | record_ckpt(opt, infos, model, optimizer, best_flag) 152 | # Stop if reaching max epochs 153 | if epoch >= opt.max_epochs and opt.max_epochs != -1: 154 | break 155 | 156 | 157 | ''' 158 | Main function: Start from here !!! 159 | ''' 160 | opt = opts.parse_opt() 161 | train(opt) 162 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | clear 4 | 5 | case "$1" in 6 | 0) 7 | echo "run resnet 101 debug" 8 | CUDA_VISIBLE_DEVICES=0,1 python train.py --model 'resnet101' --n_gpus=2 --batch_size 50 --optim='sgd' --learning_rate_decay_start=0 9 | ;; 10 | 11 | 1) 12 | echo "run vgg19 debug" 13 | CUDA_VISIBLE_DEVICES=0,1 python train.py --model 'vgg19' --n_gpus=2 --batch_size 10 --optim='sgd' --learning_rate_decay_start=0 14 | ;; 15 | 16 | 2) 17 | echo "run resnet101 debug" 18 | CUDA_VISIBLE_DEVICES=1 python train.py --model 'resnet101' --batch_size 10 --learning_rate 1e-4 19 | ;; 20 | 21 | *) 22 | echo 23 | echo "No input" 24 | ;; 25 | esac 26 | 27 | -------------------------------------------------------------------------------- /vocabs/vocab_train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/vocabs/vocab_train.pkl -------------------------------------------------------------------------------- /vocabs/vocab_words.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/vocabs/vocab_words.txt --------------------------------------------------------------------------------