├── README.md
├── __init__.py
├── coco_voc.py
├── dataloader.py
├── extract_cls_oid.py
├── extract_det_oid.py
├── model
├── __init__.py
├── coco_voc.py
├── eval_utils.py
├── mil_vocab.pkl
├── models.py
├── readme.md
├── resnet.py
├── resnet_mil.py
├── resnet_utils.py
├── utils.py
└── vgg_mil.py
├── opts.py
├── scripts
├── convert_tf2pth.sh
└── graphs
│ ├── events.out.tfevents.1525752019.jxgu
│ ├── events.out.tfevents.1525752072.jxgu
│ └── events.out.tfevents.1525752099.jxgu
├── test.py
├── test.sh
├── test_v2.py
├── train.py
├── train.sh
└── vocabs
├── vocab_train.pkl
└── vocab_words.txt
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch implementation of Multiple-instance learning
2 |
3 |
4 | ## Updates
5 |
6 | - [ ] Training/Testing on MS COCO
7 | - [x] Testing on Openimages, object detection and classification
8 | - [x] Testing on single image
9 |
10 | ## About
11 | This repository contains the MIL implementation for the experiments in:
12 | ```
13 | https://github.com/s-gupta/visual-concepts
14 | ```
15 |
16 | If you want to use the original caffe implementation, pls follow the instructions below:
17 | ```bash
18 | git clone git@github.com:s-gupta/caffe.git code/caffe
19 | cd code/caffe
20 | git checkout mil
21 | make -j 16
22 | make pycaffe
23 | cd ../../
24 | ```
25 |
26 | Or, you can use the transformed PyTorch model of mine. You can download them from [Tencent Weiyun](https://share.weiyun.com/5TxJAM4) or [Google Drive](https://drive.google.com/open?id=1wgzA7giTKEsZpSJt-NB0JnpSJgK2unVF). Files included in that directory are:
27 | ```
28 | model/coco_valid1_eval.pkl
29 | model/mil.pth
30 | ```
31 |
32 | ## Test results
33 | You can change the url in test.py to your testing image. Then run:
34 | ```python
35 | python test.py
36 | ```
37 |
38 |
39 |
40 | ```
41 | ['beach', 'dog', 'brown', 'standing', 'people', 'his', 'sandy', 'white', 'sitting', 'laying']
42 | [1.0, 0.62, 0.62, 0.5, 0.45000000000000001, 0.42999999999999999, 0.37, 0.35999999999999999, 0.34000000000000002, 0.28000000000000003]
43 | ```
44 |
45 |
46 |
47 | ```
48 | ['cat', 'sink', 'sitting', 'black', 'bathroom', 'white', 'top', 'sits', 'counter', 'looking']
49 | [1.0, 0.68000000000000005, 0.56999999999999995, 0.56000000000000005, 0.39000000000000001, 0.34999999999999998, 0.31, 0.26000000000000001, 0.23000000000000001, 0.22]
50 | ```
51 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/__init__.py
--------------------------------------------------------------------------------
/coco_voc.py:
--------------------------------------------------------------------------------
1 | pycoco/coco_voc.py
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | import cPickle
5 | import json
6 | import h5py
7 | import os
8 | import numpy as np
9 | import random
10 | import torch
11 | import cv2, numpy as np
12 | from torchvision import transforms as trn
13 | from multiprocessing.dummy import Pool
14 | import math
15 | import gc
16 |
17 | preprocess = trn.Compose([
18 | #trn.ToTensor(),
19 | trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20 | ])
21 |
22 | preprocess_vgg16 = trn.Compose([
23 | #trn.ToTensor(),
24 | trn.Normalize([123.680, 103.939, 116.779], [1.000, 1.000, 1.000])
25 | ])
26 |
27 | def upsample_image(im, sz):
28 | h = im.shape[0]
29 | w = im.shape[1]
30 | s = np.float(max(h, w))
31 | #I_out = np.zeros((sz, sz, 3), dtype=np.float);
32 | #I = cv2.resize(im, None, None, fx=np.float(sz) / s, fy=np.float(sz) / s, interpolation=cv2.INTER_CUBIC); #INTER_CUBIC, INTER_LINEAR
33 | I = cv2.resize(im, (sz, sz), interpolation=cv2.INTER_LINEAR)
34 | SZ = I.shape;
35 | #I_out[0:I.shape[0], 0:I.shape[1], :] = I;
36 | return I, I, SZ
37 |
38 | def preprocess_vgg19_mil(Image):
39 | if len(Image.shape) == 2:
40 | Image = Image[:, :, np.newaxis]
41 | Image = np.concatenate((Image, Image, Image), axis=2)
42 |
43 | mean = np.array([[[103.939, 116.779, 123.68]]]);
44 | base_image_size = 565;
45 | Image = cv2.resize(np.transpose(Image, axes=(1, 2, 0)), (base_image_size, base_image_size), interpolation=cv2.INTER_CUBIC)
46 | Image_orig = Image.astype(np.float32, copy=True)
47 | Image_orig -= mean
48 | im = Image_orig
49 | #im, gr, grr = upsample_image(Image_orig, base_image_size)
50 | # im = cv2.resize(Image_orig, (base_image_size, base_image_size), interpolation=cv2.INTER_CUBIC)
51 | im = np.transpose(im, axes=(2, 0, 1))
52 | im = im[np.newaxis, :, :, :]
53 | return im
54 |
55 | '''
56 | Load data from h5 files
57 | '''
58 | class DataLoader():
59 | def reset_iterator(self, split):
60 | # if load files from directory, then reset the prefetch process
61 | self.iterators[split] = 0
62 |
63 | def get_vocab_size(self):
64 | return self.vocab_size
65 |
66 | def get_vocab(self):
67 | return self.ix_to_word
68 |
69 | def get_seq_length(self):
70 | return self.seq_length
71 |
72 | def __init__(self, opt):
73 | self.type = 'h5'
74 | self.opt = opt
75 | self.model = getattr(opt, 'model', 'resnet101')
76 | self.attrs_in = getattr(opt, 'attrs_in', 0)
77 | self.attrs_out = getattr(opt, 'attrs_out', 0)
78 | self.att_im = getattr(opt, 'att_im', 1)
79 | self.pre_ft = getattr(opt, 'pre_ft', 1)
80 | self.mil_vocab_outsize = 1000
81 | self.top_attrs = 10
82 | self.fc_feat_size = opt.fc_feat_size
83 | self.att_feat_size = opt.att_feat_size
84 | self.batch_size = opt.batch_size
85 | self.seq_per_img = opt.seq_per_img
86 |
87 | # load the json file which contains additional information about the dataset
88 | print('DataLoader loading json file: ', opt.input_json)
89 | self.info = json.load(open(self.opt.input_json))
90 | self.mil_vocab = cPickle.load(open('model/mil_vocab.pkl'))
91 | self.ix_to_word = self.info['ix_to_word']
92 | self.vocab_size = len(self.ix_to_word)
93 | print('vocab size is ', self.vocab_size)
94 |
95 | # open the hdf5 file
96 | print('DataLoader loading h5 file: ', opt.input_im_h5)
97 | self.h5_im_file = h5py.File(self.opt.input_im_h5)
98 | # extract image size from dataset
99 | images_size = self.h5_im_file['images'].shape
100 | assert len(images_size) == 4, 'images should be a 4D tensor'
101 | assert images_size[2] == images_size[3], 'width and height must match'
102 | self.num_images = images_size[0]
103 | self.num_channels = images_size[1]
104 | self.max_image_size = images_size[2]
105 | print('read %d images of size %dx%dx%d' %(self.num_images,
106 | self.num_channels, self.max_image_size, self.max_image_size))
107 | # load in the sequence data
108 | self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
109 | seq_size = self.h5_label_file['labels'].shape
110 | self.seq_length = seq_size[1]
111 | semantic_attrs_size = self.h5_label_file['semantic_words'].shape
112 | self.semantic_attrs_length = semantic_attrs_size[1]
113 | print('max sequence length in data is', self.seq_length)
114 | print('max semantic words length in data is', self.semantic_attrs_length)
115 | # load the pointers in full to RAM (should be small enough)
116 | self.label_start_ix = self.h5_label_file['label_start_ix'][:]
117 | self.label_end_ix = self.h5_label_file['label_end_ix'][:]
118 |
119 | self.num_images = self.label_start_ix.shape[0]
120 | print('read %d image / features' %(self.num_images))
121 |
122 | # separate out indexes for each of the provided splits
123 | self.split_ix = {'train': [], 'val': [], 'test': []}
124 | for ix in range(len(self.info['images'])):
125 | img = self.info['images'][ix]
126 | if img['split'] == 'train':
127 | self.split_ix['train'].append(ix)
128 | elif img['split'] == 'val':
129 | self.split_ix['val'].append(ix)
130 | elif img['split'] == 'test':
131 | self.split_ix['test'].append(ix)
132 | elif opt.train_only == 0: # restval
133 | self.split_ix['train'].append(ix)
134 |
135 | print('assigned %d images to split train' %len(self.split_ix['train']))
136 | print('assigned %d images to split val' %len(self.split_ix['val']))
137 | print('assigned %d images to split test' %len(self.split_ix['test']))
138 |
139 | self.iterators = {'train': 0, 'val': 0, 'test': 0}
140 |
141 | def gen_mil_gt(self, attrs):
142 | mil_batch = np.zeros([1, self.mil_vocab_outsize], dtype='int')
143 | for k in range(len(attrs)):
144 | if attrs[k] > 0:
145 | for i in range(self.mil_vocab_outsize):
146 | if self.ix_to_word[str(attrs[k])] == self.mil_vocab[i]:
147 | mil_batch[0, i] = 1
148 |
149 | return mil_batch
150 |
151 | def get_batch(self, split, batch_size=None, seq_per_img=None):
152 | split_ix = self.split_ix[split]
153 | batch_size = batch_size or self.batch_size
154 | seq_per_img = seq_per_img or self.seq_per_img
155 |
156 | if 'vgg19' in self.model:
157 | img_batch = np.ndarray([batch_size, 3, 565, 565], dtype='float32')
158 | else:
159 | img_batch = np.ndarray([batch_size, 3, 224, 224], dtype='float32')
160 | label_batch = np.zeros([batch_size * self.seq_per_img, self.seq_length + 2], dtype = 'int')
161 | mask_batch = np.zeros([batch_size * self.seq_per_img, self.seq_length + 2], dtype = 'float32')
162 | attrs_batch = np.zeros([batch_size, self.top_attrs], dtype = 'int')
163 | mil_batch = np.zeros([batch_size, self.mil_vocab_outsize], dtype='int')
164 | max_index = len(split_ix)
165 | wrapped = False
166 |
167 | infos = []
168 | gts = []
169 |
170 | for i in range(batch_size):
171 | import time
172 | t_start = time.time()
173 |
174 | ri = self.iterators[split]
175 | ri_next = ri + 1
176 | if ri_next >= max_index:
177 | ri_next = 0
178 | wrapped = True
179 | self.iterators[split] = ri_next
180 | ix = split_ix[ri]
181 |
182 | #img = self.load_image(self.image_info[ix]['filename'])
183 | img = self.h5_im_file['images'][ix, :, :, :]
184 | if 'resnet' in self.model:
185 | img_batch[i] = preprocess(torch.from_numpy(img[:, 16:-16, 16:-16].astype('float32')/255.0)).numpy()
186 | else:
187 | #img_batch[i] = preprocess_vgg16(torch.from_numpy(img[:, 16:-16, 16:-16].astype('float32'))).numpy()
188 | img_batch[i] = preprocess_vgg19_mil(img)
189 |
190 | # fetch the semantic_attributes
191 | attrs_batch[i] = self.h5_label_file['semantic_words'][ix, : self.top_attrs]
192 | mil_batch[i] = self.gen_mil_gt(attrs_batch[i])
193 |
194 | # fetch the sequence labels
195 | ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
196 | ix2 = self.label_end_ix[ix] - 1
197 | ncap = ix2 - ix1 + 1 # number of captions available for this image
198 | assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
199 |
200 | # record associated info as well
201 | info_dict = {}
202 | info_dict['ix'] = ix
203 | info_dict['id'] = self.info['images'][ix]['id']
204 | info_dict['file_path'] = self.info['images'][ix]['file_path']
205 | infos.append(info_dict)
206 |
207 | # generate mask
208 | t_start = time.time()
209 | nonzeros = np.array(map(lambda x: (x != 0).sum()+2, label_batch))
210 | for ix, row in enumerate(mask_batch):
211 | row[:nonzeros[ix]] = 1
212 |
213 | data = {}
214 |
215 | data['images'] = img_batch # if pre_ft is 1, then it equals None
216 | data['semantic_words'] = attrs_batch # if attributes is 1, then it equals None
217 | data['mil_label'] = mil_batch
218 | data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(split_ix), 'wrapped': wrapped}
219 | data['infos'] = infos
220 |
221 | gc.collect()
222 |
223 | return data
--------------------------------------------------------------------------------
/extract_cls_oid.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # Copyright 2017 The Open Images Authors. All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | r"""Classifier inference utility.
18 |
19 | This code takes a resnet_v1_101 checkpoint, runs the classifier on the image and
20 | prints predictions in human-readable form.
21 |
22 | -------------------------------
23 | Example command:
24 | -------------------------------
25 |
26 | # 0. Create directory for model/data
27 | WORK_PATH="/tmp/oidv2"
28 | mkdir -p "${WORK_PATH}"
29 | cd "${WORK_PATH}"
30 |
31 | # 1. Download the model, inference code, and sample image
32 | wget https://storage.googleapis.com/openimages/2017_07/classes-trainable.txt
33 | wget https://storage.googleapis.com/openimages/2017_07/class-descriptions.csv
34 | wget https://storage.googleapis.com/openimages/2017_07/oidv2-resnet_v1_101.ckpt.tar.gz
35 | wget https://raw.githubusercontent.com/openimages/dataset/master/tools/classify_oidv2.py
36 | tar -xzf oidv2-resnet_v1_101.ckpt.tar.gz
37 |
38 | wget -O cat.jpg https://farm6.staticflickr.com/5470/9372235876_d7d69f1790_b.jpg
39 |
40 | # 2. Run inference
41 | python classify_oidv2.py \
42 | --checkpoint_path='oidv2-resnet_v1_101.ckpt' \
43 | --labelmap='classes-trainable.txt' \
44 | --dict='class-descriptions.csv' \
45 | --image="cat.jpg" \
46 | --top_k=10 \
47 | --score_threshold=0.3
48 |
49 | # Sample output:
50 | Image: "cat.jpg"
51 |
52 | 3272: /m/068hy - Pet (score = 0.96)
53 | 1076: /m/01yrx - Cat (score = 0.95)
54 | 0708: /m/01l7qd - Whiskers (score = 0.90)
55 | 4755: /m/0jbk - Animal (score = 0.90)
56 | 2847: /m/04rky - Mammal (score = 0.89)
57 | 2036: /m/0307l - Felidae (score = 0.79)
58 | 3574: /m/07k6w8 - Small to medium-sized cats (score = 0.77)
59 | 4799: /m/0k0pj - Nose (score = 0.70)
60 | 1495: /m/02cqfm - Close-up (score = 0.55)
61 | 0036: /m/012c9l - Domestic short-haired cat (score = 0.40)
62 |
63 | -------------------------------
64 | Note on image preprocessing:
65 | -------------------------------
66 |
67 | This is the code used to perform preprocessing:
68 | --------
69 | from preprocessing import preprocessing_factory
70 |
71 | def PreprocessImage(image, network='resnet_v1_101', image_size=299):
72 | # If resolution is larger than 224 we need to adjust some internal resizing
73 | # parameters for vgg preprocessing.
74 | if any(network.startswith(x) for x in ['resnet', 'vgg']):
75 | preprocessing_kwargs = {
76 | 'resize_side_min': int(256 * image_size / 224),
77 | 'resize_side_max': int(512 * image_size / 224)
78 | }
79 | else:
80 | preprocessing_kwargs = {}
81 | preprocessing_fn = preprocessing_factory.get_preprocessing(
82 | name=network, is_training=False)
83 |
84 | height = image_size
85 | width = image_size
86 | image = preprocessing_fn(image, height, width, **preprocessing_kwargs)
87 | image.set_shape([height, width, 3])
88 | return image
89 | --------
90 |
91 | Note that there appears to be a small difference between the public version
92 | of slim image processing library and the internal version (which the meta
93 | graph is based on). Results that are very close, but not exactly identical to
94 | that of the metagraph.
95 | """
96 | from __future__ import absolute_import
97 | from __future__ import division
98 | from __future__ import print_function
99 | import tensorflow as tf
100 | import argparse
101 | import random
102 | import numpy as np
103 | import time, os, sys
104 | import json
105 | import cv2
106 | import os
107 | import six.moves.urllib as urllib
108 | import sys
109 | import tarfile
110 | import zipfile
111 |
112 | from collections import defaultdict
113 | from io import StringIO
114 | from matplotlib import pyplot as plt
115 | from PIL import Image
116 |
117 | # This is needed since the notebook is stored in the object_detection folder.
118 | flags = tf.app.flags
119 | FLAGS = flags.FLAGS
120 |
121 | def load_image_ids(split_name):
122 | ''' Load a list of (path,image_id tuples). Modify this to suit your data locations. '''
123 | split = []
124 | if split_name == 'coco_test2014':
125 | with open('data/mscoco/annotations/image_info_test2014.json') as f:
126 | data = json.load(f)
127 | for item in data['images']:
128 | image_id = int(item['id'])
129 | filepath = os.path.join('data/mscoco/test2014/', item['file_name'])
130 | split.append((filepath,image_id))
131 | elif split_name == 'coco_val2014':
132 | with open('data/mscoco/annotations/captions_val2014.json') as f:
133 | data = json.load(f)
134 | for item in data['images']:
135 | image_id = int(item['id'])
136 | filepath = os.path.join('data/mscoco/val2014/', item['file_name'])
137 | split.append((filepath,image_id))
138 | elif split_name == 'coco_train2014':
139 | with open('data/mscoco/annotations/captions_train2014.json') as f:
140 | data = json.load(f)
141 | for item in data['images']:
142 | image_id = int(item['id'])
143 | filepath = os.path.join('data/mscoco/train2014/', item['file_name'])
144 | split.append((filepath,image_id))
145 | elif split_name == 'coco_test2015':
146 | with open('data/mscoco/annotations/image_info_test2015.json') as f:
147 | data = json.load(f)
148 | for item in data['images']:
149 | image_id = int(item['id'])
150 | filepath = os.path.join('data/mscoco/test2015/', item['file_name'])
151 | split.append((filepath,image_id))
152 | elif split_name == 'genome':
153 | with open('data/visualgenome/image_data.json') as f:
154 | for item in json.load(f):
155 | image_id = int(item['image_id'])
156 | filepath = os.path.join('data/visualgenome/', item['url'].split('rak248/')[-1])
157 | split.append((filepath,image_id))
158 | elif split_name == 'chinese':
159 | with open('data/aic_i2t/ai_challenger_caption_train_20170902/caption_train_annotations_20170902.json') as f:
160 | for item in json.load(f):
161 | image_id = item['image_id']
162 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_train_20170902/caption_train_images_20170902', image_id)
163 | split.append((filepath,image_id))
164 | elif split_name == 'chinese_val':
165 | with open('data/aic_i2t/ai_challenger_caption_validation_20170910/caption_validation_annotations_20170910.json') as f:
166 | for item in json.load(f):
167 | image_id = item['image_id']
168 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_validation_20170910/caption_validation_images_20170910', image_id)
169 | split.append((filepath,image_id))
170 | elif split_name == 'chinese_test1':
171 | with open('data/aic_i2t/ai_challenger_caption_test1_20170923/caption_test1_annotations_20170923.json') as f:
172 | for item in json.load(f):
173 | image_id = item['image_id']
174 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_test1_20170923/caption_test1_images_20170923', image_id)
175 | split.append((filepath,image_id))
176 | else:
177 | print('Unknown split')
178 | return split
179 |
180 | def get_classifications_from_im(args, g, sess, image_ids):
181 | save_dir = 'oid_data/'
182 | input_values = g.get_tensor_by_name('input_values:0')
183 | predictions = g.get_tensor_by_name('multi_predictions:0')
184 | count = 0
185 | for im_file, image_id in image_ids:
186 | compressed_image = tf.gfile.FastGFile(im_file, 'rb').read()
187 | predictions_eval = sess.run(predictions, feed_dict={input_values: [compressed_image]})
188 | if 'chinese' in args.data_split:
189 | np.savez_compressed(save_dir + 'aic_i2t/oid_cls/' + str(image_id), feat=predictions_eval)
190 | else:
191 | np.savez_compressed(save_dir + 'mscoco/oid_cls/' + str(image_id), feat=predictions_eval)
192 | if (count % 100) == 0:
193 | print('{:d}'.format(count + 1))
194 | count += 1
195 |
196 | def parse_args():
197 | """
198 | Parse input arguments
199 | """
200 | parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network')
201 | parser.add_argument('--gpu', dest='gpu_id', help='GPU id(s) to use', default='0', type=str)
202 | parser.add_argument('--type', dest='type', help='', default='det', type=str)
203 | parser.add_argument('--def', dest='prototxt', help='prototxt file defining the network', default='../models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt', type=str)
204 | parser.add_argument('--net', dest='caffemodel', help='model to use', default='../data/faster_rcnn_models/resnet101_faster_rcnn_final.caffemodel', type=str)
205 | parser.add_argument('--out', dest='outfile', help='output filepath', default='karpathy_train_resnet101_faster_rcnn_genome', type=str)
206 | parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default='../experiments/cfgs/faster_rcnn_end2end_resnet.yml', type=str)
207 | parser.add_argument('--split', dest='data_split', help='dataset to use', default='coco_val2014', type=str)
208 | parser.add_argument('--checkpoint_path', dest='checkpoint_path', help='', default='model/oidv2_resnet_v1_101/oidv2-resnet_v1_101.ckpt', type=str)
209 | parser.add_argument('--set', dest='set_cfgs', help='set config keys', default=None, nargs=argparse.REMAINDER)
210 |
211 | args = parser.parse_args()
212 | return args
213 |
214 |
215 | def main(_):
216 | args = parse_args()
217 |
218 | print('Called with args:')
219 | print(args)
220 | g = tf.Graph()
221 | with g.as_default():
222 | with tf.Session() as sess:
223 | saver = tf.train.import_meta_graph(args.checkpoint_path + '.meta')
224 | saver.restore(sess, args.checkpoint_path)
225 | image_ids = load_image_ids(args.data_split)
226 | get_classifications_from_im(args, g, sess, image_ids)
227 |
228 | if __name__ == '__main__':
229 | tf.app.run()
230 |
--------------------------------------------------------------------------------
/extract_det_oid.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # Copyright 2017 The Open Images Authors. All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | r"""Classifier inference utility.
18 |
19 | This code takes a resnet_v1_101 checkpoint, runs the classifier on the image and
20 | prints predictions in human-readable form.
21 |
22 | -------------------------------
23 | Example command:
24 | -------------------------------
25 |
26 | # 0. Create directory for model/data
27 | WORK_PATH="/tmp/oidv2"
28 | mkdir -p "${WORK_PATH}"
29 | cd "${WORK_PATH}"
30 |
31 | # 1. Download the model, inference code, and sample image
32 | wget https://storage.googleapis.com/openimages/2017_07/classes-trainable.txt
33 | wget https://storage.googleapis.com/openimages/2017_07/class-descriptions.csv
34 | wget https://storage.googleapis.com/openimages/2017_07/oidv2-resnet_v1_101.ckpt.tar.gz
35 | wget https://raw.githubusercontent.com/openimages/dataset/master/tools/classify_oidv2.py
36 | tar -xzf oidv2-resnet_v1_101.ckpt.tar.gz
37 |
38 | wget -O cat.jpg https://farm6.staticflickr.com/5470/9372235876_d7d69f1790_b.jpg
39 |
40 | # 2. Run inference
41 | python classify_oidv2.py \
42 | --checkpoint_path='oidv2-resnet_v1_101.ckpt' \
43 | --labelmap='classes-trainable.txt' \
44 | --dict='class-descriptions.csv' \
45 | --image="cat.jpg" \
46 | --top_k=10 \
47 | --score_threshold=0.3
48 |
49 | # Sample output:
50 | Image: "cat.jpg"
51 |
52 | 3272: /m/068hy - Pet (score = 0.96)
53 | 1076: /m/01yrx - Cat (score = 0.95)
54 | 0708: /m/01l7qd - Whiskers (score = 0.90)
55 | 4755: /m/0jbk - Animal (score = 0.90)
56 | 2847: /m/04rky - Mammal (score = 0.89)
57 | 2036: /m/0307l - Felidae (score = 0.79)
58 | 3574: /m/07k6w8 - Small to medium-sized cats (score = 0.77)
59 | 4799: /m/0k0pj - Nose (score = 0.70)
60 | 1495: /m/02cqfm - Close-up (score = 0.55)
61 | 0036: /m/012c9l - Domestic short-haired cat (score = 0.40)
62 |
63 | -------------------------------
64 | Note on image preprocessing:
65 | -------------------------------
66 |
67 | This is the code used to perform preprocessing:
68 | --------
69 | from preprocessing import preprocessing_factory
70 |
71 | def PreprocessImage(image, network='resnet_v1_101', image_size=299):
72 | # If resolution is larger than 224 we need to adjust some internal resizing
73 | # parameters for vgg preprocessing.
74 | if any(network.startswith(x) for x in ['resnet', 'vgg']):
75 | preprocessing_kwargs = {
76 | 'resize_side_min': int(256 * image_size / 224),
77 | 'resize_side_max': int(512 * image_size / 224)
78 | }
79 | else:
80 | preprocessing_kwargs = {}
81 | preprocessing_fn = preprocessing_factory.get_preprocessing(
82 | name=network, is_training=False)
83 |
84 | height = image_size
85 | width = image_size
86 | image = preprocessing_fn(image, height, width, **preprocessing_kwargs)
87 | image.set_shape([height, width, 3])
88 | return image
89 | --------
90 |
91 | Note that there appears to be a small difference between the public version
92 | of slim image processing library and the internal version (which the meta
93 | graph is based on). Results that are very close, but not exactly identical to
94 | that of the metagraph.
95 | """
96 | from __future__ import absolute_import
97 | from __future__ import division
98 | from __future__ import print_function
99 | import tensorflow as tf
100 | import argparse
101 | import random
102 | import numpy as np
103 | import time, os, sys
104 | import json
105 | import cv2
106 | import os
107 | import six.moves.urllib as urllib
108 | import sys
109 | import tarfile
110 | import zipfile
111 |
112 | from collections import defaultdict
113 | from io import StringIO
114 | from matplotlib import pyplot as plt
115 | from PIL import Image
116 |
117 | # This is needed since the notebook is stored in the object_detection folder.
118 | sys.path.append("/home/jxgu/github/MIL.pytorch/misc/models/research")
119 | from utils import label_map_util
120 | from utils import visualization_utils as vis_util
121 | import utils.ops as utils_ops
122 | flags = tf.app.flags
123 | FLAGS = flags.FLAGS
124 |
125 | def load_image_ids(split_name):
126 | ''' Load a list of (path,image_id tuples). Modify this to suit your data locations. '''
127 | split = []
128 | if split_name == 'coco_test2014':
129 | with open('data/mscoco/annotations/image_info_test2014.json') as f:
130 | data = json.load(f)
131 | for item in data['images']:
132 | image_id = int(item['id'])
133 | filepath = os.path.join('data/mscoco/test2014/', item['file_name'])
134 | split.append((filepath,image_id))
135 | elif split_name == 'coco_val2014':
136 | with open('data/mscoco/annotations/captions_val2014.json') as f:
137 | data = json.load(f)
138 | for item in data['images']:
139 | image_id = int(item['id'])
140 | filepath = os.path.join('data/mscoco/val2014/', item['file_name'])
141 | split.append((filepath,image_id))
142 | elif split_name == 'coco_train2014':
143 | with open('data/mscoco/annotations/captions_train2014.json') as f:
144 | data = json.load(f)
145 | for item in data['images']:
146 | image_id = int(item['id'])
147 | filepath = os.path.join('data/mscoco/train2014/', item['file_name'])
148 | split.append((filepath,image_id))
149 | elif split_name == 'coco_test2015':
150 | with open('data/mscoco/annotations/image_info_test2015.json') as f:
151 | data = json.load(f)
152 | for item in data['images']:
153 | image_id = int(item['id'])
154 | filepath = os.path.join('data/mscoco/test2015/', item['file_name'])
155 | split.append((filepath,image_id))
156 | elif split_name == 'genome':
157 | with open('data/visualgenome/image_data.json') as f:
158 | for item in json.load(f):
159 | image_id = int(item['image_id'])
160 | filepath = os.path.join('data/visualgenome/', item['url'].split('rak248/')[-1])
161 | split.append((filepath,image_id))
162 | elif split_name == 'chinese':
163 | with open('data/aic_i2t/ai_challenger_caption_train_20170902/caption_train_annotations_20170902.json') as f:
164 | for item in json.load(f):
165 | image_id = item['image_id']
166 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_train_20170902/caption_train_images_20170902', image_id)
167 | split.append((filepath,image_id))
168 | elif split_name == 'chinese_val':
169 | with open('data/aic_i2t/ai_challenger_caption_validation_20170910/caption_validation_annotations_20170910.json') as f:
170 | for item in json.load(f):
171 | image_id = item['image_id']
172 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_validation_20170910/caption_validation_images_20170910', image_id)
173 | split.append((filepath,image_id))
174 | elif split_name == 'chinese_test1':
175 | with open('data/aic_i2t/ai_challenger_caption_test1_20170923/caption_test1_annotations_20170923.json') as f:
176 | for item in json.load(f):
177 | image_id = item['image_id']
178 | filepath = os.path.join('data/aic_i2t/ai_challenger_caption_test1_20170923/caption_test1_images_20170923', image_id)
179 | split.append((filepath,image_id))
180 | else:
181 | print('Unknown split')
182 | return split
183 |
184 | def run_inference_for_single_image(image, graph):
185 | with graph.as_default():
186 | with tf.Session() as sess:
187 | # Get handles to input and output tensors
188 | ops = tf.get_default_graph().get_operations()
189 | all_tensor_names = {output.name for op in ops for output in op.outputs}
190 | tensor_dict = {}
191 | for key in [
192 | 'num_detections', 'detection_boxes', 'detection_scores',
193 | 'detection_classes', 'detection_masks'
194 | ]:
195 | tensor_name = key + ':0'
196 | if tensor_name in all_tensor_names:
197 | tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
198 | tensor_name)
199 | if 'detection_masks' in tensor_dict:
200 | # The following processing is only for single image
201 | detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
202 | detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
203 | # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
204 | real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
205 | detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
206 | detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
207 | detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
208 | detection_masks, detection_boxes, image.shape[0], image.shape[1])
209 | detection_masks_reframed = tf.cast(tf.greater(detection_masks_reframed, 0.2), tf.uint8)
210 | # Follow the convention by adding back the batch dimension
211 | tensor_dict['detection_masks'] = tf.expand_dims(
212 | detection_masks_reframed, 0)
213 | image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
214 |
215 | # Run inference
216 | output_dict = sess.run(tensor_dict, feed_dict={image_tensor: np.expand_dims(image, 0)})
217 |
218 | # all outputs are float32 numpy arrays, so convert types as appropriate
219 | output_dict['num_detections'] = int(output_dict['num_detections'][0])
220 | output_dict['detection_classes'] = output_dict[
221 | 'detection_classes'][0].astype(np.uint8)
222 | output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
223 | output_dict['detection_scores'] = output_dict['detection_scores'][0]
224 | if 'detection_masks' in output_dict:
225 | output_dict['detection_masks'] = output_dict['detection_masks'][0]
226 | return output_dict
227 |
228 | def get_detections_from_im(args, detection_graph, image_ids):
229 | save_dir = '/home/jxgu/github/MIL.pytorch/oid_data/'
230 | count = 0
231 | for im_file, image_id in image_ids:
232 | image = Image.open(im_file)
233 | # the array based representation of the image will be used later in order to prepare the
234 | # result image with boxes and labels on it.
235 | image_np = load_image_into_numpy_array(image)
236 | # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
237 | image_np_expanded = np.expand_dims(image_np, axis=0)
238 | # Actual detection.
239 | output_dict = run_inference_for_single_image(image_np, detection_graph)
240 | if 'chinese' in args.data_split:
241 | np.savez_compressed(save_dir + 'aic_i2t/oid_det/' + str(image_id), feat=output_dict)
242 | else:
243 | np.savez_compressed(save_dir + 'mscoco/oid_det/' + str(image_id), feat=output_dict)
244 |
245 | if (count % 100) == 0:
246 | print('{:d}'.format(count + 1))
247 | count += 1
248 |
249 | def load_image_into_numpy_array(image):
250 | (im_width, im_height) = image.size
251 | return np.array(image.getdata()).reshape(
252 | (im_height, im_width, 3)).astype(np.uint8)
253 |
254 | def parse_args():
255 | """
256 | Parse input arguments
257 | """
258 | parser = argparse.ArgumentParser(description='Generate bbox output from a Fast R-CNN network')
259 | parser.add_argument('--gpu', dest='gpu_id', help='GPU id(s) to use', default='0', type=str)
260 | parser.add_argument('--type', dest='type', help='', default='det', type=str)
261 | parser.add_argument('--def', dest='prototxt', help='prototxt file defining the network', default='../models/vg/ResNet-101/faster_rcnn_end2end_final/test.prototxt', type=str)
262 | parser.add_argument('--net', dest='caffemodel', help='model to use', default='../../../../data/faster_rcnn_models/resnet101_faster_rcnn_final.caffemodel', type=str)
263 | parser.add_argument('--out', dest='outfile', help='output filepath', default='karpathy_train_resnet101_faster_rcnn_genome', type=str)
264 | parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default='../experiments/cfgs/faster_rcnn_end2end_resnet.yml', type=str)
265 | parser.add_argument('--split', dest='data_split', help='dataset to use', default='coco_train2014', type=str)
266 | parser.add_argument('--checkpoint_path', dest='checkpoint_path', help='', default='/home/jxgu/github/MIL.pytorch/model/faster_rcnn_inception_resnet_v2_atrous_oid_2018_01_28.tar.gz', type=str)
267 | parser.add_argument('--set', dest='set_cfgs', help='set config keys', default=None, nargs=argparse.REMAINDER)
268 |
269 | args = parser.parse_args()
270 | return args
271 |
272 |
273 | def main(_):
274 | args = parse_args()
275 |
276 | # What model to download.
277 | MODEL_NAME = '/home/jxgu/github/MIL.pytorch/model/faster_rcnn_inception_resnet_v2_atrous_oid_2018_01_28'
278 | MODEL_FILE = MODEL_NAME + '.tar.gz'
279 | PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
280 | tar_file = tarfile.open(MODEL_FILE)
281 | for file in tar_file.getmembers():
282 | file_name = os.path.basename(file.name)
283 | if 'frozen_inference_graph.pb' in file_name:
284 | tar_file.extract(file, os.getcwd())
285 |
286 | tar_file = tarfile.open(args.checkpoint_path)
287 | for file in tar_file.getmembers():
288 | file_name = os.path.basename(file.name)
289 | if 'frozen_inference_graph.pb' in file_name:
290 | tar_file.extract(file, os.getcwd())
291 |
292 | detection_graph = tf.Graph()
293 | with detection_graph.as_default():
294 | od_graph_def = tf.GraphDef()
295 | with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
296 | serialized_graph = fid.read()
297 | od_graph_def.ParseFromString(serialized_graph)
298 | tf.import_graph_def(od_graph_def, name='')
299 | image_ids = load_image_ids(args.data_split)
300 | get_detections_from_im(args, detection_graph, image_ids)
301 | if __name__ == '__main__':
302 | tf.app.run()
303 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/model/__init__.py
--------------------------------------------------------------------------------
/model/coco_voc.py:
--------------------------------------------------------------------------------
1 | ../pycoco/coco_voc.py
--------------------------------------------------------------------------------
/model/eval_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.autograd import Variable
8 |
9 | import numpy as np
10 | import json
11 | from json import encoder
12 | import random
13 | import string
14 | import time
15 | import os
16 | import sys
17 | import model.utils as utils
18 |
19 | def eval_split(opt, model, crit, loader, eval_kwargs={}):
20 | verbose = eval_kwargs.get('verbose', True)
21 | num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
22 | split = eval_kwargs.get('split', 'val')
23 | dataset = eval_kwargs.get('dataset', 'coco')
24 |
25 | # Make sure in the evaluation mode
26 | model.eval()
27 | loader.reset_iterator(split)
28 |
29 | n = 0
30 | loss = 0
31 | loss_sum = 0
32 | loss_evals = 1e-8
33 | predictions = []
34 | while True:
35 | data = loader.get_batch(split)
36 | n = n + loader.batch_size
37 | # forward the model to get loss
38 | tmp = [data['images'], data['mil_label']]
39 | tmp = [Variable(torch.from_numpy(_), volatile=True).cuda() for _ in tmp]
40 | images, mil_label = tmp
41 |
42 | loss = crit(model(images), mil_label).data[0]
43 | loss_sum = loss_sum + loss
44 | loss_evals = loss_evals + 1
45 |
46 | if data['bounds']['wrapped']:
47 | break
48 | if num_images >= 0 and n >= num_images:
49 | break
50 |
51 | # Switch back to training mode
52 | model.train()
53 | return loss_sum / loss_evals
--------------------------------------------------------------------------------
/model/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import collections
6 | import torch
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 | import time
10 | import Image
11 | import os
12 | import os.path as osp
13 | import sys
14 | import platform
15 | import cPickle
16 | import urllib
17 | import cv2, numpy as np
18 | from scipy.interpolate import interp1d
19 | from matplotlib.pyplot import show
20 | import matplotlib.pyplot as plt
21 | import torch
22 | import torch.nn as nn
23 | import torch.optim as optim
24 | import torch.nn.functional as F
25 | import torch.backends.cudnn as cudnn
26 | import tensorflow as tf
27 | import torchvision
28 | import torchvision.transforms as transforms
29 | from six.moves import cPickle
30 | from model.vgg_mil import *
31 | from model.resnet_mil import *
32 | from model.utils import *
33 | import itertools
34 |
35 | class Criterion(nn.Module):
36 | def __init__(self):
37 | super(Criterion, self).__init__()
38 | #self.loss0 = nn.MultiLabelMarginLoss()
39 | self.loss0 = nn.MultiLabelSoftMarginLoss()
40 | #self.loss0 = nn.MultiMarginLoss()
41 | #self.loss0 = nn.CrossEntropyLoss()
42 | #self.loss0 = nn.NLLLoss()
43 |
44 | #self.loss1 = nn.MultiMarginLoss()
45 |
46 | def forward(self, input, target):
47 | output0 = self.loss0(input, target.float())
48 | #output1 = self.loss1(input, target.long())
49 | return output0
50 |
51 | def build_mil(opt):
52 | opt.n_gpus = getattr(opt, 'n_gpus', 1)
53 |
54 | if 'resnet101' in opt.model:
55 | mil_model = resnet_mil(opt)
56 | else:
57 | mil_model = vgg_mil(opt)
58 |
59 | if opt.n_gpus>1:
60 | print('Construct multi-gpu model ...')
61 | model = nn.DataParallel(mil_model, device_ids=opt.gpus, dim=0)
62 | else:
63 | model = mil_model
64 | # check compatibility if training is continued from previously saved model
65 | if len(opt.start_from) != 0:
66 | # check if all necessary files exist
67 | assert os.path.isdir(opt.start_from), " %s must be a a path" % opt.start_from
68 | lm_info_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.infos-best.pkl')
69 | lm_pth_path = os.path.join(opt.start_from, os.path.basename(opt.start_from) + '.model-best.pth')
70 | assert os.path.isfile(lm_info_path), "infos.pkl file does not exist in path %s" % opt.start_from
71 | model.load_state_dict(torch.load(lm_pth_path))
72 | model.cuda()
73 | model.train() # Assure in training mode
74 | return model
75 |
76 | def build_optimizer(opt, model, infos):
77 | opt.pre_ft = getattr(opt, 'pre_ft', 1)
78 |
79 | #model_parameters = itertools.ifilter(lambda p: p.requires_grad, model.parameters())
80 | optimize = opt.optim
81 | if optimize == 'adam':
82 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
83 | elif optimize == 'sgd':
84 | optimizer = torch.optim.SGD(model.parameters(), lr=opt.learning_rate, momentum=0.999, weight_decay=0.0005)
85 | elif optimize == 'Adadelta':
86 | optimizer = torch.optim.Adadelta(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
87 | elif optimize == 'Adagrad':
88 | optimizer = torch.optim.Adagrad(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
89 | elif optimize == 'Adamax':
90 | optimizer = torch.optim.Adamax(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
91 | elif optimize == 'ASGD':
92 | optimizer = torch.optim.ASGD(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
93 | elif optimize == 'LBFGS':
94 | optimizer = torch.optim.LBFGS(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
95 | elif optimize == 'RMSprop':
96 | optimizer = torch.optim.RMSprop(model.parameters(), lr=opt.learning_rate, weight_decay=0.0005)
97 |
98 | infos['optimized'] = True
99 |
100 | # Load the optimizer
101 | if len(opt.start_from) != 0:
102 | if os.path.isfile(os.path.join(opt.start_from, opt.model_id + '.optimizer.pth')):
103 | optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, opt.model_id + '.optimizer.pth')))
104 |
105 | return optimizer, infos
106 |
107 | def build_models(opt, infos):
108 | model = build_mil(opt)
109 | optimizer, infos = build_optimizer(opt, model, infos)
110 | crit = Criterion() # Training with RL, then add reward crit
111 | model.cuda()
112 | model.train() # Assure in training mode
113 | return model, crit, optimizer, infos
114 |
115 | def load_models(opt, infos):
116 | model = build_mil(opt)
117 | crit = Criterion(opt)
118 | return model, crit
--------------------------------------------------------------------------------
/model/readme.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/model/readme.md
--------------------------------------------------------------------------------
/model/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 |
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152']
8 |
9 |
10 | model_urls = {
11 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth',
12 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth',
13 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth',
14 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth',
15 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth',
16 | }
17 |
18 |
19 | def conv3x3(in_planes, out_planes, stride=1):
20 | "3x3 convolution with padding"
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22 | padding=1, bias=False)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None):
29 | super(BasicBlock, self).__init__()
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = nn.BatchNorm2d(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = nn.BatchNorm2d(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | class Bottleneck(nn.Module):
58 | expansion = 4
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None):
61 | super(Bottleneck, self).__init__()
62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
63 | self.bn1 = nn.BatchNorm2d(planes)
64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
65 | padding=1, bias=False)
66 | self.bn2 = nn.BatchNorm2d(planes)
67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
68 | self.bn3 = nn.BatchNorm2d(planes * 4)
69 | self.relu = nn.ReLU(inplace=True)
70 | self.downsample = downsample
71 | self.stride = stride
72 |
73 | def forward(self, x):
74 | residual = x
75 |
76 | out = self.conv1(x)
77 | out = self.bn1(out)
78 | out = self.relu(out)
79 |
80 | out = self.conv2(out)
81 | out = self.bn2(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv3(out)
85 | out = self.bn3(out)
86 |
87 | if self.downsample is not None:
88 | residual = self.downsample(x)
89 |
90 | out += residual
91 | out = self.relu(out)
92 |
93 | return out
94 |
95 |
96 | class ResNet(nn.Module):
97 | def __init__(self, block, layers, num_classes=1000):
98 | self.inplanes = 64
99 | super(ResNet, self).__init__()
100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
101 | bias=False)
102 | self.bn1 = nn.BatchNorm2d(64)
103 | self.relu = nn.ReLU(inplace=True)
104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
105 | self.layer1 = self._make_layer(block, 64, layers[0])
106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
109 | self.avgpool = nn.AvgPool2d(7)
110 | self.fc = nn.Linear(512 * block.expansion, num_classes)
111 |
112 | for m in self.modules():
113 | if isinstance(m, nn.Conv2d):
114 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
115 | m.weight.data.normal_(0, math.sqrt(2. / n))
116 | elif isinstance(m, nn.BatchNorm2d):
117 | m.weight.data.fill_(1)
118 | m.bias.data.zero_()
119 |
120 | def _make_layer(self, block, planes, blocks, stride=1):
121 | downsample = None
122 | if stride != 1 or self.inplanes != planes * block.expansion:
123 | downsample = nn.Sequential(
124 | nn.Conv2d(self.inplanes, planes * block.expansion,
125 | kernel_size=1, stride=stride, bias=False),
126 | nn.BatchNorm2d(planes * block.expansion),
127 | )
128 |
129 | layers = []
130 | layers.append(block(self.inplanes, planes, stride, downsample))
131 | self.inplanes = planes * block.expansion
132 | for i in range(1, blocks):
133 | layers.append(block(self.inplanes, planes))
134 |
135 | return nn.Sequential(*layers)
136 |
137 | def forward(self, x):
138 | x = self.conv1(x)
139 | x = self.bn1(x)
140 | x = self.relu(x)
141 | x = self.maxpool(x)
142 |
143 | x = self.layer1(x)
144 | x = self.layer2(x)
145 | x = self.layer3(x)
146 | x = self.layer4(x)
147 |
148 | x = self.avgpool(x)
149 | x = x.view(x.size(0), -1)
150 | x = self.fc(x)
151 |
152 | return x
153 |
154 |
155 | def resnet18(pretrained=False):
156 | """Constructs a ResNet-18 model.
157 |
158 | Args:
159 | pretrained (bool): If True, returns a model pre-trained on ImageNet
160 | """
161 | model = ResNet(BasicBlock, [2, 2, 2, 2])
162 | if pretrained:
163 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
164 | return model
165 |
166 |
167 | def resnet34(pretrained=False):
168 | """Constructs a ResNet-34 model.
169 |
170 | Args:
171 | pretrained (bool): If True, returns a model pre-trained on ImageNet
172 | """
173 | model = ResNet(BasicBlock, [3, 4, 6, 3])
174 | if pretrained:
175 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
176 | return model
177 |
178 |
179 | def resnet50(pretrained=False):
180 | """Constructs a ResNet-50 model.
181 |
182 | Args:
183 | pretrained (bool): If True, returns a model pre-trained on ImageNet
184 | """
185 | model = ResNet(Bottleneck, [3, 4, 6, 3])
186 | if pretrained:
187 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
188 | return model
189 |
190 |
191 | def resnet101(pretrained=False):
192 | """Constructs a ResNet-101 model.
193 |
194 | Args:
195 | pretrained (bool): If True, returns a model pre-trained on ImageNet
196 | """
197 | model = ResNet(Bottleneck, [3, 4, 23, 3])
198 | if pretrained:
199 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
200 | return model
201 |
202 |
203 | def resnet152(pretrained=False):
204 | """Constructs a ResNet-152 model.
205 |
206 | Args:
207 | pretrained (bool): If True, returns a model pre-trained on ImageNet
208 | """
209 | model = ResNet(Bottleneck, [3, 8, 36, 3])
210 | if pretrained:
211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
212 | return model
--------------------------------------------------------------------------------
/model/resnet_mil.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 |
7 | class resnet_mil(nn.Module):
8 | def __init__(self, opt):
9 | super(resnet_mil, self).__init__()
10 | import model.resnet as resnet
11 | resnet = resnet.resnet101()
12 | resnet.load_state_dict(torch.load('/media/jxgu/d2tb/model/resnet/resnet101.pth'))
13 | self.conv = torch.nn.Sequential()
14 | self.conv.add_module("conv1", resnet.conv1)
15 | self.conv.add_module("bn1", resnet.bn1)
16 | self.conv.add_module("relu", resnet.relu)
17 | self.conv.add_module("maxpool", resnet.maxpool)
18 | self.conv.add_module("layer1", resnet.layer1)
19 | self.conv.add_module("layer2", resnet.layer2)
20 | self.conv.add_module("layer3", resnet.layer3)
21 | self.conv.add_module("layer4", resnet.layer4)
22 | self.l1 = nn.Sequential(nn.Linear(2048, 1000),
23 | nn.ReLU(True),
24 | nn.Dropout(0.5))
25 | self.att_size = 7
26 | self.pool_mil = nn.MaxPool2d(kernel_size=self.att_size, stride=0)
27 |
28 | def forward(self, img, att_size=14):
29 | x0 = self.conv(img)
30 | x = self.pool_mil(x0)
31 | x = x.squeeze(2).squeeze(2)
32 | x = self.l1(x)
33 | x1 = torch.add(torch.mul(x.view(x.size(0), 1000, -1), -1), 1)
34 | cumprod = torch.cumprod(x1, 2)
35 | out = torch.max(x, torch.add(torch.mul(cumprod[:, :, -1], -1), 1))
36 | return out
37 |
--------------------------------------------------------------------------------
/model/resnet_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 |
6 |
7 | class myResnet(nn.Module):
8 | def __init__(self, resnet):
9 | super(myResnet, self).__init__()
10 | self.resnet = resnet
11 |
12 | def forward(self, img, att_size=14):
13 | x = img.unsqueeze(0)
14 |
15 | x = self.resnet.conv1(x)
16 | x = self.resnet.bn1(x)
17 | x = self.resnet.relu(x)
18 | x = self.resnet.maxpool(x)
19 |
20 | x = self.resnet.layer1(x)
21 | x = self.resnet.layer2(x)
22 | x = self.resnet.layer3(x)
23 | x = self.resnet.layer4(x)
24 |
25 | fc = x.mean(3).mean(2)
26 | att = F.adaptive_avg_pool2d(x, [att_size, att_size]).squeeze().permute(1, 2, 0)
27 |
28 | return fc, att
29 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import collections
6 | import torch
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 | import time
10 | import Image
11 | import os
12 | import os.path as osp
13 | import sys
14 | import platform
15 | import cPickle
16 | import urllib
17 | import cv2, numpy as np
18 | from scipy.interpolate import interp1d
19 | from matplotlib.pyplot import show
20 | import matplotlib.pyplot as plt
21 | import torch
22 | import torch.nn as nn
23 | import torch.optim as optim
24 | import torch.nn.functional as F
25 | import torch.backends.cudnn as cudnn
26 | import tensorflow as tf
27 | import torchvision
28 | import torchvision.transforms as transforms
29 | from six.moves import cPickle
30 | from model.vgg_mil import *
31 | from model.resnet_mil import *
32 | import itertools
33 |
34 | '''
35 | -- Learning rate annealing schedule. We will build a new optimizer for
36 | -- each epoch.
37 | --
38 | -- By default we follow a known recipe for a 55-epoch training. If
39 | -- the learningRate command-line parameter has been specified, though,
40 | -- we trust the user is doing something manual, and will use her
41 | -- exact settings for all optimization.
42 | --
43 | -- Return values:
44 | -- diff to apply to optimState,
45 | -- true IFF this is the first epoch of a new regime
46 | '''
47 | def set_lr(optimizer, lr):
48 | for group in optimizer.param_groups:
49 | group['lr'] = lr
50 |
51 | def set_weightDecay(optimizer, weightDecay):
52 | for group in optimizer.param_groups:
53 | group['weight_decay'] = weightDecay
54 |
55 | def update_lr(opt, epoch, optimizer):
56 | if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
57 | frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
58 | decay_factor = opt.learning_rate_decay_rate ** frac
59 | opt.current_lr = opt.learning_rate * decay_factor
60 | set_lr(optimizer, opt.current_lr) # set the decayed rate
61 | else:
62 | opt.current_lr = opt.learning_rate
63 |
64 | def paramsForEpoch(opt, epoch, optimizer):
65 | # start, end, LR, WD,
66 | regimes = [ [ 0, 5, 1e-2, 5e-4, ],
67 | [ 6, 10, 5e-3, 5e-4 ],
68 | [ 11, 20, 1e-3, 0 ],
69 | [ 31, 30, 5e-4, 0 ],
70 | [ 31, 1e8, 1e-4, 0 ]]
71 | for row in regimes:
72 | if epoch>=row[0] and epoch<= row[1]:
73 | learningRate = row[2]
74 | weightDecay = row[3]
75 | opt.learning_rate = learningRate
76 | opt.weight_decay = weightDecay
77 | set_lr(optimizer, learningRate)
78 | set_weightDecay(optimizer, weightDecay)
79 |
80 | def add_summary_value(writer, key, value, iteration):
81 | summary = tf.Summary(value=[tf.Summary.Value(tag=key, simple_value=value)])
82 | writer.add_summary(summary, iteration)
83 |
84 | def history_infos(opt):
85 | infos = {}
86 | if len(opt.start_from) != 0: # open old infos and check if models are compatible
87 | model_id = opt.start_from
88 | infos_id = model_id.replace('save/', '') + '.infos-best.pkl'
89 | with open(os.path.join(opt.start_from, infos_id)) as f:
90 | infos = cPickle.load(f)
91 | saved_model_opt = infos['opt']
92 |
93 | iteration = infos.get('iter', 0)
94 | epoch = infos.get('epoch', 0)
95 | val_result_history = infos.get('val_result_history', {})
96 | loss_history = infos.get('loss_history', {})
97 | lr_history = infos.get('lr_history', {})
98 | best_val_score = infos.get('best_val_score', None) if opt.load_best_score == 1 else 0
99 | val_loss = 0.0
100 | val_history = [val_result_history, best_val_score, val_loss]
101 | train_history = [loss_history, lr_history]
102 | return opt, infos, iteration, epoch, val_history, train_history
103 |
104 |
105 | def add_path(path):
106 | if path not in sys.path:
107 | sys.path.insert(0, path)
108 | print('added {}'.format(path))
109 |
110 | def save_variables(pickle_file_name, var, info, overwrite=False):
111 | if os.path.exists(pickle_file_name) and overwrite == False:
112 | raise Exception('{:s} exists and over write is false.'.format(pickle_file_name))
113 | # Construct the dictionary
114 | assert (type(var) == list);
115 | assert (type(info) == list);
116 | d = {}
117 | for i in xrange(len(var)):
118 | d[info[i]] = var[i]
119 | with open(pickle_file_name, 'wb') as f:
120 | cPickle.dump(d, f, cPickle.HIGHEST_PROTOCOL)
121 |
122 |
123 | def load_variables(pickle_file_name):
124 | # d is a dictionary of variables stored in the pickle file.
125 | if os.path.exists(pickle_file_name):
126 | with open(pickle_file_name, 'rb') as f:
127 | d = cPickle.load(f)
128 | return d
129 | else:
130 | raise Exception('{:s} does not exists.'.format(pickle_file_name))
131 |
132 | def to_contiguous(tensor):
133 | if tensor.is_contiguous():
134 | return tensor
135 | else:
136 | return tensor.contiguous()
137 |
138 | def clip_gradient(optimizer, grad_clip):
139 | for group in optimizer.param_groups:
140 | for param in group['params']:
141 | param.grad.data.clamp_(-grad_clip, grad_clip)
142 |
143 | # METHOD #1: OpenCV, NumPy, and urllib
144 | def url_to_image(url):
145 | # download the image, convert it to a NumPy array, and then read
146 | # it into OpenCV format
147 | resp = urllib.urlopen(url)
148 | image = np.asarray(bytearray(resp.read()), dtype="uint8")
149 | image = cv2.imdecode(image, cv2.IMREAD_COLOR)
150 |
151 | # return the image
152 | return image
153 |
154 | def upsample_image(im, sz):
155 | h = im.shape[0]
156 | w = im.shape[1]
157 | s = np.float(max(h, w))
158 | I_out = np.zeros((sz, sz, 3), dtype=np.float);
159 | I = cv2.resize(im, None, None, fx=np.float(sz) / s, fy=np.float(sz) / s, interpolation=cv2.INTER_LINEAR);
160 | SZ = I.shape;
161 | I_out[0:I.shape[0], 0:I.shape[1], :] = I;
162 | return I_out, I, SZ
163 |
164 | def compute_precision_score_mapping_torch(thresh, prec, score):
165 | thresh, ind_thresh = torch.sort(torch.from_numpy(thresh), 0, descending=False)
166 |
167 | prec, ind_prec = torch.sort(torch.from_numpy(prec), 0, descending=False)
168 | val = None
169 | return val
170 |
171 | def compute_precision_mapping(pt):
172 | thresh_all = []
173 | prec_all = []
174 | for jj in xrange(1000):
175 | thresh = pt['details']['score'][:, jj]
176 | prec = pt['details']['precision'][:, jj]
177 | ind = np.argsort(thresh); # thresh, ind = torch.sort(thresh)
178 | thresh = thresh[ind];
179 | indexes = np.unique(thresh, return_index=True)[1]
180 | indexes = np.sort(indexes);
181 | thresh = thresh[indexes]
182 |
183 | thresh = np.vstack((min(-1000, min(thresh) - 1), thresh[:, np.newaxis], max(1000, max(thresh) + 1)));
184 |
185 | prec = prec[ind];
186 | for i in xrange(1, len(prec)):
187 | prec[i] = max(prec[i], prec[i - 1]);
188 | prec = prec[indexes]
189 |
190 | prec = np.vstack((prec[0], prec[:, np.newaxis], prec[-1]));
191 | thresh_all.append(thresh)
192 | prec_all.append(prec)
193 | precision_score = {'thresh': thresh_all, "prec": prec_all}
194 | return precision_score
195 |
196 | def compute_precision_score_mapping(thresh, prec, score):
197 | ind = np.argsort(thresh); # thresh, ind = torch.sort(thresh)
198 | thresh = thresh[ind];
199 | indexes = np.unique(thresh, return_index=True)[1]
200 | indexes = np.sort(indexes);
201 | thresh = thresh[indexes]
202 |
203 | thresh = np.vstack((min(-1000, min(thresh) - 1), thresh[:, np.newaxis], max(1000, max(thresh) + 1)));
204 |
205 | prec = prec[ind];
206 | for i in xrange(1, len(prec)):
207 | prec[i] = max(prec[i], prec[i - 1]);
208 | prec = prec[indexes]
209 |
210 | prec = np.vstack((prec[0], prec[:, np.newaxis], prec[-1]));
211 |
212 | f = interp1d(thresh[:, 0], prec[:, 0])
213 | val = f(score)
214 | return val
215 |
216 | def load_vocabulary():
217 | # Load the vocabulary
218 | vocab_file = os.getcwd()+'/vocabs/vocab_train.pkl'
219 | vocab = load_variables(vocab_file)
220 |
221 | # define functional words
222 | functional_words = ['a', 'on', 'of', 'the', 'in', 'with', 'and', 'is', 'to', 'an', 'two', 'at', 'next', 'are']
223 | is_functional = np.array([x not in functional_words for x in vocab['words']])
224 |
225 | # load the score precision mapping file
226 | eval_file = os.getcwd()+'/model/coco_valid1_eval.pkl'
227 | pt = load_variables(eval_file)
228 | return vocab, functional_words, is_functional, pt
229 |
230 | def tic_toc_print(interval, string):
231 | global tic_toc_print_time_old
232 | if 'tic_toc_print_time_old' not in globals():
233 | tic_toc_print_time_old = time.time()
234 | print(string)
235 | else:
236 | new_time = time.time()
237 | if new_time - tic_toc_print_time_old > interval:
238 | tic_toc_print_time_old = new_time;
239 | print(string)
240 |
241 |
--------------------------------------------------------------------------------
/model/vgg_mil.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | from torch.autograd import Variable
5 | from torch.utils.serialization import load_lua
6 | import torch.nn.functional as F
7 |
8 | class vgg_mil(nn.Module):
9 | def __init__(self, opt):
10 | super(vgg_mil, self).__init__()
11 | self.conv = torch.nn.Sequential()
12 | self.conv.add_module("conv1_1", nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1))
13 | self.conv.add_module("relu_1_1", torch.nn.ReLU())
14 | self.conv.add_module("conv1_2", nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1))
15 | self.conv.add_module("relu_1_2", torch.nn.ReLU())
16 | self.conv.add_module("maxpool_1", torch.nn.MaxPool2d(kernel_size=2))
17 |
18 | self.conv.add_module("conv2_1", nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1))
19 | self.conv.add_module("relu_2_1", torch.nn.ReLU())
20 | self.conv.add_module("conv2_2", nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1))
21 | self.conv.add_module("relu_2_2", torch.nn.ReLU())
22 | self.conv.add_module("maxpool_2", torch.nn.MaxPool2d(kernel_size=2))
23 |
24 | self.conv.add_module("conv3_1", nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1))
25 | self.conv.add_module("relu_3_1", torch.nn.ReLU())
26 | self.conv.add_module("conv3_2", nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1))
27 | self.conv.add_module("relu_3_2", torch.nn.ReLU())
28 | self.conv.add_module("conv3_3", nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1))
29 | self.conv.add_module("relu_3_3", torch.nn.ReLU())
30 | self.conv.add_module("maxpool_3", torch.nn.MaxPool2d(kernel_size=2))
31 |
32 | self.conv.add_module("conv4_1", nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1))
33 | self.conv.add_module("relu_4_1", torch.nn.ReLU())
34 | self.conv.add_module("conv4_2", nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1))
35 | self.conv.add_module("relu_4_2", torch.nn.ReLU())
36 | self.conv.add_module("conv4_3", nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1))
37 | self.conv.add_module("relu_4_3", torch.nn.ReLU())
38 | self.conv.add_module("maxpool_4", torch.nn.MaxPool2d(kernel_size=2))
39 |
40 | self.conv.add_module("conv5_1", nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1))
41 | self.conv.add_module("relu_5_1", torch.nn.ReLU())
42 | self.conv.add_module("conv5_2", nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1))
43 | self.conv.add_module("relu_5_2", torch.nn.ReLU())
44 | self.conv.add_module("conv5_3", nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1))
45 | self.conv.add_module("relu_5_3", torch.nn.ReLU())
46 | self.conv.add_module("maxpool_5", torch.nn.MaxPool2d(kernel_size=2))
47 |
48 | self.conv.add_module("fc6_conv", nn.Conv2d(512, 4096, kernel_size=7, stride=1, padding=0))
49 | self.conv.add_module("relu_6_1", torch.nn.ReLU())
50 |
51 | self.conv.add_module("fc7_conv", nn.Conv2d(4096, 4096, kernel_size=1, stride=1, padding=0))
52 | self.conv.add_module("relu_7_1", torch.nn.ReLU())
53 |
54 | self.conv.add_module("fc8_conv", nn.Conv2d(4096, 1000, kernel_size=1, stride=1, padding=0))
55 | self.conv.add_module("sigmoid_8", torch.nn.Sigmoid())
56 |
57 | self.pool_mil = nn.MaxPool2d(kernel_size=11, stride=0)
58 |
59 | self.weight_init()
60 |
61 | def weight_init(self):
62 | self.cnn_weight = 'model/vgg16_full_conv_mil.pth'
63 | self.conv.load_state_dict(torch.load(self.cnn_weight))
64 | print("Load pretrained CNN model from " + self.cnn_weight)
65 |
66 | def forward(self, x):
67 | x0 = self.conv.forward(x.float())
68 | x = self.pool_mil(x0)
69 | x = x.squeeze(2).squeeze(2)
70 | x1 = torch.add(torch.mul(x0.view(x.size(0), 1000, -1), -1), 1)
71 | cumprod = torch.cumprod(x1, 2)
72 | out = torch.max(x, torch.add(torch.mul(cumprod[:, :, -1], -1), 1))
73 | out = F.softmax(out)
74 | return out
75 |
76 | class MIL_Precision_Score_Mapping(nn.Module):
77 | def __init__(self):
78 | super(MIL_Precision_Score_Mapping, self).__init__()
79 | self.mil = nn.MaxPool2d(kernel_size=11, stride=0)
80 |
81 | def forward(self, x, score, precision, mil_prob):
82 | out = self.mil(x)
83 | return out
84 |
85 |
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os
4 |
5 |
6 | def parse_opt():
7 | parser = argparse.ArgumentParser()
8 | # RL setting
9 | parser.add_argument('--model', type=str, default='vgg19')
10 | parser.add_argument('--learning', type=str, default='mle')
11 | parser.add_argument('--start_from', type=str, default='')
12 | # Data input settings
13 | parser.add_argument('--fc_feat_size', type=int, default=2048) # '2048 for resnet, 4096 for vgg'
14 | parser.add_argument('--att_feat_size', type=int, default=2048) # '2048 for resnet, 512 for vgg'
15 | # Optimization: General
16 | parser.add_argument('--max_epochs', type=int, default=-1) # 'number of epochs'
17 | parser.add_argument('--batch_size', type=int, default=2) # 'minibatch size'
18 | parser.add_argument('--seq_per_img', type=int,default=5) # number of captions to sample for each image during training.
19 | # Optimization: for the Language Model
20 | parser.add_argument('--optim', type=str, default='adam') # rmsprop|sgd|sgdmom|adagrad|adam
21 | parser.add_argument('--learning_rate', type=float, default=4e-4) # 'learning rate'
22 | parser.add_argument('--learning_rate_decay_start', type=int,default=0) # at what iteration to start decaying learning rate? (-1 = dont) (in epoch)
23 | parser.add_argument('--learning_rate_decay_every', type=int,default=5000) # every how many iterations thereafter to drop LR?(in epoch)
24 | parser.add_argument('--learning_rate_decay_rate', type=float,default=0.8) # every how many iterations thereafter to drop LR?(in epoch)
25 | parser.add_argument('--optim_alpha', type=float, default=0.8) # alpha for adam
26 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M')
27 | parser.add_argument('--optim_beta', type=float, default=0.999) # beta used for adam
28 | parser.add_argument('--optim_epsilon', type=float, default=1e-8) # epsilon that goes into denominator for smoothing
29 | parser.add_argument('--weight_decay', type=float, default=1e-4) # weight_decay
30 | parser.add_argument('--grad_clip', type=float, default=0.1) # clip gradients at this value
31 | parser.add_argument('--drop_prob_lm', type=float, default=0.5) # strength of dropout in the Language Model RNN
32 | # Datasets
33 | parser.add_argument('--input_json', type=str, default='data/mscoco/cocotalk_karpathy.json')
34 | parser.add_argument('--input_im_h5', type=str, default='data/mscoco/cocotalk_karpathy.h5')
35 | parser.add_argument('--input_label_h5', type=str, default='data/mscoco/cocotalk_karpathy_label_semantic_words.h5')
36 | # Evaluation/Checkpointing
37 | parser.add_argument('--split', type=str, default='train') # Dataset split type
38 | parser.add_argument('--val_images_use', type=int, default=5000) # number of images for period validation (-1 = all)
39 | parser.add_argument('--save_checkpoint_every', type=int,default=100)
40 | parser.add_argument('--checkpoint_path', type=str, default='save') # directory to store checkpointed models'
41 | parser.add_argument('--losses_log_every', type=int, default=25) # How often do we snapshot losses, (0 = disable)
42 | parser.add_argument('--load_best_score', type=int, default=1) # load previous best score when resuming training.
43 | # misc
44 | parser.add_argument('--n_gpus', type=int, default=1)
45 | parser.add_argument('--train_only', type=int, default=0) # If true then use 80k, else use 110k
46 | parser.add_argument('--gpus', default=[0, 1], nargs='+', type=int) # Use CUDA on the listed devices
47 | parser.add_argument('--model_id', type=str, default='') # Id identifying this run/job.
48 | # used in cross-val and appended when writing progress files'
49 |
50 | args = parser.parse_args()
51 | # Check if args are valid
52 | assert args.batch_size > 0, "batch_size should be greater than 0"
53 | assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
54 | assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
55 | assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
56 | assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
57 | assert args.load_best_score == 0 or args.load_best_score == 1, "should be 0 or 1"
58 |
59 | # Update args
60 | args.gpus = range(args.n_gpus)
61 | last_name = os.path.basename(args.start_from)
62 | last_time = last_name[0:8]
63 | if len(args.start_from):
64 | args.model_id = last_name
65 | else:
66 | args.model_id = datetime.datetime.now().strftime("%m%d%H%M") + "_mil_" + args.model + '_' + args.learning
67 | args.checkpoint_path = args.checkpoint_path + '/' + args.model_id
68 | return args
69 |
--------------------------------------------------------------------------------
/scripts/convert_tf2pth.sh:
--------------------------------------------------------------------------------
1 | #! /bin/sh
2 | # pip install mmdnn
3 | convert_tf_2_pytorch()
4 | {
5 | if [ ! -d tmp ]; then
6 | mkdir tmp
7 | fi
8 | # Extract tf model files
9 | python -m mmdnn.conversion.examples.tensorflow.extract_model -n resnet_v1_101 -ckpt /home/jxgu/github/MIL.pytorch/model/oidv2_resnet_v1_101/oidv2-resnet_v1_101.ckpt
10 | # Convert tf to IR
11 | python -m mmdnn.conversion._script.convertToIR -f tensorflow -d kit_imagenet -n imagenet_resnet_v1_101.ckpt.meta --dstNodeName Squeeze -w imagenet_resnet_v1_101.ckpt
12 | # Convert IR to Pytorch
13 | python -m mmdnn.conversion._script.IRToCode -f pytorch --IRModelPath kit_imagenet.pb --dstModelPath kit_imagenet.py --IRWeightPath kit_imagenet.npy -dw kit_pytorch.npy
14 | # Dump the PyTorch model
15 | python -m mmdnn.conversion.examples.pytorch.imagenet_test --dump resnet.pth -n kit_imagenet.py -w tmp/kit_pytorch.npy
16 | }
17 |
18 | convert_tf_2_pytorch
--------------------------------------------------------------------------------
/scripts/graphs/events.out.tfevents.1525752019.jxgu:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/scripts/graphs/events.out.tfevents.1525752019.jxgu
--------------------------------------------------------------------------------
/scripts/graphs/events.out.tfevents.1525752072.jxgu:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/scripts/graphs/events.out.tfevents.1525752072.jxgu
--------------------------------------------------------------------------------
/scripts/graphs/events.out.tfevents.1525752099.jxgu:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/scripts/graphs/events.out.tfevents.1525752099.jxgu
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """Set up paths."""
2 | import os
3 | import os.path as osp
4 | import sys
5 | import platform
6 | import cPickle
7 | import cv2, numpy as np
8 | from matplotlib.pyplot import show
9 | import matplotlib.pyplot as plt
10 | from scipy.interpolate import interp1d
11 | import time
12 | import Image
13 | import torch
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | import torch.nn.functional as F
17 | import torch.backends.cudnn as cudnn
18 | import tensorflow as tf
19 | import torchvision
20 | import torchvision.transforms as transforms
21 | from six.moves import cPickle
22 | import gc
23 | import os
24 | import pickle
25 | import argparse
26 | from model.models import *
27 | from model.utils import *
28 | import coco_voc
29 |
30 | this_dir = osp.dirname(__file__)
31 | ##############################################################################################
32 |
33 | def test_img(im, net, base_image_size, means):
34 | """
35 | Calls Caffe to get output for this image
36 | """
37 | batch_size = 1
38 | # Resize image
39 | im_orig = im.astype(np.float32, copy=True)
40 | im_orig -= means
41 |
42 | im, gr, grr = upsample_image(im_orig, base_image_size)
43 | im = np.transpose(im, axes=(2, 0, 1))
44 | im = im[np.newaxis, :, :, :]
45 |
46 | # Pass into model
47 | mil_prob = net(Variable(torch.from_numpy(im), requires_grad=False).cuda())
48 | return mil_prob
49 |
50 |
51 | def output_words_image(threshold_metric, output_metric, min_words, threshold, vocab, is_functional):
52 | ind_output = np.argsort(threshold_metric)
53 | ind_output = ind_output[::-1]
54 | must_keep1 = threshold_metric[ind_output] >= threshold;
55 | must_keep2 = np.cumsum(is_functional[ind_output]) < 1 + min_words;
56 | ind_output = [ind for j, ind in enumerate(ind_output) if must_keep1[j] or must_keep2[j]]
57 | out = [(vocab['words'][ind], output_metric[ind], threshold_metric[ind]) for ind in ind_output]
58 | return out
59 |
60 | ##############################################################################################
61 |
62 | '''load vocabulary'''
63 | vocab, functional_words, is_functional, pt = load_vocabulary()
64 |
65 | parser = argparse.ArgumentParser(description='PyTorch MIL Training')
66 | parser.add_argument('--start_from', type=str, default='')
67 | parser.add_argument('--load_precision_score', type=str, default='')
68 | parser.add_argument('--cnn_weight', default='model/mil.pth',
69 | help='cnn weights')
70 | opt = parser.parse_args()
71 |
72 | mil_model = vgg_mil(opt)
73 | mil_model.cuda()
74 | mil_model.eval()
75 |
76 |
77 | '''load caffe model'''
78 | mean = np.array([[[103.939, 116.779, 123.68]]]);
79 | base_image_size = 565;
80 |
81 | '''Load the image'''
82 | imageurl = 'http://img1.10bestmedia.com/Images/Photos/333810/Montrose_54_990x660.jpg'
83 | im = url_to_image(imageurl)
84 | im = cv2.resize(im, (base_image_size, base_image_size), interpolation=cv2.INTER_CUBIC)
85 |
86 | # Run the model
87 | mil_prob = test_img(im, mil_model, base_image_size, mean)
88 | mil_prob = mil_prob.data.cpu().float().numpy()
89 | # Compute precision mapping - slow in per image mode, much faster in batch mode
90 | prec = np.zeros(mil_prob.shape)
91 | if len(opt.load_precision_score) >0 :
92 | precision_score = pickle.load(open(opt.load_precision_score, 'rb'))
93 | else:
94 | precision_score = compute_precision_mapping(pt)
95 |
96 | for jj in xrange(prec.shape[1]):
97 | f = interp1d(precision_score['thresh'][jj][:,0], precision_score['prec'][jj][:,0])
98 | #prec[:, jj] = f(mil_prob[:, jj])
99 | prec[:, jj] = mil_prob[:, jj]
100 | mil_prec = prec
101 |
102 | #cv2.imshow('image', im)
103 | # Output words
104 | out = output_words_image(mil_prec[0, :], mil_prec[0, :], \
105 | min_words=10, threshold=0.0, vocab=vocab, is_functional=is_functional)
106 |
107 | plt.rcParams['figure.figsize'] = (10, 10)
108 | plt.imshow(im[:, :, [2, 1, 0]])
109 | plt.gca().set_axis_off()
110 | det_atts = []
111 | det_atts_w = []
112 | index = 0
113 | for (a, b, c) in out:
114 | if a not in functional_words:
115 | if index < 10:
116 | det_atts.append(a)
117 | det_atts_w.append(np.round(b, 2))
118 | index = index + 1
119 | # print '{:s} [{:.2f}, {:.2f}] '.format(a, np.round(b,2), np.round(c,2))
120 |
121 | print det_atts
122 | print det_atts_w
123 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | clear
4 | #-----------------------------------------------------------------------------------------------------------------------
5 | func_cls()
6 | {
7 | CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split coco_val2014
8 | CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split coco_test2014
9 | CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split coco_train2014
10 | #CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split coco_test2015
11 | #CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split chinese
12 | #CUDA_VISIBLE_DEVICES=1 python extract_cls_oid.py --split chinese_val
13 | }
14 |
15 | func_det()
16 | { export PYTHONPATH=$PYTHONPATH:/home/jxgu/github/MIL.pytorch/misc/models/research/object_detection
17 | CUDA_VISIBLE_DEVICES=1 python extract_det_oid.py --split coco_val2014
18 | CUDA_VISIBLE_DEVICES=1 python extract_det_oid.py --split coco_test2014
19 | CUDA_VISIBLE_DEVICES=1 python extract_det_oid.py --split coco_train2014
20 | CUDA_VISIBLE_DEVICES=1 python extract_det_oid.py --split coco_test2015
21 | }
22 | func_cls
23 |
--------------------------------------------------------------------------------
/test_v2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # Copyright 2017 The Open Images Authors. All Rights Reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # ==============================================================================
17 | r"""Classifier inference utility.
18 |
19 | This code takes a resnet_v1_101 checkpoint, runs the classifier on the image and
20 | prints predictions in human-readable form.
21 |
22 | -------------------------------
23 | Example command:
24 | -------------------------------
25 |
26 | # 0. Create directory for model/data
27 | WORK_PATH="/tmp/oidv2"
28 | mkdir -p "${WORK_PATH}"
29 | cd "${WORK_PATH}"
30 |
31 | # 1. Download the model, inference code, and sample image
32 | wget https://storage.googleapis.com/openimages/2017_07/classes-trainable.txt
33 | wget https://storage.googleapis.com/openimages/2017_07/class-descriptions.csv
34 | wget https://storage.googleapis.com/openimages/2017_07/oidv2-resnet_v1_101.ckpt.tar.gz
35 | wget https://raw.githubusercontent.com/openimages/dataset/master/tools/classify_oidv2.py
36 | tar -xzf oidv2-resnet_v1_101.ckpt.tar.gz
37 |
38 | wget -O cat.jpg https://farm6.staticflickr.com/5470/9372235876_d7d69f1790_b.jpg
39 |
40 | # 2. Run inference
41 | python classify_oidv2.py \
42 | --checkpoint_path='oidv2-resnet_v1_101.ckpt' \
43 | --labelmap='classes-trainable.txt' \
44 | --dict='class-descriptions.csv' \
45 | --image="cat.jpg" \
46 | --top_k=10 \
47 | --score_threshold=0.3
48 |
49 | # Sample output:
50 | Image: "cat.jpg"
51 |
52 | 3272: /m/068hy - Pet (score = 0.96)
53 | 1076: /m/01yrx - Cat (score = 0.95)
54 | 0708: /m/01l7qd - Whiskers (score = 0.90)
55 | 4755: /m/0jbk - Animal (score = 0.90)
56 | 2847: /m/04rky - Mammal (score = 0.89)
57 | 2036: /m/0307l - Felidae (score = 0.79)
58 | 3574: /m/07k6w8 - Small to medium-sized cats (score = 0.77)
59 | 4799: /m/0k0pj - Nose (score = 0.70)
60 | 1495: /m/02cqfm - Close-up (score = 0.55)
61 | 0036: /m/012c9l - Domestic short-haired cat (score = 0.40)
62 |
63 | -------------------------------
64 | Note on image preprocessing:
65 | -------------------------------
66 |
67 | This is the code used to perform preprocessing:
68 | --------
69 | from preprocessing import preprocessing_factory
70 |
71 | def PreprocessImage(image, network='resnet_v1_101', image_size=299):
72 | # If resolution is larger than 224 we need to adjust some internal resizing
73 | # parameters for vgg preprocessing.
74 | if any(network.startswith(x) for x in ['resnet', 'vgg']):
75 | preprocessing_kwargs = {
76 | 'resize_side_min': int(256 * image_size / 224),
77 | 'resize_side_max': int(512 * image_size / 224)
78 | }
79 | else:
80 | preprocessing_kwargs = {}
81 | preprocessing_fn = preprocessing_factory.get_preprocessing(
82 | name=network, is_training=False)
83 |
84 | height = image_size
85 | width = image_size
86 | image = preprocessing_fn(image, height, width, **preprocessing_kwargs)
87 | image.set_shape([height, width, 3])
88 | return image
89 | --------
90 |
91 | Note that there appears to be a small difference between the public version
92 | of slim image processing library and the internal version (which the meta
93 | graph is based on). Results that are very close, but not exactly identical to
94 | that of the metagraph.
95 | """
96 |
97 | from __future__ import absolute_import
98 | from __future__ import division
99 | from __future__ import print_function
100 |
101 | import tensorflow as tf
102 |
103 | flags = tf.app.flags
104 | FLAGS = flags.FLAGS
105 |
106 | flags.DEFINE_string('labelmap', 'classes-trainable.txt',
107 | 'Labels, one per line.')
108 |
109 | flags.DEFINE_string('dict', 'class-descriptions.csv',
110 | 'Descriptive string for each label.')
111 |
112 | flags.DEFINE_string('checkpoint_path', 'oidv2-resnet_v1_101.ckpt',
113 | 'Path to checkpoint file.')
114 |
115 | flags.DEFINE_string('image', '',
116 | 'Comma separated paths to image files on which to perform '
117 | 'inference.')
118 |
119 | flags.DEFINE_integer('top_k', 10, 'Maximum number of results to show.')
120 |
121 | flags.DEFINE_float('score_threshold', None, 'Score threshold.')
122 |
123 |
124 | def LoadLabelMap(labelmap_path, dict_path):
125 | """Load index->mid and mid->display name maps.
126 |
127 | Args:
128 | labelmap_path: path to the file with the list of mids, describing
129 | predictions.
130 | dict_path: path to the dict.csv that translates from mids to display names.
131 | Returns:
132 | labelmap: an index to mid list
133 | label_dict: mid to display name dictionary
134 | """
135 | labelmap = [line.rstrip() for line in tf.gfile.GFile(labelmap_path)]
136 |
137 | label_dict = {}
138 | for line in tf.gfile.GFile(dict_path):
139 | words = [word.strip(' "\n') for word in line.split(',', 1)]
140 | label_dict[words[0]] = words[1]
141 |
142 | return labelmap, label_dict
143 |
144 |
145 | def main(_):
146 | # Load labelmap and dictionary from disk.
147 | labelmap, label_dict = LoadLabelMap(FLAGS.labelmap, FLAGS.dict)
148 |
149 | g = tf.Graph()
150 | with g.as_default():
151 | with tf.Session() as sess:
152 | saver = tf.train.import_meta_graph(FLAGS.checkpoint_path + '.meta')
153 | saver.restore(sess, FLAGS.checkpoint_path)
154 |
155 | input_values = g.get_tensor_by_name('input_values:0')
156 | predictions = g.get_tensor_by_name('multi_predictions:0')
157 |
158 | for image_filename in FLAGS.image.split(','):
159 | compressed_image = tf.gfile.FastGFile(image_filename, 'rb').read()
160 | predictions_eval = sess.run(
161 | predictions, feed_dict={
162 | input_values: [compressed_image]
163 | })
164 | top_k = predictions_eval.argsort()[::-1] # indices sorted by score
165 | if FLAGS.top_k > 0:
166 | top_k = top_k[:FLAGS.top_k]
167 | if FLAGS.score_threshold is not None:
168 | top_k = [i for i in top_k
169 | if predictions_eval[i] >= FLAGS.score_threshold]
170 | print('Image: "%s"\n' % image_filename)
171 | for idx in top_k:
172 | mid = labelmap[idx]
173 | display_name = label_dict[mid]
174 | score = predictions_eval[idx]
175 | print('{:04d}: {} - {} (score = {:.2f})'.format(
176 | idx, mid, display_name, score))
177 |
178 |
179 | if __name__ == '__main__':
180 | tf.app.run()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Use tensorboard
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import functools
8 | import os
9 | import time
10 | from six.moves import cPickle
11 | from dataloader import *
12 | from model import *
13 | import tensorflow as tf
14 | import torch.nn as nn
15 | import torch.utils.model_zoo as model_zoo
16 | import numpy as np
17 | import math
18 | import torch
19 | import torch.nn.init as init
20 | import torch.nn.functional as F
21 | from torch.autograd import Variable
22 | import opts
23 | from model import eval_utils
24 | from model import utils
25 | from model import models
26 |
27 | rusage_denom = 1024
28 | printf = functools.partial(print, end="")
29 |
30 | def extract_fts(opt, data):
31 | images = Variable(torch.from_numpy(data['images']), volatile=False).cuda()
32 | mil_label = Variable(torch.from_numpy(data['mil_label']),volatile=False).cuda()
33 | return images, mil_label
34 |
35 | def record_training(opt, model, iteration, tf_summary_writer, current_record, history_record):
36 | [train_loss] = current_record
37 | [loss_history, lr_history] = history_record
38 | utils.add_summary_value(tf_summary_writer, 'train_lr', opt.learning_rate, iteration)
39 | utils.add_summary_value(tf_summary_writer, 'train_weight_decay', opt.weight_decay, iteration)
40 | utils.add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
41 | utils.add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
42 | tf_summary_writer.flush()
43 | loss_history[iteration] = train_loss
44 | lr_history[iteration] = opt.current_lr
45 | return history_record
46 |
47 | def record_ckpt(opt, infos, model, optimizer, best_flag):
48 | tag = '-best' if best_flag else ''
49 | print("Save language model start")
50 | checkpoint_path = os.path.join(opt.checkpoint_path, opt.model_id + '.model' + tag + '.pth')
51 | optimizer_path = os.path.join(opt.checkpoint_path, opt.model_id + '.optimizer' + tag + '.pth')
52 | torch.save(model.state_dict(), checkpoint_path)
53 | torch.save(optimizer.state_dict(), optimizer_path)
54 | print("Save infos start")
55 | with open(os.path.join(opt.checkpoint_path, opt.model_id + '.infos' + tag + '.pkl'), 'wb') as f:
56 | cPickle.dump(infos, f)
57 | print("model saved to {}".format(checkpoint_path))
58 |
59 | def train(opt):
60 | print("Load dataset with image features, and labels\n")
61 | loader = DataLoader(opt)
62 | opt.vocab_size = loader.vocab_size
63 | opt.seq_length = loader.seq_length
64 |
65 | tf_summary_writer = tf.summary.FileWriter(opt.checkpoint_path)
66 | print("Load informations from infos.pkl ... ")
67 | opt, infos, iteration, epoch, val_history, train_history = utils.history_infos(opt)
68 | [loss_history, lr_history] = train_history
69 | [val_result_history, best_val_score, val_loss] = val_history
70 |
71 | # Update dataloader info
72 | loader.iterators = infos.get('iterators', loader.iterators)
73 | loader.split_ix = infos.get('split_ix', loader.split_ix)
74 |
75 | print("Build image cnn model, and initialize it with pre-trained cnn model")
76 | model, crit, optimizer, infos = models.build_models(opt, infos)
77 |
78 | update_lr_flag = True
79 | while True:
80 | gc.collect() # collect cpu memory
81 | if update_lr_flag:
82 | utils.paramsForEpoch(opt, epoch, optimizer)
83 | utils.update_lr(opt, epoch, optimizer) # Assign the learning rate
84 |
85 | data = loader.get_batch('train') # Load data from train split (0)
86 | torch.cuda.synchronize()
87 | start = time.time()
88 |
89 | images, mil_label = extract_fts(opt, data)
90 | optimizer.zero_grad()
91 | crit_outputs = crit(model(images), mil_label)
92 | loss = crit_outputs[0]
93 | loss.backward()
94 | utils.clip_gradient(optimizer, opt.grad_clip)
95 | optimizer.step()
96 | torch.cuda.synchronize()
97 | train_loss = loss.data[0]
98 |
99 | last_name = os.path.basename(opt.model_id)
100 | last_time = last_name[0:8]
101 | print(
102 | "{}/{},{}/{},loss(t|{:.4f},v|{:.4f})|T/B({:.2f})" \
103 | .format(opt.model+'.'+last_time, iteration, epoch, opt.batch_size,
104 | train_loss, val_loss,
105 | time.time() - start))
106 |
107 | # Update the iteration and epoch
108 | iteration += 1
109 | if data['bounds']['wrapped']:
110 | epoch += 1
111 | update_lr_flag = True
112 |
113 | # Write the training loss summary
114 | if (iteration % opt.losses_log_every == 0):
115 | current_record = [train_loss]
116 | history_record = [loss_history, lr_history]
117 | history_record = record_training(opt, model, iteration, tf_summary_writer, current_record, history_record)
118 | [loss_history, lr_history] = history_record
119 |
120 | # make evaluation on validation set, and save model
121 | if (iteration % opt.save_checkpoint_every == 0):
122 | eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
123 | eval_kwargs.update(vars(opt))
124 | eval_kwargs['split'] = 'test'
125 | eval_kwargs['dataset'] = opt.input_json
126 | val_loss = eval_utils.eval_split(opt, model, crit, loader, eval_kwargs)
127 |
128 | utils.add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
129 | tf_summary_writer.flush()
130 | val_result_history[iteration] = {'loss': val_loss}
131 |
132 | # Save model if is improving on validation result
133 | current_score = val_loss
134 | best_flag = False
135 | if True: # if true
136 | if best_val_score is None or current_score > best_val_score:
137 | best_val_score = current_score
138 | best_flag = True
139 | # Dump miscalleous informations
140 | infos['iter'] = iteration
141 | infos['epoch'] = epoch
142 | infos['iterators'] = loader.iterators
143 | infos['split_ix'] = loader.split_ix
144 | infos['best_val_score'] = best_val_score
145 | infos['opt'] = opt
146 | infos['val_result_history'] = val_result_history
147 | infos['loss_history'] = loss_history
148 | infos['lr_history'] = lr_history
149 | infos['vocab'] = loader.get_vocab()
150 | # Dump checkpoint
151 | record_ckpt(opt, infos, model, optimizer, best_flag)
152 | # Stop if reaching max epochs
153 | if epoch >= opt.max_epochs and opt.max_epochs != -1:
154 | break
155 |
156 |
157 | '''
158 | Main function: Start from here !!!
159 | '''
160 | opt = opts.parse_opt()
161 | train(opt)
162 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | clear
4 |
5 | case "$1" in
6 | 0)
7 | echo "run resnet 101 debug"
8 | CUDA_VISIBLE_DEVICES=0,1 python train.py --model 'resnet101' --n_gpus=2 --batch_size 50 --optim='sgd' --learning_rate_decay_start=0
9 | ;;
10 |
11 | 1)
12 | echo "run vgg19 debug"
13 | CUDA_VISIBLE_DEVICES=0,1 python train.py --model 'vgg19' --n_gpus=2 --batch_size 10 --optim='sgd' --learning_rate_decay_start=0
14 | ;;
15 |
16 | 2)
17 | echo "run resnet101 debug"
18 | CUDA_VISIBLE_DEVICES=1 python train.py --model 'resnet101' --batch_size 10 --learning_rate 1e-4
19 | ;;
20 |
21 | *)
22 | echo
23 | echo "No input"
24 | ;;
25 | esac
26 |
27 |
--------------------------------------------------------------------------------
/vocabs/vocab_train.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/vocabs/vocab_train.pkl
--------------------------------------------------------------------------------
/vocabs/vocab_words.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gujiuxiang/MIL.pytorch/d0d223b92097532b4c906e2a10113507eff18cac/vocabs/vocab_words.txt
--------------------------------------------------------------------------------