├── LICENSE ├── README.md ├── dataset ├── process_CUHK_data.py └── utils │ └── read_write_data.py ├── run ├── test.bash └── train.bash └── src ├── data ├── dataloader.py └── dataset.py ├── loss ├── Id_loss.py └── RankingLoss.py ├── model └── model.py ├── option └── options.py ├── test.py ├── test_during_train.py ├── train.py ├── transforms ├── __init__.py ├── functional.py └── transforms.py └── utils └── read_write_data.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Suo-Wei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![LICENSE](https://img.shields.io/badge/license-MIT-green)](https://github.com/taksau/GPS-Net/blob/master/LICENSE) 2 | [![Python](https://img.shields.io/badge/python-3.7-blue.svg)](https://www.python.org/) 3 | ![PyTorch](https://img.shields.io/badge/pytorch-1.4.0-%237732a8) 4 | 5 | # A Simple and Robust Correlation Filtering method for text-based person search 6 | We provide the code for reproducing results of our ECCV 2022 paper [A Simple and Robust Correlation Filtering method for text-based person search](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136950719.pdf). 7 | Compared with the original paper, we obtain better performance due to some modifications. Following our global response map, we also add the same mutual-exclusion-loss to separate body part response map. Meanwhile, we merge the global filter and dictionary filter module. The adjusted method achieves a new state-of-the-art performance and it improves to 64.88 on Top-1 *without* [Re-Rank](https://github.com/TencentYoutuResearch/PersonReID-NAFS?utm_source=catalyzex.com) (CUHK-PEDES). 8 | ## Getting Started 9 | ### Requirements 10 | - [PyTorch](https://pytorch.org/) 1.4 or higher 11 | - [transformers](https://huggingface.co/docs/transformers/index) (install with `pip install transformers`) 12 | - numpy, torchvision 13 | 14 | ### Dataset Preparation 15 | 16 | Organize them in `dataset` folder as follows: 17 | 18 | ~~~ 19 | |-- dataset/ 20 | | |-- / 21 | | |-- imgs 22 | |-- cam_a 23 | |-- cam_b 24 | |-- ... 25 | | |-- reid_raw.json 26 | 27 | ~~~ 28 | Download the CUHK-PEDES dataset from [here](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description) and then run the `process_CUHK_data.py` as follow: 29 | ~~~ 30 | cd SRCF 31 | python ./dataset/process_CUHK_data.py 32 | ~~~ 33 | 34 | ### Building BERT 35 | ~~~ 36 | mkdir bert_weight 37 | ~~~ 38 | 39 | Downland the [weight and config](https://huggingface.co/bert-base-uncased/tree/main), put them into SRCF/bert_weight 40 | 41 | ### Training and Testing 42 | ~~~ 43 | bash run/train.bash 44 | ~~~ 45 | ### Evaluation 46 | ~~~ 47 | bash run/test.bash 48 | ~~~ 49 | 50 | ## Results on CUHK-PEDES 51 | 52 | |CUHK-PEDES | performance | 53 | |------|------| 54 | | `Top-1` | 64.88 | 55 | | `Top-5` | 83.02 | 56 | | `Top-10` | 88.56 | 57 | 58 | ## Citation 59 | 60 | If this work is helpful for your research, please cite our work: 61 | 62 | ~~~ 63 | @InProceedings{Suo_ECCV_A, 64 | author = {Suo, Wei and Sun, MengYang and Niu, Kai, et.al}, 65 | title = {A Simple and Robust Correlation Filtering method for text-based person search}, 66 | booktitle = {The European Conference on Computer Vision (ECCV)}, 67 | month = {August}, 68 | year = {2022} 69 | } 70 | ~~~ 71 | 72 | ### References 73 | [SSAN](https://github.com/zifyloo/SSAN/) 74 | -------------------------------------------------------------------------------- /dataset/process_CUHK_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | processes the CUHK_PEDES/reid_raw.json, output the train_data, val_data, test_data, 4 | 5 | For example: 6 | train_data.pkl contains the data in dict format 7 | 'id': id for the caption-image pair 8 | 'img_path': the image in the caption-image pair 9 | 'same_id_index': the id number in the dict of captions of other images from same id 10 | 'lstm_caption_id': the code of per caption for bi-lstm 11 | 'captions': the caption in the caption-image pair 12 | 13 | 14 | @author: zifyloo 15 | """ 16 | 17 | from utils.read_write_data import read_json, makedir, save_dict, write_txt 18 | import argparse 19 | import os 20 | import numpy as np 21 | 22 | 23 | class Word2Index(object): 24 | 25 | def __init__(self, vocab): 26 | self._vocab = {w: index + 1 for index, w in enumerate(vocab)} 27 | self.unk_id = len(vocab) + 1 28 | 29 | def __call__(self, word): 30 | if word not in self._vocab: 31 | return self.unk_id 32 | return self._vocab[word] 33 | 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser(description='Command for data pre_processing') 37 | parser.add_argument('--json_root', default='CUHK-PEDES/reid_raw.json', type=str) 38 | parser.add_argument('--out_root', default='CUHK-PEDES/processed_data', type=str) 39 | parser.add_argument('--min_word_count', default='2', type=int) 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | def split_json(args): 45 | """ 46 | has 40206 image in reid_raw_data 47 | has 13003 id 48 | every id has several images and every image has several caption 49 | data's structure in reid_raw_data is dict ['split', 'captions', 'file_path', 'processed_tokens', 'id'] 50 | """ 51 | reid_raw_data = read_json(args.json_root) 52 | 53 | train_json = [] 54 | test_json = [] 55 | val_json = [] 56 | for data in reid_raw_data: 57 | data_save = { 58 | 'img_path': 'imgs/'+data['file_path'], 59 | 'id': data['id'], 60 | 'tokens': data['processed_tokens'], 61 | 'captions': data['captions'] 62 | } 63 | 64 | split = data['split'].lower() 65 | if split == 'train': 66 | train_json.append(data_save) 67 | elif split == 'test': 68 | test_json.append(data_save) 69 | else: 70 | val_json.append(data_save) 71 | return train_json, test_json, val_json 72 | 73 | 74 | def build_vocabulary(train_json, args): 75 | 76 | word_count = {} 77 | for data in train_json: 78 | for caption in data['tokens']: 79 | for word in caption: 80 | word_count[word.lower()] = word_count.get(word.lower(), 0) + 1 81 | 82 | word_count_list = [[v, k] for v, k in word_count.items()] 83 | word_count_list.sort(key=lambda x: x[1], reverse=True) # from high to low 84 | 85 | good_vocab = [v for v, k in word_count.items() if k >= args.min_word_count] 86 | 87 | print('top-10 highest frequency words:') 88 | for w, n in word_count_list[0:10]: 89 | print(w, n) 90 | 91 | good_count = len(good_vocab) 92 | all_count = len(word_count_list) 93 | good_word_rate = good_count * 100.0 / all_count 94 | st = 'good words: %d, total_words: %d, good_word_rate: %f%%' % (good_count, all_count, good_word_rate) 95 | write_txt(st, os.path.join(args.out_root, 'data_message')) 96 | print(st) 97 | word2Ind = Word2Index(good_vocab) 98 | 99 | save_dict(good_vocab, os.path.join(args.out_root, 'ind2word')) 100 | return word2Ind 101 | 102 | 103 | def generate_captionid(data_json, word2Ind, data_name, args): 104 | 105 | id_save = [] 106 | lstm_caption_id_save = [] 107 | img_path_save = [] 108 | caption_save = [] 109 | same_id_index_save = [] 110 | un_idx = word2Ind.unk_id 111 | data_save_by_id = {} 112 | for data in data_json: 113 | 114 | if data['id'] in [1369, 4116, 6116]: # CR need two images for per id at least, these ids have only one image, 115 | continue 116 | if data['id'] > 6116: 117 | id_new = data['id'] - 4 118 | elif data['id'] > 4116: 119 | id_new = data['id'] - 3 120 | elif data['id'] > 1369: 121 | id_new = data['id'] - 2 122 | else: 123 | id_new = data['id'] - 1 124 | 125 | data_save_i = { 126 | 'img_path': data['img_path'], 127 | 'id': id_new, 128 | 'tokens': data['tokens'], 129 | 'captions': data['captions'] 130 | } 131 | if id_new not in data_save_by_id.keys(): 132 | data_save_by_id[id_new] = [] 133 | 134 | data_save_by_id[id_new].append(data_save_i) 135 | 136 | data_order = 0 137 | for id_new, data_save_by_id_i in data_save_by_id.items(): 138 | 139 | caption_length = 0 140 | for data_save_by_id_i_i in data_save_by_id_i: 141 | caption_length += len(data_save_by_id_i_i['captions']) 142 | 143 | data_order_i = data_order + np.arange(caption_length) 144 | data_order_i_begin = 0 145 | 146 | for data_save_by_id_i_i in data_save_by_id_i: 147 | caption_length_i = len(data_save_by_id_i_i['captions']) 148 | data_order_i_end = data_order_i_begin + caption_length_i 149 | data_order_i_select = np.delete(data_order_i, np.arange(data_order_i_begin, data_order_i_end)) 150 | data_order_i_begin = data_order_i_end 151 | 152 | for j in range(len(data_save_by_id_i_i['tokens'])): 153 | tokens_j = data_save_by_id_i_i['tokens'][j] 154 | lstm_caption_id = [] 155 | for word in tokens_j: 156 | lstm_caption_id.append(word2Ind(word)) 157 | if un_idx in lstm_caption_id: 158 | lstm_caption_id = list(filter(lambda x: x != un_idx, lstm_caption_id)) 159 | 160 | caption_j = data_save_by_id_i_i['captions'][j] 161 | 162 | id_save.append(data_save_by_id_i_i['id']) 163 | img_path_save.append(data_save_by_id_i_i['img_path']) 164 | same_id_index_save.append(data_order_i_select) 165 | 166 | lstm_caption_id_save.append(lstm_caption_id) 167 | caption_save.append(caption_j) 168 | 169 | data_order = data_order + caption_length 170 | 171 | data_save = { 172 | 'id': id_save, 173 | 'img_path': img_path_save, 174 | 'same_id_index': same_id_index_save, 175 | 176 | 'lstm_caption_id': lstm_caption_id_save, 177 | 'captions': caption_save, 178 | } 179 | 180 | img_num = len(set(img_path_save)) 181 | id_num = len(set(id_save)) 182 | caption_num = len(lstm_caption_id_save) 183 | 184 | st = '%s_img_num: %d, %s_id_num: %d, %s_caption_num: %d, ' \ 185 | % (data_name, img_num, data_name, id_num, data_name, caption_num) 186 | write_txt(st, os.path.join(args.out_root, 'data_message')) 187 | 188 | return data_save 189 | 190 | 191 | def generate_test_val_caption_id(data_json, word2Ind, data_name, args): 192 | id_save = [] 193 | lstm_caption_id_save = [] 194 | caption_save = [] 195 | img_path_save = [] 196 | img_caption_index_save = [] 197 | caption_matching_img_index_save = [] 198 | caption_label_save = [] 199 | 200 | un_idx = word2Ind.unk_id 201 | 202 | img_caption_index_i = 0 203 | caption_matching_img_index_i = 0 204 | for data in data_json: 205 | id_save.append(data['id']) 206 | img_path_save.append(data['img_path']) 207 | 208 | for j in range(len(data['tokens'])): 209 | 210 | tokens_j = data['tokens'][j] 211 | lstm_caption_id = [] 212 | for word in tokens_j: 213 | lstm_caption_id.append(word2Ind(word)) 214 | if un_idx in lstm_caption_id: 215 | lstm_caption_id = list(filter(lambda x: x != un_idx, lstm_caption_id)) 216 | 217 | caption_j = data['captions'][j] 218 | 219 | caption_matching_img_index_save.append(caption_matching_img_index_i) 220 | lstm_caption_id_save.append(lstm_caption_id) 221 | caption_save.append(caption_j) 222 | 223 | caption_label_save.append(data['id']) 224 | img_caption_index_save.append([img_caption_index_i, img_caption_index_i+len(data['captions'])-1]) 225 | img_caption_index_i += len(data['captions']) 226 | caption_matching_img_index_i += 1 227 | 228 | data_save = { 229 | 'id': id_save, 230 | 'img_path': img_path_save, 231 | 'img_caption_index': img_caption_index_save, 232 | 233 | 'caption_matching_img_index': caption_matching_img_index_save, 234 | 'caption_label': caption_label_save, 235 | 'lstm_caption_id': lstm_caption_id_save, 236 | 'captions': caption_save, 237 | } 238 | 239 | img_num = len(set(img_path_save)) 240 | id_num = len(set(id_save)) 241 | caption_num = len(lstm_caption_id_save) 242 | st = '%s_img_num: %d, %s_id_num: %d, %s_caption_num: %d, ' % ( 243 | data_name, img_num, data_name, id_num, data_name, caption_num) 244 | write_txt(st, os.path.join(args.out_root, 'data_message')) 245 | 246 | return data_save 247 | 248 | 249 | def main(args): 250 | train_json, test_json, val_json = split_json(args) 251 | 252 | word2Ind = build_vocabulary(train_json, args) 253 | 254 | train_save = generate_captionid(train_json, word2Ind, 'train', args) 255 | test_save = generate_test_val_caption_id(test_json, word2Ind, 'test', args) 256 | val_save = generate_captionid(val_json, word2Ind, 'val', args) 257 | 258 | save_dict(train_save, os.path.join(args.out_root, 'train_save')) 259 | save_dict(test_save, os.path.join(args.out_root, 'test_save')) 260 | save_dict(val_save, os.path.join(args.out_root, 'val_save')) 261 | 262 | 263 | if __name__ == '__main__': 264 | 265 | args = parse_args() 266 | 267 | makedir(args.out_root) 268 | main(args) 269 | -------------------------------------------------------------------------------- /dataset/utils/read_write_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | the tool to read or write the data. Have a good luck! 4 | 5 | Created on Thurs., Aug. 1(st), 2019 at 20:15. 6 | 7 | Updated on Thurs., Aug. 9(th), 2021 at 11:12 8 | 9 | @author: zifyloo 10 | """ 11 | 12 | import os 13 | import json 14 | import pickle 15 | 16 | 17 | def makedir(root): 18 | if not os.path.exists(root): 19 | os.makedirs(root) 20 | 21 | 22 | def write_json(data, root): 23 | with open(root, 'w') as f: 24 | json.dump(data, f) 25 | 26 | 27 | def read_json(root): 28 | with open(root, 'r') as f: 29 | data = json.load(f) 30 | 31 | return data 32 | 33 | 34 | def read_dict(root): 35 | with open(root, 'rb') as f: 36 | data = pickle.load(f) 37 | 38 | return data 39 | 40 | 41 | def save_dict(data, name): 42 | with open(name + '.pkl', 'wb') as f: 43 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 44 | 45 | 46 | def write_txt(data, name): 47 | with open(name, 'a') as f: 48 | f.write(data) 49 | f.write('\n') 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /run/test.bash: -------------------------------------------------------------------------------- 1 | python src/test.py --model_name 'SRCF' \ 2 | --GPU_id 0 \ 3 | --part 6 \ 4 | --lr 0.001 \ 5 | --dataset 'CUHK-PEDES' \ 6 | --dataroot '../dataset/CUHK-PEDES/' \ 7 | --vocab_size 5000 \ 8 | --feature_length 1024 \ 9 | --mode 'test' 10 | -------------------------------------------------------------------------------- /run/train.bash: -------------------------------------------------------------------------------- 1 | save_name=SRCF 2 | gpu=0 3 | CUDA_VISIBLE_DEVICES=$gpu nohup python -u src/train.py \ 4 | --model_name $save_name \ 5 | --part 6 \ 6 | --lr 0.0005 \ 7 | --dataset CUHK-PEDES \ 8 | --epoch 60 \ 9 | --dataroot dataset/CUHK-PEDES/ \ 10 | --class_num 11000 \ 11 | --vocab_size 5000 \ 12 | --feature_length 1024 \ 13 | --mode train \ 14 | --batch_size 32 \ 15 | --cr_beta 0.1 16 | 17 | -------------------------------------------------------------------------------- /src/data/dataloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | from torchvision import transforms 7 | from PIL import Image 8 | import torch 9 | from data.dataset import CUHKPEDEDataset, CUHKPEDE_img_dateset, CUHKPEDE_txt_dateset 10 | 11 | 12 | def get_dataloader(opt): 13 | """ 14 | tranforms the image, downloads the image with the id by data.DataLoader 15 | """ 16 | 17 | if opt.mode == 'train': 18 | transform_list = [ 19 | transforms.RandomHorizontalFlip(), 20 | transforms.Resize((384, 128), Image.BICUBIC), # interpolation 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.5, 0.5, 0.5), 23 | (0.5, 0.5, 0.5))] 24 | tran = transforms.Compose(transform_list) 25 | 26 | dataset = CUHKPEDEDataset(opt,tran) 27 | 28 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, 29 | shuffle=True, drop_last=True, num_workers=12) 30 | print('{}-{} has {} pohtos'.format(opt.dataset, opt.mode, len(dataset))) 31 | 32 | return dataloader 33 | 34 | else: 35 | tran = transforms.Compose([ 36 | transforms.Resize((384, 128), Image.BICUBIC), # interpolation 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.5, 0.5, 0.5), 39 | (0.5, 0.5, 0.5))] 40 | ) 41 | 42 | img_dataset = CUHKPEDE_img_dateset(opt,tran) 43 | 44 | img_dataloader = torch.utils.data.DataLoader(img_dataset, batch_size=opt.batch_size, 45 | shuffle=False, drop_last=False, num_workers=12) 46 | 47 | txt_dataset = CUHKPEDE_txt_dateset(opt) 48 | 49 | txt_dataloader = torch.utils.data.DataLoader(txt_dataset, batch_size=opt.batch_size, 50 | shuffle=False, drop_last=False, num_workers=12) 51 | 52 | print('{}-{} has {} pohtos, {} text'.format(opt.dataset, opt.mode, len(img_dataset), len(txt_dataset))) 53 | 54 | return img_dataloader, txt_dataloader 55 | 56 | def collate_fn4train(batch): 57 | imgs = [] 58 | label = [] 59 | label_swap = [] 60 | law_swap = [] 61 | caption = [] 62 | caption_mask = [] 63 | caption_cr = [] 64 | caption_cr_mask =[] 65 | 66 | for sample in batch: 67 | imgs.append(sample[0]) 68 | imgs.append(sample[1]) 69 | label.append(sample[2]) 70 | label.append(sample[2]) 71 | 72 | label_swap.append(sample[2]) 73 | label_swap.append(sample[3]) 74 | 75 | law_swap.append(sample[4]) 76 | law_swap.append(sample[5]) 77 | # img_name.append(sample[-1]) 78 | caption.append(sample[6]) 79 | caption_mask.append(sample[7]) 80 | caption_cr.append(sample[8]) 81 | caption_cr_mask.append(sample[9]) 82 | return torch.stack(imgs, 0), label, label_swap, law_swap, caption,caption_mask,caption_cr,caption_cr_mask -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | import torch 7 | import torch.utils.data as data 8 | import numpy as np 9 | from PIL import Image 10 | import os 11 | from utils.read_write_data import read_dict 12 | from transforms import transforms 13 | import cv2 14 | # import torchvision.transforms.functional as F 15 | import random 16 | import re 17 | from pytorch_pretrained_bert.tokenization import BertTokenizer 18 | from PIL import ImageStat 19 | import copy 20 | 21 | def fliplr(img, dim): 22 | """ 23 | flip horizontal 24 | :param img: 25 | :return: 26 | """ 27 | inv_idx = torch.arange(img.size(dim) - 1, -1, -1).long() # N x C x H x W 28 | img_flip = img.index_select(dim, inv_idx) 29 | return img_flip 30 | 31 | def read_examples(input_line, unique_id): 32 | """Read a list of `InputExample`s from an input file.""" 33 | examples = [] 34 | # unique_id = 0 35 | line = input_line #reader.readline() 36 | # if not line: 37 | # break 38 | line = line.strip() 39 | text_a = None 40 | text_b = None 41 | m = re.match(r"^(.*) \|\|\| (.*)$", line) 42 | if m is None: 43 | text_a = line 44 | else: 45 | text_a = m.group(1) 46 | text_b = m.group(2) 47 | examples.append( 48 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 49 | # unique_id += 1 50 | return examples 51 | 52 | class InputExample(object): 53 | def __init__(self, unique_id, text_a, text_b): 54 | self.unique_id = unique_id 55 | self.text_a = text_a 56 | self.text_b = text_b 57 | 58 | class InputFeatures(object): 59 | """A single set of features of data.""" 60 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 61 | self.unique_id = unique_id 62 | self.tokens = tokens 63 | self.input_ids = input_ids 64 | self.input_mask = input_mask 65 | self.input_type_ids = input_type_ids 66 | 67 | def convert_examples_to_features(examples, seq_length, tokenizer): 68 | """Loads a data file into a list of `InputBatch`s.""" 69 | features = [] 70 | for (ex_index, example) in enumerate(examples): 71 | tokens_a = tokenizer.tokenize(example.text_a) 72 | 73 | tokens_b = None 74 | if example.text_b: 75 | tokens_b = tokenizer.tokenize(example.text_b) 76 | 77 | if tokens_b: 78 | # Modifies `tokens_a` and `tokens_b` in place so that the total 79 | # length is less than the specified length. 80 | # Account for [CLS], [SEP], [SEP] with "- 3" 81 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 82 | else: 83 | # Account for [CLS] and [SEP] with "- 2" 84 | if len(tokens_a) > seq_length - 2: 85 | tokens_a = tokens_a[0:(seq_length - 2)] 86 | tokens = [] 87 | input_type_ids = [] 88 | tokens.append("[CLS]") 89 | input_type_ids.append(0) 90 | for token in tokens_a: 91 | tokens.append(token) 92 | input_type_ids.append(0) 93 | tokens.append("[SEP]") 94 | input_type_ids.append(0) 95 | 96 | if tokens_b: 97 | for token in tokens_b: 98 | tokens.append(token) 99 | input_type_ids.append(1) 100 | tokens.append("[SEP]") 101 | input_type_ids.append(1) 102 | 103 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 104 | 105 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 106 | # tokens are attended to. 107 | input_mask = [1] * len(input_ids) 108 | 109 | # Zero-pad up to the sequence length. 110 | while len(input_ids) < seq_length: 111 | input_ids.append(0) 112 | input_mask.append(0) 113 | input_type_ids.append(0) 114 | 115 | assert len(input_ids) == seq_length 116 | assert len(input_mask) == seq_length 117 | assert len(input_type_ids) == seq_length 118 | features.append( 119 | InputFeatures( 120 | unique_id=example.unique_id, 121 | tokens=tokens, 122 | input_ids=input_ids, 123 | input_mask=input_mask, 124 | input_type_ids=input_type_ids)) 125 | return features 126 | 127 | 128 | def load_data_transformers(resize_reso, crop_reso, swap_num=(12, 4)): 129 | center_resize = 600 130 | Normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 131 | data_transforms = { 132 | 'swap': transforms.Compose([ 133 | # transforms.RandomHorizontalFlip(), 134 | # transforms.Resize((384, 128), Image.BICUBIC), # interpolation 135 | # transforms.RandomRotation(degrees=15), 136 | # transforms.RandomCrop((crop_reso[0], crop_reso[1])), 137 | transforms.Randomswap((swap_num[0], swap_num[1])), 138 | ]), 139 | 'common_aug': transforms.Compose([ 140 | # transforms.Resize((resize_reso[0], resize_reso[1]),Image.BICUBIC), 141 | # transforms.RandomHorizontalFlip(), 142 | transforms.RandomHorizontalFlip(), 143 | transforms.Resize((384, 128), Image.BICUBIC), # interpolation 144 | transforms.RandomRotation(degrees=15), 145 | transforms.RandomCrop((crop_reso[0], crop_reso[1])), 146 | # transforms.ToTensor(), 147 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 148 | ]), 149 | 'train_totensor': transforms.Compose([ 150 | # transforms.Resize((crop_reso[0], crop_reso[1]),Image.BICUBIC), 151 | # ImageNetPolicy(), 152 | transforms.ToTensor(), 153 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 154 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 155 | 156 | ]), 157 | 'test_totensor': transforms.Compose([ 158 | transforms.Resize((crop_reso[0], crop_reso[1]),Image.BICUBIC), 159 | transforms.CenterCrop((crop_reso[0], crop_reso[1])), 160 | transforms.ToTensor(), 161 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 162 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 163 | ]), 164 | 'None': None, 165 | } 166 | return data_transforms 167 | 168 | class CUHKPEDEDataset(data.Dataset): 169 | def __init__(self, opt,tran): 170 | 171 | self.opt = opt 172 | self.flip_flag = (self.opt.mode == 'train') 173 | 174 | data_save = read_dict(os.path.join(opt.dataroot, 'processed_data', opt.mode + '_save.pkl')) 175 | 176 | self.img_path = [os.path.join(opt.dataroot, img_path) for img_path in data_save['img_path']] 177 | 178 | self.label = data_save['id'] 179 | 180 | self.caption_code = data_save['lstm_caption_id'] 181 | 182 | self.same_id_index = data_save['same_id_index'] 183 | 184 | self.caption = data_save['captions'] 185 | 186 | self.transform = tran 187 | 188 | self.num_data = len(self.img_path) 189 | 190 | self.tokenizer = BertTokenizer.from_pretrained('saved_models/bert-base-uncased-vocab.txt', do_lower_case=True) 191 | 192 | self.transformers = load_data_transformers([384,128], [384,128], [4,6]) 193 | 194 | self.swap_size = [4,6] 195 | def crop_image(self, image, cropnum): 196 | width, high = image.size 197 | crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)] 198 | crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)] 199 | im_list = [] 200 | for j in range(len(crop_y) - 1): 201 | for i in range(len(crop_x) - 1): 202 | im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high)))) 203 | return im_list 204 | 205 | def __getitem__(self, index): 206 | """ 207 | :param index: 208 | :return: image and its label 209 | """ 210 | image = Image.open(self.img_path[index]) 211 | img_unswaps = self.transformers['common_aug'](image) 212 | img_unswaps = self.transformers["train_totensor"](img_unswaps) 213 | label = torch.from_numpy(np.array([self.label[index]],dtype='int32')).long() 214 | 215 | phrase = self.caption[index] 216 | examples = read_examples(phrase, index) 217 | features = convert_examples_to_features( 218 | examples=examples, seq_length=65, tokenizer=self.tokenizer) 219 | caption_code = features[0].input_ids 220 | caption_length = features[0].input_mask 221 | 222 | same_id_index = np.random.randint(len(self.same_id_index[index])) 223 | same_id_index = self.same_id_index[index][same_id_index] 224 | phrase = self.caption[same_id_index] 225 | examples = read_examples(phrase, index) 226 | features = convert_examples_to_features( 227 | examples=examples, seq_length=65, tokenizer=self.tokenizer) 228 | same_id_caption_code = features[0].input_ids 229 | same_id_caption_length = features[0].input_mask 230 | 231 | return img_unswaps, label, np.array(caption_code,dtype=int), np.array(caption_length,dtype=int), \ 232 | np.array(same_id_caption_code,dtype=int), np.array(same_id_caption_length,dtype=int) 233 | 234 | def get_data(self, index, img=True): 235 | if img: 236 | image = Image.open(self.img_path[index]) 237 | image = self.transform(image) 238 | else: 239 | image = 0 240 | 241 | label = torch.from_numpy(np.array([self.label[index]])).long() 242 | 243 | caption_code, caption_length = self.caption_mask(self.caption_code[index]) 244 | 245 | return image, label, caption_code, caption_length 246 | 247 | def caption_mask(self, caption): 248 | caption_length = len(caption) 249 | caption = torch.from_numpy(np.array(caption)).view(-1).long() 250 | 251 | if caption_length < self.opt.caption_length_max: 252 | zero_padding = torch.zeros(self.opt.caption_length_max - caption_length).long() 253 | caption = torch.cat([caption, zero_padding], 0) 254 | else: 255 | caption = caption[:self.opt.caption_length_max] 256 | caption_length = self.opt.caption_length_max 257 | 258 | return caption, caption_length 259 | 260 | def __len__(self): 261 | return self.num_data 262 | 263 | 264 | class CUHKPEDE_img_dateset(data.Dataset): 265 | def __init__(self, opt,tran): 266 | 267 | self.opt = opt 268 | if opt.mode=='train': 269 | path = 'dataset/CUHK-PEDES/processed_data/train_save.pkl' 270 | elif opt.mode=='test': 271 | path = 'dataset/CUHK-PEDES/processed_data/test_save.pkl' 272 | # data_save = read_dict(os.path.join(opt.dataroot, 'processed_data', opt.mode + '_save.pkl')) 273 | data_save = read_dict(path) 274 | 275 | self.img_path = [os.path.join(opt.dataroot, img_path) for img_path in data_save['img_path']] 276 | 277 | self.label = data_save['id'] 278 | 279 | self.transform = tran 280 | 281 | self.num_data = len(self.img_path) 282 | 283 | self.transformers = load_data_transformers([384, 128], [384, 128], [12, 4]) 284 | 285 | def __getitem__(self, index): 286 | """ 287 | :param index: 288 | :return: image and its label 289 | """ 290 | 291 | image = Image.open(self.img_path[index]) 292 | image_path = self.img_path[index] 293 | raw_image = cv2.imread(image_path) 294 | # raw_image = cv2.resize(raw_image, (128, 384), interpolation=cv2.INTER_CUBIC) 295 | # # image = self.transform(image) 296 | image = self.transformers["test_totensor"](image) 297 | 298 | label = torch.from_numpy(np.array([self.label[index]])).long() 299 | 300 | return image, label 301 | 302 | def __len__(self): 303 | return self.num_data 304 | 305 | 306 | class CUHKPEDE_txt_dateset(data.Dataset): 307 | def __init__(self, opt): 308 | 309 | self.opt = opt 310 | 311 | data_save = read_dict(os.path.join(opt.dataroot, 'processed_data', opt.mode + '_save.pkl')) 312 | 313 | self.label = data_save['caption_label'] 314 | self.caption_code = data_save['lstm_caption_id'] 315 | self.caption = data_save['captions'] 316 | self.num_data = len(self.caption_code) 317 | self.tokenizer = BertTokenizer.from_pretrained('saved_models/bert-base-uncased-vocab.txt', do_lower_case=True) 318 | 319 | def __getitem__(self, index): 320 | """ 321 | :param index: 322 | :return: image and its label 323 | """ 324 | 325 | label = torch.from_numpy(np.array([self.label[index]])).long() 326 | 327 | # caption_code, caption_length = self.caption_mask(self.caption_code[index]) 328 | phrase = self.caption[index] 329 | examples = read_examples(phrase, index) 330 | features = convert_examples_to_features( 331 | examples=examples, seq_length=65, tokenizer=self.tokenizer) 332 | caption_code = features[0].input_ids 333 | caption_length = features[0].input_mask 334 | fea_tokens = (features[0].tokens + ['0'] * (65 - len(features[0].tokens))) 335 | return label, np.array(caption_code,dtype=int), np.array(caption_length,dtype=int),fea_tokens 336 | 337 | def caption_mask(self, caption): 338 | caption_length = len(caption) 339 | caption = torch.from_numpy(np.array(caption)).view(-1).float() 340 | if caption_length < self.opt.caption_length_max: 341 | zero_padding = torch.zeros(self.opt.caption_length_max - caption_length) 342 | caption = torch.cat([caption, zero_padding], 0) 343 | else: 344 | caption = caption[:self.opt.caption_length_max] 345 | caption_length = self.opt.caption_length_max 346 | 347 | return caption, caption_length 348 | 349 | def __len__(self): 350 | return self.num_data 351 | 352 | 353 | 354 | 355 | 356 | -------------------------------------------------------------------------------- /src/loss/Id_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.parameter import Parameter 9 | from torch.nn import init 10 | 11 | 12 | def weights_init_classifier(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Linear') != -1: 15 | init.normal_(m.weight.data, std=0.001) 16 | init.constant_(m.bias.data, 0.0) 17 | 18 | 19 | class classifier(nn.Module): 20 | 21 | def __init__(self, input_dim, output_dim): 22 | super(classifier, self).__init__() 23 | 24 | self.block = nn.Linear(input_dim, output_dim) 25 | self.block.apply(weights_init_classifier) 26 | 27 | def forward(self, x): 28 | x = self.block(x) 29 | return x 30 | 31 | 32 | class Id_Loss(nn.Module): 33 | 34 | def __init__(self, opt, part, feature_length): 35 | super(Id_Loss, self).__init__() 36 | 37 | self.opt = opt 38 | self.part = part 39 | 40 | W = [] 41 | for i in range(part): 42 | W.append(classifier(feature_length, opt.class_num)) 43 | self.W = nn.Sequential(*W) 44 | 45 | self.global_swap = classifier(feature_length,opt.class_num*2) 46 | 47 | def calculate_IdLoss(self, image_embedding_local, text_embedding_local, label): 48 | 49 | label = label.view(label.size(0)) 50 | 51 | criterion = nn.CrossEntropyLoss(reduction='mean') 52 | 53 | Lipt_local = 0 54 | Ltpi_local = 0 55 | 56 | for i in range(self.part): 57 | 58 | score_i2t_local_i = self.W[i](image_embedding_local[:, :, i]) 59 | score_t2i_local_i = self.W[i](text_embedding_local[:, :, i]) 60 | 61 | Lipt_local += criterion(score_i2t_local_i, label) 62 | Ltpi_local += criterion(score_t2i_local_i, label) 63 | 64 | loss = (Lipt_local + Ltpi_local) / self.part 65 | 66 | return loss 67 | 68 | def forward(self, image_embedding_local, text_embedding_local, label,gs = False): 69 | criterion = nn.CrossEntropyLoss(reduction='mean') 70 | if gs: 71 | a = torch.cat((image_embedding_local,text_embedding_local),dim=0).squeeze() 72 | predict = self.global_swap(a) 73 | loss = criterion(predict,label.squeeze()) 74 | else: 75 | loss = self.calculate_IdLoss(image_embedding_local, text_embedding_local, label) 76 | 77 | return loss 78 | 79 | -------------------------------------------------------------------------------- /src/loss/RankingLoss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 4 | @author: zifyloo 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | import torch.nn.functional as F 11 | 12 | 13 | def calculate_similarity(image_embedding, text_embedding): 14 | image_embedding = image_embedding.view(image_embedding.size(0), -1) 15 | text_embedding = text_embedding.view(text_embedding.size(0), -1) 16 | image_embedding_norm = image_embedding / (image_embedding.norm(dim=1, keepdim=True) + 1e-8) 17 | text_embedding_norm = text_embedding / (text_embedding.norm(dim=1, keepdim=True) + 1e-8) 18 | 19 | similarity = torch.mm(image_embedding_norm, text_embedding_norm.t()) 20 | 21 | similarity_match = torch.sum(image_embedding_norm * text_embedding_norm, dim=1) 22 | 23 | return similarity, similarity_match 24 | 25 | 26 | def calculate_margin_cr(similarity_match_cr, similarity_match, auto_margin_flag, margin): 27 | if auto_margin_flag: 28 | lambda_cr = abs(similarity_match_cr.detach()) / abs(similarity_match.detach()) 29 | ones = torch.ones_like(lambda_cr) 30 | data = torch.ge(ones, lambda_cr).float() 31 | data_2 = torch.ge(lambda_cr, ones).float() 32 | lambda_cr = data * lambda_cr + data_2 33 | 34 | lambda_cr = lambda_cr.detach().cpu().numpy() 35 | margin_cr = ((lambda_cr + 1) * margin) / 2.0 36 | else: 37 | margin_cr = margin / 2.0 38 | 39 | return margin_cr 40 | 41 | 42 | class CRLoss(nn.Module): 43 | 44 | def __init__(self, opt): 45 | super(CRLoss, self).__init__() 46 | 47 | self.device = opt.device 48 | self.margin = np.array([opt.margin]).repeat(opt.batch_size) 49 | self.double_margin = np.array([opt.margin]).repeat(opt.batch_size*2*opt.part) 50 | self.beta = opt.cr_beta 51 | # self.margin_local = np.array([opt.margin]).repeat(opt.batch_size*opt.part) 52 | 53 | def semi_hard_negative(self, loss, margin): 54 | negative_index = np.where(np.logical_and(loss < margin, loss > 0))[0] 55 | return np.random.choice(negative_index) if len(negative_index) > 0 else None 56 | 57 | def get_triplets(self, similarity, labels, auto_margin_flag, margin): 58 | 59 | similarity = similarity.cpu().data.numpy() 60 | 61 | labels = labels.cpu().data.numpy() 62 | triplets = [] 63 | 64 | for idx, label in enumerate(labels): # same class calculate together 65 | if margin[idx] >= 0.16 or auto_margin_flag is False: 66 | negative = np.where(labels != label)[0] 67 | 68 | ap_sim = similarity[idx, idx] 69 | 70 | loss = similarity[idx, negative] - ap_sim + margin[idx] 71 | 72 | negetive_index = self.semi_hard_negative(loss, margin[idx]) 73 | 74 | if negetive_index is not None: 75 | triplets.append([idx, idx, negative[negetive_index]]) 76 | 77 | if len(triplets) == 0: 78 | triplets.append([idx, idx, negative[0]]) 79 | 80 | triplets = torch.LongTensor(np.array(triplets)) 81 | 82 | return_margin = torch.FloatTensor(np.array(margin[triplets[:, 0]])).to(self.device) 83 | 84 | return triplets, return_margin 85 | 86 | def calculate_loss(self, similarity, label, auto_margin_flag, margin): 87 | 88 | image_triplets, img_margin = self.get_triplets(similarity, label, auto_margin_flag, margin) 89 | text_triplets, txt_margin = self.get_triplets(similarity.t(), label, auto_margin_flag, margin) 90 | 91 | image_anchor_loss = F.relu(img_margin 92 | - similarity[image_triplets[:, 0], image_triplets[:, 1]] 93 | + similarity[image_triplets[:, 0], image_triplets[:, 2]]) 94 | 95 | similarity = similarity.t() 96 | text_anchor_loss = F.relu(txt_margin 97 | - similarity[text_triplets[:, 0], text_triplets[:, 1]] 98 | + similarity[text_triplets[:, 0], text_triplets[:, 2]]) 99 | 100 | loss = torch.sum(image_anchor_loss) + torch.sum(text_anchor_loss) 101 | 102 | return loss 103 | 104 | def forward(self, img, txt, txt_cr, labels, auto_margin_flag,local=False): 105 | # if local: 106 | # similarity, similarity_match = calculate_similarity(img, txt) 107 | # similarity_cr, similarity_cr_match = calculate_similarity(img, txt_cr) 108 | # margin_cr = calculate_margin_cr(similarity_cr_match, similarity_match, auto_margin_flag, self.double_margin) 109 | # 110 | # cr_loss = self.calculate_loss(similarity, labels, auto_margin_flag, self.double_margin) \ 111 | # + self.beta * self.calculate_loss(similarity_cr, labels, auto_margin_flag, margin_cr) 112 | # else: 113 | similarity, similarity_match = calculate_similarity(img, txt) 114 | similarity_cr, similarity_cr_match = calculate_similarity(img, txt_cr) 115 | margin_cr = calculate_margin_cr(similarity_cr_match, similarity_match, auto_margin_flag, self.margin) 116 | 117 | cr_loss = self.calculate_loss(similarity, labels, auto_margin_flag, self.margin) \ 118 | + self.beta * self.calculate_loss(similarity_cr, labels, auto_margin_flag, margin_cr) 119 | 120 | return cr_loss 121 | 122 | -------------------------------------------------------------------------------- /src/model/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import transformers as ppb 4 | import math 5 | from torch import nn 6 | from torchvision import models 7 | import torch 8 | from torch.nn import init 9 | from torch.nn import functional as F 10 | from torch.autograd import Variable 11 | from pytorch_pretrained_bert.modeling import BertModel 12 | import numpy as np 13 | # import keyboard 14 | number = 0 15 | import matplotlib.pyplot as plt 16 | import seaborn as sns 17 | def l2norm(X, dim, eps=1e-8): 18 | """L2-normalize columns of X 19 | """ 20 | norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps 21 | X = torch.div(X, norm) 22 | return X 23 | 24 | def weights_init_kaiming(m): 25 | classname = m.__class__.__name__ 26 | if classname.find('Conv2d') != -1: 27 | init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') 28 | elif classname.find('Linear') != -1: 29 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out') 30 | init.constant_(m.bias.data, 0.0) 31 | elif classname.find('BatchNorm1d') != -1: 32 | init.normal(m.weight.data, 1.0, 0.02) 33 | init.constant_(m.bias.data, 0.0) 34 | elif classname.find('BatchNorm2d') != -1: 35 | init.constant_(m.weight.data, 1) 36 | init.constant_(m.bias.data, 0) 37 | 38 | def weights_init_classifier(m): 39 | classname = m.__class__.__name__ 40 | if classname.find('Linear') != -1: 41 | init.normal(m.weight.data, std=0.001) 42 | init.constant(m.bias.data, 0.0) 43 | 44 | class conv(nn.Module): 45 | 46 | def __init__(self, input_dim, output_dim, relu=False, BN=False): 47 | super(conv, self).__init__() 48 | 49 | block = [] 50 | block += [nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=False)] 51 | 52 | if BN: 53 | block += [nn.BatchNorm2d(output_dim)] 54 | if relu: 55 | block += [nn.LeakyReLU(0.25, inplace=True)] 56 | 57 | self.block = nn.Sequential(*block) 58 | self.block.apply(weights_init_kaiming) 59 | 60 | def forward(self, x): 61 | x = self.block(x) 62 | x = x.squeeze(3).squeeze(2) 63 | return x 64 | 65 | class NonLocalNet(nn.Module): 66 | def __init__(self, opt, dim_cut=8): 67 | super(NonLocalNet, self).__init__() 68 | self.opt = opt 69 | 70 | up_dim_conv = [] 71 | part_sim_conv = [] 72 | cur_sim_conv = [] 73 | conv_local_att = [] 74 | for i in range(opt.part): 75 | up_dim_conv.append(conv(opt.feature_length//dim_cut, 1024, relu=True, BN=True)) 76 | part_sim_conv.append(conv(opt.feature_length, opt.feature_length // dim_cut, relu=True, BN=False)) 77 | cur_sim_conv.append(conv(opt.feature_length, opt.feature_length // dim_cut, relu=True, BN=False)) 78 | conv_local_att.append(conv(opt.feature_length, 512)) 79 | 80 | self.up_dim_conv = nn.Sequential(*up_dim_conv) 81 | self.part_sim_conv = nn.Sequential(*part_sim_conv) 82 | self.cur_sim_conv = nn.Sequential(*cur_sim_conv) 83 | self.conv_local_att = nn.Sequential(*conv_local_att) 84 | 85 | self.zero_eye = (torch.eye(opt.part, opt.part) * -1e6).unsqueeze(0).to(opt.device) 86 | 87 | self.lambda_softmax = 1 88 | 89 | def forward(self, embedding): 90 | embedding = embedding.unsqueeze(3) 91 | embedding_part_sim = [] 92 | embedding_cur_sim = [] 93 | 94 | for i in range(self.opt.part): 95 | embedding_i = embedding[:, :, i, :].unsqueeze(2) 96 | 97 | embedding_part_sim_i = self.part_sim_conv[i](embedding_i).unsqueeze(2)#b,512 98 | embedding_part_sim.append(embedding_part_sim_i) 99 | 100 | embedding_cur_sim_i = self.cur_sim_conv[i](embedding_i).unsqueeze(2)#b,512 101 | embedding_cur_sim.append(embedding_cur_sim_i) 102 | 103 | embedding_part_sim = torch.cat(embedding_part_sim, dim=2)#b,512,6 104 | embedding_cur_sim = torch.cat(embedding_cur_sim, dim=2)#b,512,6 105 | 106 | embedding_part_sim_norm = l2norm(embedding_part_sim, dim=1) # N*D*n 107 | embedding_cur_sim_norm = l2norm(embedding_cur_sim, dim=1) # N*D*n 108 | self_att = torch.bmm(embedding_part_sim_norm.transpose(1, 2), embedding_cur_sim_norm) #b,6,6 # N*n*n 109 | self_att = self_att + self.zero_eye.repeat(self_att.size(0), 1, 1) 110 | self_att = F.softmax(self_att * self.lambda_softmax, dim=1) #64,6,6 # .transpose(1, 2).contiguous() 111 | embedding_att = torch.bmm(embedding_part_sim_norm, self_att).unsqueeze(3)#b,512,6 112 | 113 | embedding_att_up_dim = [] 114 | for i in range(self.opt.part): 115 | embedding_att_up_dim_i = embedding_att[:, :, i, :].unsqueeze(2) 116 | embedding_att_up_dim_i = self.up_dim_conv[i](embedding_att_up_dim_i).unsqueeze(2) 117 | embedding_att_up_dim.append(embedding_att_up_dim_i) 118 | embedding_att_up_dim = torch.cat(embedding_att_up_dim, dim=2).unsqueeze(3) 119 | 120 | embedding_att = embedding + embedding_att_up_dim#cancha 121 | 122 | embedding_local_att = [] 123 | for i in range(self.opt.part): 124 | embedding_att_i = embedding_att[:, :, i, :].unsqueeze(2) 125 | embedding_att_i = self.conv_local_att[i](embedding_att_i).unsqueeze(2) 126 | embedding_local_att.append(embedding_att_i) 127 | 128 | embedding_local_att = torch.cat(embedding_local_att, 2) 129 | 130 | return embedding_local_att.squeeze() 131 | 132 | 133 | 134 | 135 | class ResNet_image_50(nn.Module): 136 | def __init__(self): 137 | super(ResNet_image_50, self).__init__() 138 | resnet50 = models.resnet50(pretrained=True) 139 | resnet50.layer4[0].downsample[0].stride = (1, 1) 140 | resnet50.layer4[0].conv2.stride = (1, 1) 141 | self.base1 = nn.Sequential( 142 | resnet50.conv1, 143 | resnet50.bn1, 144 | resnet50.relu, 145 | resnet50.maxpool, 146 | resnet50.layer1, # 256 64 32 147 | ) 148 | self.base2 = nn.Sequential( 149 | resnet50.layer2, # 512 32 16 150 | ) 151 | self.base3 = nn.Sequential( 152 | resnet50.layer3, # 1024 16 8 153 | ) 154 | self.base4 = nn.Sequential( 155 | resnet50.layer4 # 2048 16 8 156 | ) 157 | 158 | def forward(self, x): 159 | x1 = self.base1(x) 160 | x2 = self.base2(x1) 161 | x3 = self.base3(x2) 162 | x4 = self.base4(x3) 163 | return x1, x2, x3, x4 164 | 165 | class TextImgPersonReidNet(nn.Module): 166 | def __init__(self, opt): 167 | super(TextImgPersonReidNet, self).__init__() 168 | self.opt = opt 169 | self.ImageExtract = ResNet_image_50() 170 | model_class, tokenizer_class, pretrained_weights = (ppb.BertModel, ppb.BertTokenizer, 'bert-base-uncased') 171 | self.text_embed = model_class.from_pretrained('bert_weight') 172 | self.text_embed.eval() 173 | self.BERT = True 174 | for p in self.text_embed.parameters(): 175 | p.requires_grad = False 176 | self.global_avgpool = nn.AdaptiveMaxPool2d((1, 1)) 177 | self.local_avgpool = nn.AdaptiveMaxPool2d((opt.part, 1)) 178 | 179 | conv_local = [] 180 | for i in range(opt.part): 181 | conv_local.append(conv(2048, opt.feature_length)) 182 | self.conv_local = nn.Sequential(*conv_local) 183 | 184 | self.conv_global = conv(2048, opt.feature_length) 185 | self.conv_global_qiyu = conv(2048, opt.feature_length) 186 | 187 | self.non_local_net = NonLocalNet(opt, dim_cut=2) 188 | self.leaky_relu = nn.LeakyReLU(0.25, inplace=True) 189 | 190 | 191 | txt_change = [] 192 | for i in range(self.opt.part): 193 | txt_change.append(nn.Linear(2048, 2048)) 194 | self.txt_change = nn.Sequential(*txt_change) 195 | 196 | self.start_list_img = nn.Parameter(torch.randn(2048,2)) 197 | self.start_list_txt = nn.Parameter(torch.randn(2048,2)) 198 | 199 | 200 | self.model_txt = ResNet_text_50() 201 | self.number = 0 202 | self.Convmask = nn.Conv2d(2048, 1, 1, stride=1, padding=0, bias=True) 203 | self.max_pool = nn.AdaptiveMaxPool2d((1,1)) 204 | 205 | visual_dictionary = torch.randn(2048, self.opt.part) 206 | self.register_buffer('end_img', visual_dictionary) 207 | nn.init.normal_(self.end_img) 208 | self.end_img.requires_grad=False 209 | txt_dictionary = torch.randn(2048, self.opt.part) 210 | self.register_buffer('end_txt', txt_dictionary) 211 | nn.init.normal_(self.end_txt) 212 | self.end_txt.requires_grad=False 213 | self.adapt_max_pool = nn.AdaptiveMaxPool2d((self.opt.part,1)) 214 | self.adapt_max_pool1D = nn.AdaptiveMaxPool1d((1)) 215 | self.temp = 1 216 | 217 | def forward(self, image, caption_id, text_length,epoch=None): 218 | if self.training: 219 | img_global, img_local, img_non_local,img_part_response,img_global_response = self.img_embedding(image,epoch=epoch) 220 | txt_global, txt_local, txt_non_local,txt_part_response,txt_global_response = self.txt_embedding(caption_id, text_length,epoch=epoch) 221 | else: 222 | img_global, img_local, img_non_local = self.img_embedding(image,epoch=epoch) 223 | txt_global, txt_local, txt_non_local = self.txt_embedding(caption_id, text_length,epoch=epoch) 224 | if self.training: 225 | return img_global, img_local, img_non_local, txt_global, txt_local, txt_non_local,\ 226 | img_part_response,txt_part_response,img_global_response,txt_global_response 227 | else: 228 | return img_global, img_local, img_non_local, txt_global, txt_local, txt_non_local 229 | 230 | def compute_global_local(self,img4,image=True,txt_lang=None,train=True,epoch=None): 231 | 232 | fine_tune_start = 25 233 | if image: 234 | feat_list = [] 235 | img4_new = img4.permute(0, 2, 3, 1) 236 | img4_new = img4_new.flatten(1,2) 237 | foreground_background_response = (F.normalize(img4_new,dim=-1)*5) @ (F.normalize(self.start_list_img,dim=0)*5) 238 | foreground_background_response_soft = torch.softmax(foreground_background_response/self.temp,dim=-1) 239 | foreground_background_response_soft_mutual = F.normalize(foreground_background_response_soft,dim=1).permute(0,2,1) @ F.normalize(foreground_background_response_soft,dim=1) 240 | foreground = (foreground_background_response_soft[:,:,0].unsqueeze(-1))*img4_new 241 | foreground_max = self.adapt_max_pool(foreground.contiguous().view(self.opt.batch_size,24,8,-1).permute(0,3,1,2)) 242 | if train and epoch<=fine_tune_start: 243 | part_axis = [] 244 | for i in range(self.opt.part): 245 | head = torch.sum(foreground_max[:,:,i,:].squeeze(),dim=0)/self.opt.batch_size 246 | part_axis.append(head.unsqueeze(0)) 247 | part_axis = torch.cat(part_axis,dim=0).t() 248 | embed = 0.99*self.end_img.detach() + 0.01*part_axis 249 | self.end_img = embed.detach() 250 | part_response = (F.normalize(foreground,dim=-1)*5) @ (F.normalize(embed,dim=0)*5) 251 | else: 252 | part_response = (F.normalize(foreground,dim=-1)*5) @ (F.normalize(self.end_img.detach(),dim=0)*5) 253 | part_response = F.softmax(part_response, dim=1) 254 | for i in range(self.opt.part): 255 | select = torch.sum(part_response[:, :, i].unsqueeze(-1) * foreground, dim=1) 256 | feat_list.append(select.unsqueeze(-1).unsqueeze(-1)) 257 | return feat_list,part_response,foreground_background_response_soft_mutual 258 | else: 259 | feat_list = [] 260 | for i in range(self.opt.part): 261 | feat_list.append([]) 262 | txt = img4.permute(1,0,2) 263 | part_mutual_list = [] 264 | gloabl_mutual_list = [] 265 | for j in range(img4.size(1)): 266 | foreground_background_response = (F.normalize(txt[j,1:txt_lang[j]-1,:], dim=-1)*5) @ (F.normalize(self.start_list_txt, dim=0)*5) 267 | foreground_background_response_soft = torch.softmax(foreground_background_response/self.temp,dim=-1) 268 | foreground_background_response_soft_mutual = F.normalize(foreground_background_response_soft,dim=0).t() @ F.normalize(foreground_background_response_soft,dim=0) 269 | gloabl_mutual_list.append(foreground_background_response_soft_mutual.unsqueeze(0)) 270 | foreground = (foreground_background_response_soft[:,0].unsqueeze(-1)) * txt[j,1:txt_lang[j]-1,:] 271 | if train and epoch<=fine_tune_start: 272 | part_axis = [] 273 | for i in range(self.opt.part): 274 | head = self.adapt_max_pool1D(self.txt_change[i](foreground).unsqueeze(0).permute(0,2,1)).squeeze() 275 | part_axis.append(head.unsqueeze(0)) 276 | part_axis = torch.cat(part_axis,dim=0).t() 277 | weights = 0.99 + (((self.opt.batch_size - 1) / self.opt.batch_size) * 0.01) 278 | embed = weights * self.end_txt.detach() + (1-weights) * part_axis 279 | self.end_txt = embed.detach() 280 | part_response = (F.normalize(foreground,dim=-1)*5) @ (F.normalize(embed,dim=0)*5) 281 | else: 282 | part_response = (F.normalize(foreground,dim=-1)*5) @ (F.normalize(self.end_txt.detach(),dim=0)*5) 283 | part_response_soft = F.softmax(part_response, dim=0,) 284 | for i in range(self.opt.part): 285 | select = torch.sum(part_response_soft[:, i].unsqueeze(-1) * foreground, dim=0) 286 | feat_list[i].append(select.unsqueeze(0)) 287 | part_response_soft_norm = F.normalize(part_response_soft,dim=0) 288 | part_mutual_list.append((part_response_soft_norm.t() @ part_response_soft_norm).unsqueeze(0)) 289 | feat_list = [torch.cat(feat_list[i], dim=0).unsqueeze(-1) for i in range(self.opt.part)] 290 | 291 | if train: 292 | return feat_list,torch.cat(part_mutual_list,dim=0), torch.cat(gloabl_mutual_list,0) 293 | else: 294 | return feat_list,torch.cat(gloabl_mutual_list,0) 295 | 296 | def img_embedding(self, image,epoch=None): 297 | _,_,imgf3, image_feature = self.ImageExtract(image)#b,2048,12,4 298 | image_feature_global = self.global_avgpool(image_feature) # b,2048,1 299 | image_global = self.conv_global(image_feature_global).unsqueeze(2) # b,1024 300 | if self.training: 301 | image_feature_local,part_response,global_mutual = self.compute_global_local(torch.cat([image_feature],dim=0),image=True,train=self.training,epoch=epoch) 302 | else: 303 | image_feature_local,part_response,global_mutual = self.compute_global_local(image_feature,image=True,train=self.training) 304 | 305 | image_feature_local = torch.cat(image_feature_local,dim=2) 306 | image_local = [] 307 | for i in range(self.opt.part): 308 | image_feature_local_i = image_feature_local[:, :, i, :] 309 | image_feature_local_i = image_feature_local_i.unsqueeze(2) 310 | image_embedding_local_i = self.conv_local[i](image_feature_local_i).unsqueeze(2) 311 | image_local.append(image_embedding_local_i) 312 | 313 | image_local = torch.cat(image_local, 2)#b,1024,6 314 | 315 | image_non_local = self.leaky_relu(image_local) 316 | image_non_local = self.non_local_net(image_non_local) 317 | 318 | if self.training: 319 | return image_global, image_local, image_non_local,part_response,global_mutual 320 | else: 321 | return image_global, image_local, image_non_local 322 | 323 | def txt_embedding(self, caption_id, text_length,epoch=None): 324 | with torch.no_grad(): 325 | txt = self.text_embed(caption_id, attention_mask=text_length) 326 | txt = txt[0] 327 | 328 | _, fword_1 = self.model_txt(txt) 329 | text_feature_l = fword_1.squeeze().unsqueeze(-1) 330 | text_length = torch.sum(text_length, dim=-1) 331 | text_global = self.global_avgpool(text_feature_l) # 64,2048 332 | text_global = self.conv_global(text_global).unsqueeze(2) # 64,1024 333 | if self.training: 334 | text_feature_local,part_response,global_mutual = self.compute_global_local(text_feature_l.squeeze().permute(2,0,1),image=False,txt_lang=text_length,train=self.training,epoch=epoch) 335 | else: 336 | text_feature_local,_ = self.compute_global_local(text_feature_l.squeeze().permute(2,0,1),image=False,txt_lang=text_length,train=self.training,epoch=epoch) 337 | 338 | text_feature_local = torch.cat(text_feature_local, dim=-1)#b,2048,6 339 | text_local = [] 340 | for p in range(self.opt.part): 341 | text_feature_local_conv_p = text_feature_local[:, :, p].unsqueeze(2).unsqueeze(2) 342 | text_feature_local_conv_p = self.conv_local[p](text_feature_local_conv_p).unsqueeze(2) 343 | text_local.append(text_feature_local_conv_p) 344 | text_local = torch.cat(text_local, dim=2)#b,1024,6 345 | text_non_local = self.leaky_relu(text_local) 346 | text_non_local = self.non_local_net(text_non_local)#b,512,6 347 | if self.training: 348 | return text_global, text_local, text_non_local,part_response,global_mutual 349 | else: 350 | return text_global, text_local, text_non_local 351 | 352 | 353 | 354 | class ResNet_text_50(nn.Module): 355 | 356 | def __init__(self, zero_init_residual=False, 357 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 358 | norm_layer=None): 359 | super(ResNet_text_50, self).__init__() 360 | if norm_layer is None: 361 | norm_layer = nn.BatchNorm2d 362 | self._norm_layer = norm_layer 363 | 364 | self.inplanes = 768 365 | 366 | if replace_stride_with_dilation is None: 367 | # each element in the tuple indicates if we should replace 368 | # the 2x2 stride with a dilated convolution instead 369 | replace_stride_with_dilation = [False, False, False] 370 | if len(replace_stride_with_dilation) != 3: 371 | raise ValueError("replace_stride_with_dilation should be None " 372 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 373 | 374 | 375 | self.conv1 = conv1x1(self.inplanes, 1024) 376 | self.bn1 = norm_layer(1024) 377 | self.relu = nn.ReLU(inplace=True) 378 | 379 | downsample = nn.Sequential( 380 | conv1x1(1024, 2048), 381 | norm_layer(2048), 382 | ) 383 | 384 | # 3, 4, 6, 3 385 | 386 | self.branch1 = nn.Sequential( 387 | Bottleneck(inplanes=1024, planes=2048, width=512, downsample=downsample), 388 | Bottleneck(inplanes=2048, planes=2048, width=512), 389 | Bottleneck(inplanes=2048, planes=2048, width=512) 390 | ) 391 | 392 | for m in self.modules(): 393 | if isinstance(m, nn.Conv2d): 394 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 395 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 396 | nn.init.constant_(m.weight, 1) 397 | nn.init.constant_(m.bias, 0) 398 | 399 | if zero_init_residual: 400 | for m in self.modules(): 401 | if isinstance(m, Bottleneck): 402 | nn.init.constant_(m.bn3.weight, 0) 403 | 404 | 405 | def forward(self, x): 406 | x = x.permute(0,2,1).unsqueeze(2).contiguous() 407 | x1 = self.conv1(x) # 1024 1 64 408 | x1 = self.bn1(x1) 409 | x1 = self.relu(x1) 410 | x21 = self.branch1(x1) 411 | return x1, x21 412 | 413 | class Bottleneck(nn.Module): 414 | 415 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 416 | width=64, dilation=1, norm_layer=None): 417 | super(Bottleneck, self).__init__() 418 | if norm_layer is None: 419 | norm_layer = nn.BatchNorm2d 420 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 421 | self.conv1 = conv1x1(inplanes, width) 422 | self.bn1 = norm_layer(width) 423 | self.conv2 = conv1x3(width, width, stride, groups, dilation) 424 | self.bn2 = norm_layer(width) 425 | self.conv3 = conv1x1(width, planes) 426 | self.bn3 = norm_layer(planes) 427 | self.relu = nn.ReLU(inplace=True) 428 | self.downsample = downsample 429 | self.stride = stride 430 | 431 | def forward(self, x): 432 | identity = x 433 | 434 | out = self.conv1(x) 435 | out = self.bn1(out) 436 | out = self.relu(out) 437 | 438 | out = self.conv2(out) 439 | out = self.bn2(out) 440 | out = self.relu(out) 441 | 442 | out = self.conv3(out) 443 | out = self.bn3(out) 444 | 445 | if self.downsample is not None: 446 | identity = self.downsample(x) 447 | 448 | out += identity 449 | out = self.relu(out) 450 | 451 | return out 452 | 453 | def conv1x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 454 | """3x3 convolution with padding""" 455 | return nn.Conv2d(in_planes, out_planes, kernel_size=(1,3), stride=stride, 456 | padding=(0,1), groups=groups, bias=False, dilation=dilation) 457 | 458 | def conv1x1(in_planes, out_planes, stride=1): 459 | """1x1 convolution""" 460 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 461 | -------------------------------------------------------------------------------- /src/option/options.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | import argparse 7 | import torch 8 | import logging 9 | import os 10 | from utils.read_write_data import makedir 11 | 12 | logger = logging.getLogger() 13 | logger.setLevel(logging.INFO) 14 | 15 | 16 | class options(): 17 | def __init__(self): 18 | self._par = argparse.ArgumentParser(description='options for Deep Cross Modal') 19 | 20 | self._par.add_argument('--model_name', type=str, help='experiment name') 21 | self._par.add_argument('--mode', type=str, default='', help='choose mode [train or test]') 22 | 23 | self._par.add_argument('--epoch', type=int, default=60, help='train epoch') 24 | self._par.add_argument('--epoch_decay', type=list, default=[20,40,50,55], help='decay epoch') 25 | self._par.add_argument('--epoch_begin', type=int, default=5, help='when calculate the auto margin') 26 | 27 | self._par.add_argument('--batch_size', type=int, default=64, help='batch size') 28 | self._par.add_argument('--adam_alpha', type=float, default=0.9, help='momentum term of adam') 29 | self._par.add_argument('--adam_beta', type=float, default=0.999, help='momentum term of adam') 30 | self._par.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam') 31 | self._par.add_argument('--margin', type=float, default=0.2, help='ranking loss margin') 32 | self._par.add_argument('--cr_beta', type=float, default=0.1, help='ranking loss margin') 33 | 34 | self._par.add_argument('--vocab_size', type=int, default=5000, help='the size of vocab') 35 | self._par.add_argument('--feature_length', type=int, default=512, help='the length of feature') 36 | self._par.add_argument('--class_num', type=int, default=11000, 37 | help='num of class for StarGAN training on second dataset') 38 | self._par.add_argument('--part', type=int, default=6, help='the num of image part') 39 | self._par.add_argument('--caption_length_max', type=int, default=100, help='the max length of caption') 40 | 41 | self._par.add_argument('--save_path', type=str, default='./checkpoints/test', 42 | help='save the result during training') 43 | self._par.add_argument('--GPU_id', type=str, default='0', help='choose GPU ID') 44 | self._par.add_argument('--device', type=str, default='', help='cuda devie') 45 | self._par.add_argument('--dataset', type=str, help='choose the dataset ') 46 | self._par.add_argument('--dataroot', type=str, help='data root of the Data') 47 | 48 | self.opt = self._par.parse_args() 49 | 50 | self.opt.device = torch.device('cuda:{}'.format(self.opt.GPU_id[0])) 51 | 52 | 53 | def config(opt): 54 | 55 | log_config(opt) 56 | model_root = os.path.join(opt.save_path, 'model') 57 | if os.path.exists(model_root) is False: 58 | makedir(model_root) 59 | 60 | 61 | def log_config(opt): 62 | logroot = os.path.join(opt.save_path, 'log') 63 | if os.path.exists(logroot) is False: 64 | makedir(logroot) 65 | filename = os.path.join(logroot, opt.mode + '.log') 66 | handler = logging.FileHandler(filename) 67 | handler.setLevel(logging.INFO) 68 | formatter = logging.Formatter('%(message)s') 69 | handler.setFormatter(formatter) 70 | logger.addHandler(logging.StreamHandler()) 71 | logger.addHandler(handler) 72 | if opt.mode != 'test': 73 | logger.info(opt) 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | 7 | from option.options import options 8 | from data.dataloader import get_dataloader 9 | import torch 10 | from model.model import TextImgPersonReidNet 11 | import os 12 | from test_during_train import test 13 | 14 | def save_checkpoint(state, opt): 15 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 16 | torch.save(state, filename) 17 | 18 | def load_checkpoint(opt,time=None): 19 | if time is None: 20 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 21 | state = torch.load(filename) 22 | print('Load the {} epoch parameter successfully'.format(state['epoch'])) 23 | else: 24 | filename = os.path.join(opt.save_path, 'model/best_'+str(time)+'.pth.tar') 25 | state = torch.load(filename) 26 | print('Load the {} epoch parameter successfully'.format(state['epoch'])) 27 | return state 28 | 29 | def main(opt): 30 | opt.device = torch.device('cuda:{}'.format(opt.GPU_id)) 31 | 32 | opt.save_path = './checkpoints/{}/'.format(opt.dataset) + opt.model_name 33 | 34 | test_img_dataloader, test_txt_dataloader = get_dataloader(opt) 35 | 36 | network = TextImgPersonReidNet(opt).to(opt.device) 37 | 38 | test_best = 0 39 | state = load_checkpoint(opt) 40 | network.load_state_dict(state['network']) 41 | epoch = state['epoch'] 42 | 43 | print(opt.model_name) 44 | network.eval() 45 | test(opt, epoch + 1, network, test_img_dataloader, test_txt_dataloader, test_best, False) 46 | network.train() 47 | 48 | if __name__ == '__main__': 49 | opt = options().opt 50 | main(opt) 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /src/test_during_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @author: zifyloo 4 | """ 5 | 6 | import tqdm 7 | import torch 8 | import numpy as np 9 | import os 10 | import torch.nn.functional as F 11 | from utils.read_write_data import write_txt 12 | from time import time 13 | 14 | def calculate_similarity(image_feature_local, text_feature_local): 15 | 16 | image_feature_local = image_feature_local.view(image_feature_local.size(0), -1) 17 | image_feature_local = image_feature_local / image_feature_local.norm(dim=1, keepdim=True) 18 | 19 | text_feature_local = text_feature_local.view(text_feature_local.size(0), -1) 20 | text_feature_local = text_feature_local / text_feature_local.norm(dim=1, keepdim=True) 21 | 22 | similarity = torch.mm(image_feature_local, text_feature_local.t()) 23 | 24 | return similarity.cpu() 25 | 26 | 27 | def calculate_ap(similarity, label_query, label_gallery): 28 | """ 29 | calculate the similarity, and rank the distance, according to the distance, calculate the ap, cmc 30 | :param label_query: the id of query [1] 31 | :param label_gallery:the id of gallery [N] 32 | :return: ap, cmc 33 | """ 34 | 35 | index = np.argsort(similarity)[::-1] # the index of the similarity from huge to small 36 | good_index = np.argwhere(label_gallery == label_query) # the index of the same label in gallery 37 | 38 | cmc = np.zeros(index.shape) 39 | 40 | mask = np.in1d(index, good_index) # get the flag the if index[i] is in the good_index 41 | 42 | precision_result = np.argwhere(mask == True) # get the situation of the good_index in the index 43 | 44 | precision_result = precision_result.reshape(precision_result.shape[0]) 45 | 46 | if precision_result.shape[0] != 0: 47 | cmc[int(precision_result[0]):] = 1 # get the cmc 48 | 49 | d_recall = 1.0 / len(precision_result) 50 | ap = 0 51 | 52 | for i in range(len(precision_result)): # ap is to calculate the PR area 53 | precision = (i + 1) * 1.0 / (precision_result[i] + 1) 54 | 55 | if precision_result[i] != 0: 56 | old_precision = i * 1.0 / precision_result[i] 57 | else: 58 | old_precision = 1.0 59 | 60 | ap += d_recall * (old_precision + precision) / 2 61 | 62 | return ap, cmc 63 | else: 64 | return None, None 65 | 66 | 67 | def evaluate(similarity, label_query, label_gallery): 68 | similarity = similarity.numpy() 69 | label_query = label_query.numpy() 70 | label_gallery = label_gallery.numpy() 71 | 72 | cmc = np.zeros(label_gallery.shape) 73 | ap = 0 74 | for i in range(len(label_query)): 75 | ap_i, cmc_i = calculate_ap(similarity[i, :], label_query[i], label_gallery) 76 | cmc += cmc_i 77 | ap += ap_i 78 | """ 79 | cmc_i is the vector [0,0,...1,1,..1], the first 1 is the first right prediction n, 80 | rank-n and the rank-k after it all add one right prediction, therefore all of them's index mark 1 81 | Through the add all the vector and then divive the n_query, we can get the rank-k accuracy cmc 82 | cmc[k-1] is the rank-k accuracy 83 | """ 84 | cmc = cmc / len(label_query) 85 | map = ap / len(label_query) # map = sum(ap) / n_query 86 | 87 | # print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f' % (cmc[0], cmc[4], cmc[9], map)) 88 | 89 | return cmc, map 90 | 91 | 92 | def evaluate_without_matching_image(similarity, label_query, label_gallery, txt_img_index): 93 | similarity = similarity.numpy() 94 | label_query = label_query.numpy() 95 | label_gallery = label_gallery.numpy() 96 | 97 | cmc = np.zeros(label_gallery.shape[0] - 1) 98 | ap = 0 99 | count = 0 100 | for i in range(len(label_query)): 101 | 102 | similarity_i = similarity[i, :] 103 | similarity_i = np.delete(similarity_i, txt_img_index[i]) 104 | label_gallery_i = np.delete(label_gallery, txt_img_index[i]) 105 | ap_i, cmc_i = calculate_ap(similarity_i, label_query[i], label_gallery_i) 106 | if ap_i is not None: 107 | cmc += cmc_i 108 | ap += ap_i 109 | else: 110 | count += 1 111 | """ 112 | cmc_i is the vector [0,0,...1,1,..1], the first 1 is the first right prediction n, 113 | rank-n and the rank-k after it all add one right prediction, therefore all of them's index mark 1 114 | Through the add all the vector and then divive the n_query, we can get the rank-k accuracy cmc 115 | cmc[k-1] is the rank-k accuracy 116 | """ 117 | cmc = cmc / (len(label_query) - count) 118 | map = ap / (len(label_query) - count) # map = sum(ap) / n_query 119 | # print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f' % (cmc[0], cmc[4], cmc[9], map)) 120 | 121 | return cmc, map 122 | 123 | 124 | def load_checkpoint(model_root, model_name): 125 | filename = os.path.join(model_root, 'model', model_name) 126 | state = torch.load(filename, map_location='cpu') 127 | 128 | return state 129 | 130 | 131 | def write_result(similarity, img_labels, txt_labels, name, txt_root, best_txt_root, epoch, best): 132 | write_txt(name, txt_root) 133 | print(name) 134 | t2i_cmc, t2i_map = evaluate(similarity.t(), txt_labels, img_labels) 135 | str = "t2i: @R1: {:.4}, @R5: {:.4}, @R10: {:.4}, map: {:.4}".format(t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_map) 136 | write_txt(str, txt_root) 137 | write_txt(str, txt_root) 138 | print(str) 139 | 140 | if t2i_cmc[0] > best: 141 | str = "Testing Epoch: {}".format(epoch) 142 | write_txt(str, best_txt_root) 143 | write_txt(name, best_txt_root) 144 | str = "t2i: @R1: {:.4}, @R5: {:.4}, @R10: {:.4}, map: {:.4}".format(t2i_cmc[0], t2i_cmc[4], t2i_cmc[9], t2i_map) 145 | write_txt(str, best_txt_root) 146 | 147 | return t2i_cmc[0] 148 | else: 149 | return best 150 | 151 | 152 | def test(opt, epoch, network, img_dataloader, txt_dataloader, best, return_flag=True,img_visualization=False,text_visualization=False): 153 | 154 | best_txt_root = os.path.join(opt.save_path, 'log', 'best_test.log') 155 | txt_root = os.path.join(opt.save_path, 'log', 'test_separate.log') 156 | 157 | str_epoch = "Testing Epoch: {}".format(epoch) 158 | write_txt(str_epoch, txt_root) 159 | print(str_epoch) 160 | 161 | image_feature_global = [] 162 | image_feature_local = [] 163 | image_feature_non_local = [] 164 | img_labels = [] 165 | ###image####3 166 | a = time() 167 | 168 | for times, [image, label] in tqdm.tqdm(enumerate(img_dataloader)): 169 | 170 | image = image.to(opt.device) 171 | label = label.to(opt.device) 172 | 173 | with torch.no_grad(): 174 | img_global_i, img_local_i, img_non_local_i = network.img_embedding(image) 175 | 176 | image_feature_global.append(img_global_i) 177 | image_feature_local.append(img_local_i) 178 | image_feature_non_local.append(img_non_local_i) 179 | img_labels.append(label.view(-1)) 180 | 181 | image_feature_local = torch.cat(image_feature_local, 0) 182 | image_feature_global = torch.cat(image_feature_global, 0) 183 | image_feature_non_local = torch.cat(image_feature_non_local, 0) 184 | img_labels = torch.cat(img_labels, 0) 185 | 186 | ####text### 187 | text_feature_local = [] 188 | text_feature_global = [] 189 | text_feature_non_local = [] 190 | txt_labels = [] 191 | text_sum = [] 192 | names = locals() 193 | for i in range(opt.part): 194 | names['txt_top_part' + str(i)] = {} 195 | for times, [label, caption_code, caption_length,tokens] in tqdm.tqdm(enumerate(txt_dataloader)): 196 | label = label.to(opt.device) 197 | caption_code = caption_code.to(opt.device).long() 198 | caption_length = caption_length.to(opt.device) 199 | 200 | with torch.no_grad(): 201 | text_global_i, text_local_i, text_non_local_i = network.txt_embedding(caption_code, caption_length) 202 | text_feature_local.append(text_local_i) 203 | text_feature_global.append(text_global_i) 204 | text_feature_non_local.append(text_non_local_i) 205 | # text_sum.append(feat_sum) 206 | txt_labels.append(label.view(-1)) 207 | 208 | 209 | # print('ha') 210 | text_feature_local = torch.cat(text_feature_local, 0) 211 | text_feature_global = torch.cat(text_feature_global, 0) 212 | text_feature_non_local = torch.cat(text_feature_non_local, 0) 213 | # text_sum = torch.cat(text_sum,dim=0) 214 | txt_labels = torch.cat(txt_labels, 0) 215 | 216 | similarity_local = calculate_similarity(image_feature_local, text_feature_local) 217 | similarity_global = calculate_similarity(image_feature_global, text_feature_global) 218 | similarity_non_local = calculate_similarity(image_feature_non_local, text_feature_non_local) 219 | similarity_all = similarity_local + similarity_global + similarity_non_local 220 | similarity_nonlocal_global = similarity_global + similarity_non_local 221 | img_labels = img_labels.cpu() 222 | txt_labels = txt_labels.cpu() 223 | b = time() 224 | print (b-a) 225 | best = write_result(similarity_global, img_labels, txt_labels, 'similarity_global:', 226 | txt_root, best_txt_root, epoch, best) 227 | 228 | best = write_result(similarity_local, img_labels, txt_labels, 'similarity_local:', 229 | txt_root, best_txt_root, epoch, best) 230 | 231 | best = write_result(similarity_non_local, img_labels, txt_labels, 'similarity_non_local:', 232 | txt_root, best_txt_root, epoch, best) 233 | 234 | best = write_result(similarity_all, img_labels, txt_labels, 'similarity_all:', 235 | txt_root, best_txt_root, epoch, best) 236 | 237 | best = write_result(similarity_nonlocal_global,img_labels,txt_labels,'similarity_nonlocal_global:', 238 | txt_root, best_txt_root, epoch, best) 239 | if return_flag: 240 | return best 241 | 242 | 243 | 244 | 245 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from option.options import options, config 4 | from data.dataloader import get_dataloader 5 | import torch.nn.functional as F 6 | import torch 7 | import torch.nn as nn 8 | from model.model import TextImgPersonReidNet 9 | from loss.Id_loss import Id_Loss 10 | from loss.RankingLoss import CRLoss 11 | from torch import optim 12 | import logging 13 | import os 14 | from test_during_train import test 15 | from torch.autograd import Variable 16 | import numpy as np 17 | import random 18 | import time 19 | # SEED = 0 20 | # torch.manual_seed(SEED) 21 | # torch.cuda.manual_seed(SEED) 22 | def setup_seed(seed): 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | np.random.seed(seed) 26 | random.seed(seed) 27 | torch.backends.cudnn.deterministic = True 28 | setup_seed(27) 29 | logger = logging.getLogger() 30 | logger.setLevel(logging.INFO) 31 | 32 | 33 | def save_checkpoint(state, opt,times=None): 34 | if times is None: 35 | filename = os.path.join(opt.save_path, 'model/best.pth.tar') 36 | torch.save(state, filename) 37 | else: 38 | filename = os.path.join(opt.save_path, 'model/best_'+str(times)+'.pth.tar') 39 | torch.save(state, filename) 40 | 41 | 42 | def train(opt): 43 | opt.device = torch.device('cuda:{}'.format(opt.GPU_id)) 44 | 45 | opt.save_path = './checkpoints/{}/'.format(opt.dataset) + opt.model_name 46 | 47 | config(opt) 48 | train_dataloader = get_dataloader(opt) 49 | opt.mode = 'test' 50 | test_img_dataloader, test_txt_dataloader = get_dataloader(opt) 51 | opt.mode = 'train' 52 | 53 | id_loss_fun_global = Id_Loss(opt, 1, opt.feature_length).to(opt.device) 54 | #id_loss_fun_global_f3 = Id_Loss(opt,1,opt.feature_length).to(opt.device) 55 | id_loss_fun_local = Id_Loss(opt, opt.part, opt.feature_length).to(opt.device) 56 | id_loss_fun_non_local = Id_Loss(opt, opt.part, 512).to(opt.device) 57 | cr_loss_fun = CRLoss(opt) 58 | network = TextImgPersonReidNet(opt).to(opt.device) 59 | 60 | cnn_params = list(map(id, network.ImageExtract.parameters())) 61 | other_params = filter(lambda p: id(p) not in cnn_params, network.parameters()) 62 | other_params = list(other_params) 63 | other_params.extend(list(id_loss_fun_global.parameters())) 64 | other_params.extend(list(id_loss_fun_local.parameters())) 65 | other_params.extend(list(id_loss_fun_non_local.parameters())) 66 | # other_params.extend(list(id_loss_fun_global_f3.parameters())) 67 | param_groups = [{'params': other_params, 'lr': opt.lr}, 68 | {'params': network.ImageExtract.parameters(), 'lr': opt.lr * 0.1}] 69 | 70 | optimizer = optim.Adam(param_groups, betas=(opt.adam_alpha, opt.adam_beta)) 71 | 72 | test_best = 0 73 | test_history = 0 74 | 75 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, opt.epoch_decay) 76 | 77 | add_loss = nn.MSELoss() 78 | 79 | for epoch in range(opt.epoch): 80 | for param in optimizer.param_groups: 81 | logging.info('lr:{}'.format(param['lr'])) 82 | start = time.time() 83 | for times, [image, label, caption_code, caption_length, caption_code_cr, caption_length_cr] in enumerate(train_dataloader): 84 | 85 | 86 | image = Variable(image.to(opt.device)) 87 | label = Variable(label.to(opt.device)) 88 | caption_code = Variable(caption_code.to(opt.device).long()) 89 | caption_length = caption_length.to(opt.device) 90 | caption_code_cr = Variable(caption_code_cr.to(opt.device).long()) 91 | caption_length_cr = caption_length_cr.to(opt.device) 92 | 93 | 94 | img_global, img_local, img_non_local, txt_global, txt_local, txt_non_local,\ 95 | img_part_response,txt_part_response,img_global_response,txt_global_response = network(image, caption_code, caption_length,epoch=epoch) 96 | 97 | 98 | txt_global_cr, txt_local_cr, txt_non_local_cr,txt_part_response_cr, txt_global_response_cr= network.txt_embedding(caption_code_cr, caption_length_cr,epoch=epoch) 99 | img_part = F.normalize(img_part_response,dim=1) 100 | img_part = img_part.permute(0,2,1) @ img_part 101 | img_part_loss = add_loss(img_part,torch.eye((opt.part)).repeat(opt.batch_size,1,1).to(opt.device)) 102 | txt_part_loss = add_loss(txt_part_response, torch.eye((opt.part)).repeat(opt.batch_size, 1, 1).to(opt.device)) 103 | txt_part_loss_cr = add_loss(txt_part_response_cr, torch.eye((opt.part)).repeat(opt.batch_size, 1, 1).to(opt.device)) 104 | 105 | img_global_loss = add_loss(img_global_response,torch.eye((2)).repeat(opt.batch_size,1,1).to(opt.device)) 106 | txt_global_loss = add_loss(txt_global_response,torch.eye((2)).repeat(opt.batch_size,1,1).to(opt.device)) 107 | txt_global_loss_cr = add_loss(txt_global_response_cr,torch.eye((2)).repeat(opt.batch_size,1,1).to(opt.device)) 108 | 109 | 110 | id_loss_global = id_loss_fun_global(img_global, txt_global, label) 111 | id_loss_local = id_loss_fun_local(img_local,txt_local,label) 112 | id_loss_non_local = id_loss_fun_non_local(img_non_local,txt_non_local,label) 113 | id_loss = id_loss_global + (id_loss_local + id_loss_non_local) * 0.5 114 | 115 | 116 | cr_loss_global = cr_loss_fun(img_global, txt_global, txt_global_cr, label, epoch >= opt.epoch_begin) 117 | cr_loss_local = cr_loss_fun(img_local, txt_local, txt_local_cr, label, epoch >= opt.epoch_begin,) 118 | cr_loss_non_local = cr_loss_fun(img_non_local, txt_non_local, txt_non_local_cr, label, epoch >= opt.epoch_begin,) 119 | ranking_loss = cr_loss_global + (cr_loss_local + cr_loss_non_local)*0.5 \ 120 | + (img_part_loss+txt_part_loss+txt_part_loss_cr+img_global_loss+txt_global_loss+txt_global_loss_cr) 121 | 122 | optimizer.zero_grad() 123 | loss = (id_loss + ranking_loss) 124 | loss.backward() 125 | optimizer.step() 126 | 127 | if (times+1) % 200== 0: 128 | logging.info("Epoch: %d/%d Setp: %d, ranking_loss: %.2f, id_loss: %.2f" 129 | % (epoch + 1, opt.epoch, times + 1, ranking_loss, id_loss)) 130 | 131 | end = time.time() 132 | print ((end-start)//60) 133 | logging.info('time:' + str((end-start)//60)) 134 | 135 | 136 | print(opt.model_name) 137 | 138 | network.eval() 139 | test_best = test(opt, epoch + 1, network, test_img_dataloader, test_txt_dataloader, test_best) 140 | network.train() 141 | 142 | if test_best > test_history: 143 | test_history = test_best 144 | state = { 145 | 'network': network.cpu().state_dict(), 146 | 'test_best': test_best, 147 | 'epoch': epoch, 148 | 'WN': id_loss_fun_non_local.cpu().state_dict(), 149 | 'WL': id_loss_fun_local.cpu().state_dict(), 150 | } 151 | save_checkpoint(state, opt) 152 | network.to(opt.device) 153 | id_loss_fun_non_local.to(opt.device) 154 | id_loss_fun_local.to(opt.device) 155 | 156 | scheduler.step() 157 | 158 | logging.info('Training Done') 159 | 160 | 161 | if __name__ == '__main__': 162 | opt = options().opt 163 | train(opt) 164 | -------------------------------------------------------------------------------- /src/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | -------------------------------------------------------------------------------- /src/transforms/functional.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION 6 | 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | import numpy as np 12 | import numbers 13 | import types 14 | import collections 15 | import warnings 16 | 17 | 18 | def _is_pil_image(img): 19 | if accimage is not None: 20 | return isinstance(img, (Image.Image, accimage.Image)) 21 | else: 22 | return isinstance(img, Image.Image) 23 | 24 | 25 | def _is_tensor_image(img): 26 | return torch.is_tensor(img) and img.ndimension() == 3 27 | 28 | 29 | def _is_numpy_image(img): 30 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 31 | 32 | 33 | def to_tensor(pic): 34 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 35 | 36 | See ``ToTensor`` for more details. 37 | 38 | Args: 39 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 40 | 41 | Returns: 42 | Tensor: Converted image. 43 | """ 44 | if not (_is_pil_image(pic) or _is_numpy_image(pic)): 45 | raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) 46 | 47 | if isinstance(pic, np.ndarray): 48 | # handle numpy array 49 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 50 | # backward compatibility 51 | if isinstance(img, torch.ByteTensor): 52 | return img.float().div(255) 53 | else: 54 | return img 55 | 56 | if accimage is not None and isinstance(pic, accimage.Image): 57 | nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) 58 | pic.copyto(nppic) 59 | return torch.from_numpy(nppic) 60 | 61 | # handle PIL Image 62 | if pic.mode == 'I': 63 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 64 | elif pic.mode == 'I;16': 65 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 66 | elif pic.mode == 'F': 67 | img = torch.from_numpy(np.array(pic, np.float32, copy=False)) 68 | elif pic.mode == '1': 69 | img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) 70 | else: 71 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 72 | # PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK 73 | if pic.mode == 'YCbCr': 74 | nchannel = 3 75 | elif pic.mode == 'I;16': 76 | nchannel = 1 77 | else: 78 | nchannel = len(pic.mode) 79 | img = img.view(pic.size[1], pic.size[0], nchannel) 80 | # put it from HWC to CHW format 81 | # yikes, this transpose takes 80% of the loading time/CPU 82 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 83 | if isinstance(img, torch.ByteTensor): 84 | return img.float().div(255) 85 | else: 86 | return img 87 | 88 | 89 | def to_pil_image(pic, mode=None): 90 | """Convert a tensor or an ndarray to PIL Image. 91 | 92 | See :class:`~torchvision.transforms.ToPIlImage` for more details. 93 | 94 | Args: 95 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 96 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 97 | 98 | .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 99 | 100 | Returns: 101 | PIL Image: Image converted to PIL Image. 102 | """ 103 | if not (_is_numpy_image(pic) or _is_tensor_image(pic)): 104 | raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) 105 | 106 | npimg = pic 107 | if isinstance(pic, torch.FloatTensor): 108 | pic = pic.mul(255).byte() 109 | if torch.is_tensor(pic): 110 | npimg = np.transpose(pic.numpy(), (1, 2, 0)) 111 | 112 | if not isinstance(npimg, np.ndarray): 113 | raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + 114 | 'not {}'.format(type(npimg))) 115 | 116 | if npimg.shape[2] == 1: 117 | expected_mode = None 118 | npimg = npimg[:, :, 0] 119 | if npimg.dtype == np.uint8: 120 | expected_mode = 'L' 121 | elif npimg.dtype == np.int16: 122 | expected_mode = 'I;16' 123 | elif npimg.dtype == np.int32: 124 | expected_mode = 'I' 125 | elif npimg.dtype == np.float32: 126 | expected_mode = 'F' 127 | if mode is not None and mode != expected_mode: 128 | raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" 129 | .format(mode, np.dtype, expected_mode)) 130 | mode = expected_mode 131 | 132 | elif npimg.shape[2] == 4: 133 | permitted_4_channel_modes = ['RGBA', 'CMYK'] 134 | if mode is not None and mode not in permitted_4_channel_modes: 135 | raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) 136 | 137 | if mode is None and npimg.dtype == np.uint8: 138 | mode = 'RGBA' 139 | else: 140 | permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] 141 | if mode is not None and mode not in permitted_3_channel_modes: 142 | raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) 143 | if mode is None and npimg.dtype == np.uint8: 144 | mode = 'RGB' 145 | 146 | if mode is None: 147 | raise TypeError('Input type {} is not supported'.format(npimg.dtype)) 148 | 149 | return Image.fromarray(npimg, mode=mode) 150 | 151 | 152 | def normalize(tensor, mean, std): 153 | """Normalize a tensor image with mean and standard deviation. 154 | 155 | See ``Normalize`` for more details. 156 | 157 | Args: 158 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 159 | mean (sequence): Sequence of means for each channel. 160 | std (sequence): Sequence of standard deviations for each channely. 161 | 162 | Returns: 163 | Tensor: Normalized Tensor image. 164 | """ 165 | if not _is_tensor_image(tensor): 166 | raise TypeError('tensor is not a torch image.') 167 | # TODO: make efficient 168 | for t, m, s in zip(tensor, mean, std): 169 | t.sub_(m).div_(s) 170 | return tensor 171 | 172 | 173 | def resize(img, size, interpolation=Image.BILINEAR): 174 | """Resize the input PIL Image to the given size. 175 | 176 | Args: 177 | img (PIL Image): Image to be resized. 178 | size (sequence or int): Desired output size. If size is a sequence like 179 | (h, w), the output size will be matched to this. If size is an int, 180 | the smaller edge of the image will be matched to this number maintaing 181 | the aspect ratio. i.e, if height > width, then image will be rescaled to 182 | (size * height / width, size) 183 | interpolation (int, optional): Desired interpolation. Default is 184 | ``PIL.Image.BILINEAR`` 185 | 186 | Returns: 187 | PIL Image: Resized image. 188 | """ 189 | if not _is_pil_image(img): 190 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 191 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)): 192 | raise TypeError('Got inappropriate size arg: {}'.format(size)) 193 | 194 | if isinstance(size, int): 195 | w, h = img.size 196 | if (w <= h and w == size) or (h <= w and h == size): 197 | return img 198 | if w < h: 199 | ow = size 200 | oh = int(size * h / w) 201 | return img.resize((ow, oh), interpolation) 202 | else: 203 | oh = size 204 | ow = int(size * w / h) 205 | return img.resize((ow, oh), interpolation) 206 | else: 207 | return img.resize(size[::-1], interpolation) 208 | 209 | 210 | def scale(*args, **kwargs): 211 | warnings.warn("The use of the transforms.Scale transform is deprecated, " + 212 | "please use transforms.Resize instead.") 213 | return resize(*args, **kwargs) 214 | 215 | 216 | def pad(img, padding, fill=0, padding_mode='constant'): 217 | """Pad the given PIL Image on all sides with speficified padding mode and fill value. 218 | 219 | Args: 220 | img (PIL Image): Image to be padded. 221 | padding (int or tuple): Padding on each border. If a single int is provided this 222 | is used to pad all borders. If tuple of length 2 is provided this is the padding 223 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 224 | this is the padding for the left, top, right and bottom borders 225 | respectively. 226 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 227 | length 3, it is used to fill R, G, B channels respectively. 228 | This value is only used when the padding_mode is constant 229 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 230 | constant: pads with a constant value, this value is specified with fill 231 | edge: pads with the last value on the edge of the image 232 | reflect: pads with reflection of image (without repeating the last value on the edge) 233 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 234 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 235 | symmetric: pads with reflection of image (repeating the last value on the edge) 236 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 237 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 238 | 239 | Returns: 240 | PIL Image: Padded image. 241 | """ 242 | if not _is_pil_image(img): 243 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 244 | 245 | if not isinstance(padding, (numbers.Number, tuple)): 246 | raise TypeError('Got inappropriate padding arg') 247 | if not isinstance(fill, (numbers.Number, str, tuple)): 248 | raise TypeError('Got inappropriate fill arg') 249 | if not isinstance(padding_mode, str): 250 | raise TypeError('Got inappropriate padding_mode arg') 251 | 252 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: 253 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 254 | "{} element tuple".format(len(padding))) 255 | 256 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ 257 | 'Padding mode should be either constant, edge, reflect or symmetric' 258 | 259 | if padding_mode == 'constant': 260 | return ImageOps.expand(img, border=padding, fill=fill) 261 | else: 262 | if isinstance(padding, int): 263 | pad_left = pad_right = pad_top = pad_bottom = padding 264 | if isinstance(padding, collections.Sequence) and len(padding) == 2: 265 | pad_left = pad_right = padding[0] 266 | pad_top = pad_bottom = padding[1] 267 | if isinstance(padding, collections.Sequence) and len(padding) == 4: 268 | pad_left = padding[0] 269 | pad_top = padding[1] 270 | pad_right = padding[2] 271 | pad_bottom = padding[3] 272 | 273 | img = np.asarray(img) 274 | # RGB image 275 | if len(img.shape) == 3: 276 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) 277 | # Grayscale image 278 | if len(img.shape) == 2: 279 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) 280 | 281 | return Image.fromarray(img) 282 | 283 | 284 | def crop(img, i, j, h, w): 285 | """Crop the given PIL Image. 286 | 287 | Args: 288 | img (PIL Image): Image to be cropped. 289 | i: Upper pixel coordinate. 290 | j: Left pixel coordinate. 291 | h: Height of the cropped image. 292 | w: Width of the cropped image. 293 | 294 | Returns: 295 | PIL Image: Cropped image. 296 | """ 297 | if not _is_pil_image(img): 298 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 299 | 300 | return img.crop((j, i, j + w, i + h)) 301 | 302 | 303 | def center_crop(img, output_size): 304 | if isinstance(output_size, numbers.Number): 305 | output_size = (int(output_size), int(output_size)) 306 | w, h = img.size 307 | th, tw = output_size 308 | i = int(round((h - th) / 2.)) 309 | j = int(round((w - tw) / 2.)) 310 | return crop(img, i, j, th, tw) 311 | 312 | 313 | def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): 314 | """Crop the given PIL Image and resize it to desired size. 315 | 316 | Notably used in RandomResizedCrop. 317 | 318 | Args: 319 | img (PIL Image): Image to be cropped. 320 | i: Upper pixel coordinate. 321 | j: Left pixel coordinate. 322 | h: Height of the cropped image. 323 | w: Width of the cropped image. 324 | size (sequence or int): Desired output size. Same semantics as ``scale``. 325 | interpolation (int, optional): Desired interpolation. Default is 326 | ``PIL.Image.BILINEAR``. 327 | Returns: 328 | PIL Image: Cropped image. 329 | """ 330 | assert _is_pil_image(img), 'img should be PIL Image' 331 | img = crop(img, i, j, h, w) 332 | img = resize(img, size, interpolation) 333 | return img 334 | 335 | 336 | def hflip(img): 337 | """Horizontally flip the given PIL Image. 338 | 339 | Args: 340 | img (PIL Image): Image to be flipped. 341 | 342 | Returns: 343 | PIL Image: Horizontall flipped image. 344 | """ 345 | if not _is_pil_image(img): 346 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 347 | 348 | return img.transpose(Image.FLIP_LEFT_RIGHT) 349 | 350 | 351 | def vflip(img): 352 | """Vertically flip the given PIL Image. 353 | 354 | Args: 355 | img (PIL Image): Image to be flipped. 356 | 357 | Returns: 358 | PIL Image: Vertically flipped image. 359 | """ 360 | if not _is_pil_image(img): 361 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 362 | 363 | return img.transpose(Image.FLIP_TOP_BOTTOM) 364 | 365 | 366 | def swap(img, crop): 367 | def crop_image(image, cropnum): 368 | width, high = image.size 369 | crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)] 370 | crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)] 371 | im_list = [] 372 | for j in range(len(crop_y) - 1): 373 | for i in range(len(crop_x) - 1): 374 | im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high)))) 375 | return im_list 376 | 377 | widthcut, highcut = img.size 378 | img = img.crop((10, 10, widthcut - 10, highcut - 10)) 379 | images = crop_image(img, crop) 380 | pro = 5 381 | if pro >= 5: 382 | tmpx = [] 383 | tmpy = [] 384 | count_x = 0 385 | count_y = 0 386 | k = 1 387 | RAN = 2 388 | for i in range(crop[1] * crop[0]): 389 | tmpx.append(images[i]) 390 | count_x += 1 391 | if len(tmpx) >= k: 392 | tmp = tmpx[count_x - RAN:count_x] 393 | random.shuffle(tmp) 394 | tmpx[count_x - RAN:count_x] = tmp 395 | if count_x == crop[0]: 396 | tmpy.append(tmpx) 397 | count_x = 0 398 | count_y += 1 399 | tmpx = [] 400 | if len(tmpy) >= k: 401 | tmp2 = tmpy[count_y - RAN:count_y] 402 | random.shuffle(tmp2) 403 | tmpy[count_y - RAN:count_y] = tmp2 404 | random_im = [] 405 | for line in tmpy: 406 | random_im.extend(line) 407 | 408 | # random.shuffle(images) 409 | width, high = img.size 410 | iw = int(width / crop[0]) 411 | ih = int(high / crop[1]) 412 | toImage = Image.new('RGB', (iw * crop[0], ih * crop[1])) 413 | x = 0 414 | y = 0 415 | for i in random_im: 416 | i = i.resize((iw, ih), Image.ANTIALIAS) 417 | toImage.paste(i, (x * iw, y * ih)) 418 | x += 1 419 | if x == crop[0]: 420 | x = 0 421 | y += 1 422 | else: 423 | toImage = img 424 | toImage = toImage.resize((widthcut, highcut)) 425 | return toImage 426 | 427 | 428 | def five_crop(img, size): 429 | """Crop the given PIL Image into four corners and the central crop. 430 | 431 | .. Note:: 432 | This transform returns a tuple of images and there may be a 433 | mismatch in the number of inputs and targets your ``Dataset`` returns. 434 | 435 | Args: 436 | size (sequence or int): Desired output size of the crop. If size is an 437 | int instead of sequence like (h, w), a square crop (size, size) is 438 | made. 439 | Returns: 440 | tuple: tuple (tl, tr, bl, br, center) corresponding top left, 441 | top right, bottom left, bottom right and center crop. 442 | """ 443 | if isinstance(size, numbers.Number): 444 | size = (int(size), int(size)) 445 | else: 446 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 447 | 448 | w, h = img.size 449 | crop_h, crop_w = size 450 | if crop_w > w or crop_h > h: 451 | raise ValueError("Requested crop size {} is bigger than input size {}".format(size, 452 | (h, w))) 453 | tl = img.crop((0, 0, crop_w, crop_h)) 454 | tr = img.crop((w - crop_w, 0, w, crop_h)) 455 | bl = img.crop((0, h - crop_h, crop_w, h)) 456 | br = img.crop((w - crop_w, h - crop_h, w, h)) 457 | center = center_crop(img, (crop_h, crop_w)) 458 | return (tl, tr, bl, br, center) 459 | 460 | 461 | def ten_crop(img, size, vertical_flip=False): 462 | """Crop the given PIL Image into four corners and the central crop plus the 463 | flipped version of these (horizontal flipping is used by default). 464 | 465 | .. Note:: 466 | This transform returns a tuple of images and there may be a 467 | mismatch in the number of inputs and targets your ``Dataset`` returns. 468 | 469 | Args: 470 | size (sequence or int): Desired output size of the crop. If size is an 471 | int instead of sequence like (h, w), a square crop (size, size) is 472 | made. 473 | vertical_flip (bool): Use vertical flipping instead of horizontal 474 | 475 | Returns: 476 | tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, 477 | br_flip, center_flip) corresponding top left, top right, 478 | bottom left, bottom right and center crop and same for the 479 | flipped image. 480 | """ 481 | if isinstance(size, numbers.Number): 482 | size = (int(size), int(size)) 483 | else: 484 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 485 | 486 | first_five = five_crop(img, size) 487 | 488 | if vertical_flip: 489 | img = vflip(img) 490 | else: 491 | img = hflip(img) 492 | 493 | second_five = five_crop(img, size) 494 | return first_five + second_five 495 | 496 | 497 | def adjust_brightness(img, brightness_factor): 498 | """Adjust brightness of an Image. 499 | 500 | Args: 501 | img (PIL Image): PIL Image to be adjusted. 502 | brightness_factor (float): How much to adjust the brightness. Can be 503 | any non negative number. 0 gives a black image, 1 gives the 504 | original image while 2 increases the brightness by a factor of 2. 505 | 506 | Returns: 507 | PIL Image: Brightness adjusted image. 508 | """ 509 | if not _is_pil_image(img): 510 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 511 | 512 | enhancer = ImageEnhance.Brightness(img) 513 | img = enhancer.enhance(brightness_factor) 514 | return img 515 | 516 | 517 | def adjust_contrast(img, contrast_factor): 518 | """Adjust contrast of an Image. 519 | 520 | Args: 521 | img (PIL Image): PIL Image to be adjusted. 522 | contrast_factor (float): How much to adjust the contrast. Can be any 523 | non negative number. 0 gives a solid gray image, 1 gives the 524 | original image while 2 increases the contrast by a factor of 2. 525 | 526 | Returns: 527 | PIL Image: Contrast adjusted image. 528 | """ 529 | if not _is_pil_image(img): 530 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 531 | 532 | enhancer = ImageEnhance.Contrast(img) 533 | img = enhancer.enhance(contrast_factor) 534 | return img 535 | 536 | 537 | def adjust_saturation(img, saturation_factor): 538 | """Adjust color saturation of an image. 539 | 540 | Args: 541 | img (PIL Image): PIL Image to be adjusted. 542 | saturation_factor (float): How much to adjust the saturation. 0 will 543 | give a black and white image, 1 will give the original image while 544 | 2 will enhance the saturation by a factor of 2. 545 | 546 | Returns: 547 | PIL Image: Saturation adjusted image. 548 | """ 549 | if not _is_pil_image(img): 550 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 551 | 552 | enhancer = ImageEnhance.Color(img) 553 | img = enhancer.enhance(saturation_factor) 554 | return img 555 | 556 | 557 | def adjust_hue(img, hue_factor): 558 | """Adjust hue of an image. 559 | 560 | The image hue is adjusted by converting the image to HSV and 561 | cyclically shifting the intensities in the hue channel (H). 562 | The image is then converted back to original image mode. 563 | 564 | `hue_factor` is the amount of shift in H channel and must be in the 565 | interval `[-0.5, 0.5]`. 566 | 567 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 568 | 569 | Args: 570 | img (PIL Image): PIL Image to be adjusted. 571 | hue_factor (float): How much to shift the hue channel. Should be in 572 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 573 | HSV space in positive and negative direction respectively. 574 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 575 | with complementary colors while 0 gives the original image. 576 | 577 | Returns: 578 | PIL Image: Hue adjusted image. 579 | """ 580 | if not (-0.5 <= hue_factor <= 0.5): 581 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 582 | 583 | if not _is_pil_image(img): 584 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 585 | 586 | input_mode = img.mode 587 | if input_mode in {'L', '1', 'I', 'F'}: 588 | return img 589 | 590 | h, s, v = img.convert('HSV').split() 591 | 592 | np_h = np.array(h, dtype=np.uint8) 593 | # uint8 addition take cares of rotation across boundaries 594 | with np.errstate(over='ignore'): 595 | np_h += np.uint8(hue_factor * 255) 596 | h = Image.fromarray(np_h, 'L') 597 | 598 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 599 | return img 600 | 601 | 602 | def adjust_gamma(img, gamma, gain=1): 603 | """Perform gamma correction on an image. 604 | 605 | Also known as Power Law Transform. Intensities in RGB mode are adjusted 606 | based on the following equation: 607 | 608 | I_out = 255 * gain * ((I_in / 255) ** gamma) 609 | 610 | See https://en.wikipedia.org/wiki/Gamma_correction for more details. 611 | 612 | Args: 613 | img (PIL Image): PIL Image to be adjusted. 614 | gamma (float): Non negative real number. gamma larger than 1 make the 615 | shadows darker, while gamma smaller than 1 make dark regions 616 | lighter. 617 | gain (float): The constant multiplier. 618 | """ 619 | if not _is_pil_image(img): 620 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 621 | 622 | if gamma < 0: 623 | raise ValueError('Gamma should be a non-negative real number') 624 | 625 | input_mode = img.mode 626 | img = img.convert('RGB') 627 | 628 | gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 629 | img = img.point(gamma_map) # use PIL's point-function to accelerate this part 630 | 631 | img = img.convert(input_mode) 632 | return img 633 | 634 | 635 | def rotate(img, angle, resample=False, expand=False, center=None): 636 | """Rotate the image by angle. 637 | 638 | 639 | Args: 640 | img (PIL Image): PIL Image to be rotated. 641 | angle ({float, int}): In degrees degrees counter clockwise order. 642 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 643 | An optional resampling filter. 644 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 645 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 646 | expand (bool, optional): Optional expansion flag. 647 | If true, expands the output image to make it large enough to hold the entire rotated image. 648 | If false or omitted, make the output image the same size as the input image. 649 | Note that the expand flag assumes rotation around the center and no translation. 650 | center (2-tuple, optional): Optional center of rotation. 651 | Origin is the upper left corner. 652 | Default is the center of the image. 653 | """ 654 | 655 | if not _is_pil_image(img): 656 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 657 | 658 | return img.rotate(angle, resample, expand, center) 659 | 660 | 661 | def _get_inverse_affine_matrix(center, angle, translate, scale, shear): 662 | # Helper method to compute inverse matrix for affine transformation 663 | 664 | # As it is explained in PIL.Image.rotate 665 | # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 666 | # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] 667 | # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] 668 | # RSS is rotation with scale and shear matrix 669 | # RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0] 670 | # [ sin(a)*scale cos(a + shear)*scale 0] 671 | # [ 0 0 1] 672 | # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 673 | 674 | angle = math.radians(angle) 675 | shear = math.radians(shear) 676 | scale = 1.0 / scale 677 | 678 | # Inverted rotation matrix with scale and shear 679 | d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle) 680 | matrix = [ 681 | math.cos(angle + shear), math.sin(angle + shear), 0, 682 | -math.sin(angle), math.cos(angle), 0 683 | ] 684 | matrix = [scale / d * m for m in matrix] 685 | 686 | # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 687 | matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1]) 688 | matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1]) 689 | 690 | # Apply center translation: C * RSS^-1 * C^-1 * T^-1 691 | matrix[2] += center[0] 692 | matrix[5] += center[1] 693 | return matrix 694 | 695 | 696 | def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): 697 | """Apply affine transformation on the image keeping image center invariant 698 | 699 | Args: 700 | img (PIL Image): PIL Image to be rotated. 701 | angle ({float, int}): rotation angle in degrees between -180 and 180, clockwise direction. 702 | translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) 703 | scale (float): overall scale 704 | shear (float): shear angle value in degrees between -180 to 180, clockwise direction. 705 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 706 | An optional resampling filter. 707 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 708 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 709 | fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) 710 | """ 711 | if not _is_pil_image(img): 712 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 713 | 714 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 715 | "Argument translate should be a list or tuple of length 2" 716 | 717 | assert scale > 0.0, "Argument scale should be positive" 718 | 719 | output_size = img.size 720 | center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) 721 | matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) 722 | kwargs = {"fillcolor": fillcolor} if PILLOW_VERSION[0] == '5' else {} 723 | return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs) 724 | 725 | 726 | def to_grayscale(img, num_output_channels=1): 727 | """Convert image to grayscale version of image. 728 | 729 | Args: 730 | img (PIL Image): Image to be converted to grayscale. 731 | 732 | Returns: 733 | PIL Image: Grayscale version of the image. 734 | if num_output_channels == 1 : returned image is single channel 735 | if num_output_channels == 3 : returned image is 3 channel with r == g == b 736 | """ 737 | if not _is_pil_image(img): 738 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 739 | 740 | if num_output_channels == 1: 741 | img = img.convert('L') 742 | elif num_output_channels == 3: 743 | img = img.convert('L') 744 | np_img = np.array(img, dtype=np.uint8) 745 | np_img = np.dstack([np_img, np_img, np_img]) 746 | img = Image.fromarray(np_img, 'RGB') 747 | else: 748 | raise ValueError('num_output_channels should be either 1 or 3') 749 | 750 | return img 751 | -------------------------------------------------------------------------------- /src/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps, ImageEnhance 6 | try: 7 | import accimage 8 | except ImportError: 9 | accimage = None 10 | import numpy as np 11 | import numbers 12 | import types 13 | import collections 14 | import warnings 15 | 16 | from . import functional as F 17 | 18 | __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", 19 | "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", 20 | "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", 21 | "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "Randomswap"] 22 | 23 | _pil_interpolation_to_str = { 24 | Image.NEAREST: 'PIL.Image.NEAREST', 25 | Image.BILINEAR: 'PIL.Image.BILINEAR', 26 | Image.BICUBIC: 'PIL.Image.BICUBIC', 27 | Image.LANCZOS: 'PIL.Image.LANCZOS', 28 | } 29 | 30 | 31 | class Compose(object): 32 | """Composes several transforms together. 33 | 34 | Args: 35 | transforms (list of ``Transform`` objects): list of transforms to compose. 36 | 37 | Example: 38 | >>> transforms.Compose([ 39 | >>> transforms.CenterCrop(10), 40 | >>> transforms.ToTensor(), 41 | >>> ]) 42 | """ 43 | 44 | def __init__(self, transforms): 45 | self.transforms = transforms 46 | 47 | def __call__(self, img): 48 | for t in self.transforms: 49 | img = t(img) 50 | return img 51 | 52 | def __repr__(self): 53 | format_string = self.__class__.__name__ + '(' 54 | for t in self.transforms: 55 | format_string += '\n' 56 | format_string += ' {0}'.format(t) 57 | format_string += '\n)' 58 | return format_string 59 | 60 | 61 | class ToTensor(object): 62 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 63 | 64 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 65 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 66 | """ 67 | 68 | def __call__(self, pic): 69 | """ 70 | Args: 71 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 72 | 73 | Returns: 74 | Tensor: Converted image. 75 | """ 76 | return F.to_tensor(pic) 77 | 78 | def __repr__(self): 79 | return self.__class__.__name__ + '()' 80 | 81 | 82 | class ToPILImage(object): 83 | """Convert a tensor or an ndarray to PIL Image. 84 | 85 | Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape 86 | H x W x C to a PIL Image while preserving the value range. 87 | 88 | Args: 89 | mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). 90 | If ``mode`` is ``None`` (default) there are some assumptions made about the input data: 91 | 1. If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. 92 | 2. If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. 93 | 3. If the input has 1 channel, the ``mode`` is determined by the data type (i,e, 94 | ``int``, ``float``, ``short``). 95 | 96 | .. _PIL.Image mode: http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#modes 97 | """ 98 | def __init__(self, mode=None): 99 | self.mode = mode 100 | 101 | def __call__(self, pic): 102 | """ 103 | Args: 104 | pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. 105 | 106 | Returns: 107 | PIL Image: Image converted to PIL Image. 108 | 109 | """ 110 | return F.to_pil_image(pic, self.mode) 111 | 112 | def __repr__(self): 113 | format_string = self.__class__.__name__ + '(' 114 | if self.mode is not None: 115 | format_string += 'mode={0}'.format(self.mode) 116 | format_string += ')' 117 | return format_string 118 | 119 | 120 | class Normalize(object): 121 | """Normalize a tensor image with mean and standard deviation. 122 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 123 | will normalize each channel of the input ``torch.*Tensor`` i.e. 124 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 125 | 126 | Args: 127 | mean (sequence): Sequence of means for each channel. 128 | std (sequence): Sequence of standard deviations for each channel. 129 | """ 130 | 131 | def __init__(self, mean, std): 132 | self.mean = mean 133 | self.std = std 134 | 135 | def __call__(self, tensor): 136 | """ 137 | Args: 138 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 139 | 140 | Returns: 141 | Tensor: Normalized Tensor image. 142 | """ 143 | return F.normalize(tensor, self.mean, self.std) 144 | 145 | def __repr__(self): 146 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 147 | 148 | 149 | class Randomswap(object): 150 | def __init__(self, size): 151 | self.size = size 152 | if isinstance(size, numbers.Number): 153 | self.size = (int(size), int(size)) 154 | else: 155 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 156 | self.size = size 157 | 158 | def __call__(self, img): 159 | return F.swap(img, self.size) 160 | 161 | def __repr__(self): 162 | return self.__class__.__name__ + '(size={0})'.format(self.size) 163 | 164 | 165 | class Resize(object): 166 | """Resize the input PIL Image to the given size. 167 | 168 | Args: 169 | size (sequence or int): Desired output size. If size is a sequence like 170 | (h, w), output size will be matched to this. If size is an int, 171 | smaller edge of the image will be matched to this number. 172 | i.e, if height > width, then image will be rescaled to 173 | (size * height / width, size) 174 | interpolation (int, optional): Desired interpolation. Default is 175 | ``PIL.Image.BILINEAR`` 176 | """ 177 | 178 | def __init__(self, size, interpolation=Image.BILINEAR): 179 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 180 | self.size = size 181 | self.interpolation = interpolation 182 | 183 | def __call__(self, img): 184 | """ 185 | Args: 186 | img (PIL Image): Image to be scaled. 187 | 188 | Returns: 189 | PIL Image: Rescaled image. 190 | """ 191 | return F.resize(img, self.size, self.interpolation) 192 | 193 | def __repr__(self): 194 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 195 | return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) 196 | 197 | 198 | class Scale(Resize): 199 | """ 200 | Note: This transform is deprecated in favor of Resize. 201 | """ 202 | def __init__(self, *args, **kwargs): 203 | warnings.warn("The use of the transforms.Scale transform is deprecated, " + 204 | "please use transforms.Resize instead.") 205 | super(Scale, self).__init__(*args, **kwargs) 206 | 207 | 208 | class CenterCrop(object): 209 | """Crops the given PIL Image at the center. 210 | 211 | Args: 212 | size (sequence or int): Desired output size of the crop. If size is an 213 | int instead of sequence like (h, w), a square crop (size, size) is 214 | made. 215 | """ 216 | 217 | def __init__(self, size): 218 | if isinstance(size, numbers.Number): 219 | self.size = (int(size), int(size)) 220 | else: 221 | self.size = size 222 | 223 | def __call__(self, img): 224 | """ 225 | Args: 226 | img (PIL Image): Image to be cropped. 227 | 228 | Returns: 229 | PIL Image: Cropped image. 230 | """ 231 | return F.center_crop(img, self.size) 232 | 233 | def __repr__(self): 234 | return self.__class__.__name__ + '(size={0})'.format(self.size) 235 | 236 | 237 | class Pad(object): 238 | """Pad the given PIL Image on all sides with the given "pad" value. 239 | 240 | Args: 241 | padding (int or tuple): Padding on each border. If a single int is provided this 242 | is used to pad all borders. If tuple of length 2 is provided this is the padding 243 | on left/right and top/bottom respectively. If a tuple of length 4 is provided 244 | this is the padding for the left, top, right and bottom borders 245 | respectively. 246 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 247 | length 3, it is used to fill R, G, B channels respectively. 248 | This value is only used when the padding_mode is constant 249 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 250 | constant: pads with a constant value, this value is specified with fill 251 | edge: pads with the last value at the edge of the image 252 | reflect: pads with reflection of image (without repeating the last value on the edge) 253 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 254 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 255 | symmetric: pads with reflection of image (repeating the last value on the edge) 256 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 257 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 258 | """ 259 | 260 | def __init__(self, padding, fill=0, padding_mode='constant'): 261 | assert isinstance(padding, (numbers.Number, tuple)) 262 | assert isinstance(fill, (numbers.Number, str, tuple)) 263 | assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] 264 | if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]: 265 | raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + 266 | "{} element tuple".format(len(padding))) 267 | 268 | self.padding = padding 269 | self.fill = fill 270 | self.padding_mode = padding_mode 271 | 272 | def __call__(self, img): 273 | """ 274 | Args: 275 | img (PIL Image): Image to be padded. 276 | 277 | Returns: 278 | PIL Image: Padded image. 279 | """ 280 | return F.pad(img, self.padding, self.fill, self.padding_mode) 281 | 282 | def __repr__(self): 283 | return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ 284 | format(self.padding, self.fill, self.padding_mode) 285 | 286 | 287 | class Lambda(object): 288 | """Apply a user-defined lambda as a transform. 289 | 290 | Args: 291 | lambd (function): Lambda/function to be used for transform. 292 | """ 293 | 294 | def __init__(self, lambd): 295 | assert isinstance(lambd, types.LambdaType) 296 | self.lambd = lambd 297 | 298 | def __call__(self, img): 299 | return self.lambd(img) 300 | 301 | def __repr__(self): 302 | return self.__class__.__name__ + '()' 303 | 304 | 305 | class RandomTransforms(object): 306 | """Base class for a list of transformations with randomness 307 | 308 | Args: 309 | transforms (list or tuple): list of transformations 310 | """ 311 | 312 | def __init__(self, transforms): 313 | assert isinstance(transforms, (list, tuple)) 314 | self.transforms = transforms 315 | 316 | def __call__(self, *args, **kwargs): 317 | raise NotImplementedError() 318 | 319 | def __repr__(self): 320 | format_string = self.__class__.__name__ + '(' 321 | for t in self.transforms: 322 | format_string += '\n' 323 | format_string += ' {0}'.format(t) 324 | format_string += '\n)' 325 | return format_string 326 | 327 | 328 | class RandomApply(RandomTransforms): 329 | """Apply randomly a list of transformations with a given probability 330 | 331 | Args: 332 | transforms (list or tuple): list of transformations 333 | p (float): probability 334 | """ 335 | 336 | def __init__(self, transforms, p=0.5): 337 | super(RandomApply, self).__init__(transforms) 338 | self.p = p 339 | 340 | def __call__(self, img): 341 | if self.p < random.random(): 342 | return img 343 | for t in self.transforms: 344 | img = t(img) 345 | return img 346 | 347 | def __repr__(self): 348 | format_string = self.__class__.__name__ + '(' 349 | format_string += '\n p={}'.format(self.p) 350 | for t in self.transforms: 351 | format_string += '\n' 352 | format_string += ' {0}'.format(t) 353 | format_string += '\n)' 354 | return format_string 355 | 356 | 357 | class RandomOrder(RandomTransforms): 358 | """Apply a list of transformations in a random order 359 | """ 360 | def __call__(self, img): 361 | order = list(range(len(self.transforms))) 362 | random.shuffle(order) 363 | for i in order: 364 | img = self.transforms[i](img) 365 | return img 366 | 367 | 368 | class RandomChoice(RandomTransforms): 369 | """Apply single transformation randomly picked from a list 370 | """ 371 | def __call__(self, img): 372 | t = random.choice(self.transforms) 373 | return t(img) 374 | 375 | 376 | class RandomCrop(object): 377 | """Crop the given PIL Image at a random location. 378 | 379 | Args: 380 | size (sequence or int): Desired output size of the crop. If size is an 381 | int instead of sequence like (h, w), a square crop (size, size) is 382 | made. 383 | padding (int or sequence, optional): Optional padding on each border 384 | of the image. Default is 0, i.e no padding. If a sequence of length 385 | 4 is provided, it is used to pad left, top, right, bottom borders 386 | respectively. 387 | pad_if_needed (boolean): It will pad the image if smaller than the 388 | desired size to avoid raising an exception. 389 | """ 390 | 391 | def __init__(self, size, padding=0, pad_if_needed=False): 392 | if isinstance(size, numbers.Number): 393 | self.size = (int(size), int(size)) 394 | else: 395 | self.size = size 396 | self.padding = padding 397 | self.pad_if_needed = pad_if_needed 398 | 399 | @staticmethod 400 | def get_params(img, output_size): 401 | """Get parameters for ``crop`` for a random crop. 402 | 403 | Args: 404 | img (PIL Image): Image to be cropped. 405 | output_size (tuple): Expected output size of the crop. 406 | 407 | Returns: 408 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 409 | """ 410 | w, h = img.size 411 | th, tw = output_size 412 | if w == tw and h == th: 413 | return 0, 0, h, w 414 | 415 | i = random.randint(0, h - th) 416 | j = random.randint(0, w - tw) 417 | return i, j, th, tw 418 | 419 | def __call__(self, img): 420 | """ 421 | Args: 422 | img (PIL Image): Image to be cropped. 423 | 424 | Returns: 425 | PIL Image: Cropped image. 426 | """ 427 | if self.padding > 0: 428 | img = F.pad(img, self.padding) 429 | 430 | # pad the width if needed 431 | if self.pad_if_needed and img.size[0] < self.size[1]: 432 | img = F.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0)) 433 | # pad the height if needed 434 | if self.pad_if_needed and img.size[1] < self.size[0]: 435 | img = F.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2))) 436 | 437 | i, j, h, w = self.get_params(img, self.size) 438 | 439 | return F.crop(img, i, j, h, w) 440 | 441 | def __repr__(self): 442 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 443 | 444 | 445 | class RandomHorizontalFlip(object): 446 | """Horizontally flip the given PIL Image randomly with a given probability. 447 | 448 | Args: 449 | p (float): probability of the image being flipped. Default value is 0.5 450 | """ 451 | 452 | def __init__(self, p=0.5): 453 | self.p = p 454 | 455 | def __call__(self, img): 456 | """ 457 | Args: 458 | img (PIL Image): Image to be flipped. 459 | 460 | Returns: 461 | PIL Image: Randomly flipped image. 462 | """ 463 | if random.random() < self.p: 464 | return F.hflip(img) 465 | return img 466 | 467 | def __repr__(self): 468 | return self.__class__.__name__ + '(p={})'.format(self.p) 469 | 470 | 471 | class RandomVerticalFlip(object): 472 | """Vertically flip the given PIL Image randomly with a given probability. 473 | 474 | Args: 475 | p (float): probability of the image being flipped. Default value is 0.5 476 | """ 477 | 478 | def __init__(self, p=0.5): 479 | self.p = p 480 | 481 | def __call__(self, img): 482 | """ 483 | Args: 484 | img (PIL Image): Image to be flipped. 485 | 486 | Returns: 487 | PIL Image: Randomly flipped image. 488 | """ 489 | if random.random() < self.p: 490 | return F.vflip(img) 491 | return img 492 | 493 | def __repr__(self): 494 | return self.__class__.__name__ + '(p={})'.format(self.p) 495 | 496 | 497 | class RandomResizedCrop(object): 498 | """Crop the given PIL Image to random size and aspect ratio. 499 | 500 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 501 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 502 | is finally resized to given size. 503 | This is popularly used to train the Inception networks. 504 | 505 | Args: 506 | size: expected output size of each edge 507 | scale: range of size of the origin size cropped 508 | ratio: range of aspect ratio of the origin aspect ratio cropped 509 | interpolation: Default: PIL.Image.BILINEAR 510 | """ 511 | 512 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 513 | self.size = (size, size) 514 | self.interpolation = interpolation 515 | self.scale = scale 516 | self.ratio = ratio 517 | 518 | @staticmethod 519 | def get_params(img, scale, ratio): 520 | """Get parameters for ``crop`` for a random sized crop. 521 | 522 | Args: 523 | img (PIL Image): Image to be cropped. 524 | scale (tuple): range of size of the origin size cropped 525 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 526 | 527 | Returns: 528 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 529 | sized crop. 530 | """ 531 | for attempt in range(10): 532 | area = img.size[0] * img.size[1] 533 | target_area = random.uniform(*scale) * area 534 | aspect_ratio = random.uniform(*ratio) 535 | 536 | w = int(round(math.sqrt(target_area * aspect_ratio))) 537 | h = int(round(math.sqrt(target_area / aspect_ratio))) 538 | 539 | if random.random() < 0.5: 540 | w, h = h, w 541 | 542 | if w <= img.size[0] and h <= img.size[1]: 543 | i = random.randint(0, img.size[1] - h) 544 | j = random.randint(0, img.size[0] - w) 545 | return i, j, h, w 546 | 547 | # Fallback 548 | w = min(img.size[0], img.size[1]) 549 | i = (img.size[1] - w) // 2 550 | j = (img.size[0] - w) // 2 551 | return i, j, w, w 552 | 553 | def __call__(self, img): 554 | """ 555 | Args: 556 | img (PIL Image): Image to be cropped and resized. 557 | 558 | Returns: 559 | PIL Image: Randomly cropped and resized image. 560 | """ 561 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 562 | return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) 563 | 564 | def __repr__(self): 565 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 566 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 567 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 568 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 569 | format_string += ', interpolation={0})'.format(interpolate_str) 570 | return format_string 571 | 572 | 573 | class RandomSizedCrop(RandomResizedCrop): 574 | """ 575 | Note: This transform is deprecated in favor of RandomResizedCrop. 576 | """ 577 | def __init__(self, *args, **kwargs): 578 | warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + 579 | "please use transforms.RandomResizedCrop instead.") 580 | super(RandomSizedCrop, self).__init__(*args, **kwargs) 581 | 582 | 583 | class FiveCrop(object): 584 | """Crop the given PIL Image into four corners and the central crop 585 | 586 | .. Note:: 587 | This transform returns a tuple of images and there may be a mismatch in the number of 588 | inputs and targets your Dataset returns. See below for an example of how to deal with 589 | this. 590 | 591 | Args: 592 | size (sequence or int): Desired output size of the crop. If size is an ``int`` 593 | instead of sequence like (h, w), a square crop of size (size, size) is made. 594 | 595 | Example: 596 | >>> transform = Compose([ 597 | >>> FiveCrop(size), # this is a list of PIL Images 598 | >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor 599 | >>> ]) 600 | >>> #In your test loop you can do the following: 601 | >>> input, target = batch # input is a 5d tensor, target is 2d 602 | >>> bs, ncrops, c, h, w = input.size() 603 | >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops 604 | >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops 605 | """ 606 | 607 | def __init__(self, size): 608 | self.size = size 609 | if isinstance(size, numbers.Number): 610 | self.size = (int(size), int(size)) 611 | else: 612 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 613 | self.size = size 614 | 615 | def __call__(self, img): 616 | return F.five_crop(img, self.size) 617 | 618 | def __repr__(self): 619 | return self.__class__.__name__ + '(size={0})'.format(self.size) 620 | 621 | 622 | class TenCrop(object): 623 | """Crop the given PIL Image into four corners and the central crop plus the flipped version of 624 | these (horizontal flipping is used by default) 625 | 626 | .. Note:: 627 | This transform returns a tuple of images and there may be a mismatch in the number of 628 | inputs and targets your Dataset returns. See below for an example of how to deal with 629 | this. 630 | 631 | Args: 632 | size (sequence or int): Desired output size of the crop. If size is an 633 | int instead of sequence like (h, w), a square crop (size, size) is 634 | made. 635 | vertical_flip(bool): Use vertical flipping instead of horizontal 636 | 637 | Example: 638 | >>> transform = Compose([ 639 | >>> TenCrop(size), # this is a list of PIL Images 640 | >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor 641 | >>> ]) 642 | >>> #In your test loop you can do the following: 643 | >>> input, target = batch # input is a 5d tensor, target is 2d 644 | >>> bs, ncrops, c, h, w = input.size() 645 | >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops 646 | >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops 647 | """ 648 | 649 | def __init__(self, size, vertical_flip=False): 650 | self.size = size 651 | if isinstance(size, numbers.Number): 652 | self.size = (int(size), int(size)) 653 | else: 654 | assert len(size) == 2, "Please provide only two dimensions (h, w) for size." 655 | self.size = size 656 | self.vertical_flip = vertical_flip 657 | 658 | def __call__(self, img): 659 | return F.ten_crop(img, self.size, self.vertical_flip) 660 | 661 | def __repr__(self): 662 | return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) 663 | 664 | 665 | class LinearTransformation(object): 666 | """Transform a tensor image with a square transformation matrix computed 667 | offline. 668 | 669 | Given transformation_matrix, will flatten the torch.*Tensor, compute the dot 670 | product with the transformation matrix and reshape the tensor to its 671 | original shape. 672 | 673 | Applications: 674 | - whitening: zero-center the data, compute the data covariance matrix 675 | [D x D] with np.dot(X.T, X), perform SVD on this matrix and 676 | pass it as transformation_matrix. 677 | 678 | Args: 679 | transformation_matrix (Tensor): tensor [D x D], D = C x H x W 680 | """ 681 | 682 | def __init__(self, transformation_matrix): 683 | if transformation_matrix.size(0) != transformation_matrix.size(1): 684 | raise ValueError("transformation_matrix should be square. Got " + 685 | "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) 686 | self.transformation_matrix = transformation_matrix 687 | 688 | def __call__(self, tensor): 689 | """ 690 | Args: 691 | tensor (Tensor): Tensor image of size (C, H, W) to be whitened. 692 | 693 | Returns: 694 | Tensor: Transformed image. 695 | """ 696 | if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0): 697 | raise ValueError("tensor and transformation matrix have incompatible shape." + 698 | "[{} x {} x {}] != ".format(*tensor.size()) + 699 | "{}".format(self.transformation_matrix.size(0))) 700 | flat_tensor = tensor.view(1, -1) 701 | transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) 702 | tensor = transformed_tensor.view(tensor.size()) 703 | return tensor 704 | 705 | def __repr__(self): 706 | format_string = self.__class__.__name__ + '(' 707 | format_string += (str(self.transformation_matrix.numpy().tolist()) + ')') 708 | return format_string 709 | 710 | 711 | class ColorJitter(object): 712 | """Randomly change the brightness, contrast and saturation of an image. 713 | 714 | Args: 715 | brightness (float): How much to jitter brightness. brightness_factor 716 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 717 | contrast (float): How much to jitter contrast. contrast_factor 718 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 719 | saturation (float): How much to jitter saturation. saturation_factor 720 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 721 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 722 | [-hue, hue]. Should be >=0 and <= 0.5. 723 | """ 724 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 725 | self.brightness = brightness 726 | self.contrast = contrast 727 | self.saturation = saturation 728 | self.hue = hue 729 | 730 | @staticmethod 731 | def get_params(brightness, contrast, saturation, hue): 732 | """Get a randomized transform to be applied on image. 733 | 734 | Arguments are same as that of __init__. 735 | 736 | Returns: 737 | Transform which randomly adjusts brightness, contrast and 738 | saturation in a random order. 739 | """ 740 | transforms = [] 741 | if brightness > 0: 742 | brightness_factor = random.uniform(max(0, 1 - brightness), 1 + brightness) 743 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 744 | 745 | if contrast > 0: 746 | contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast) 747 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 748 | 749 | if saturation > 0: 750 | saturation_factor = random.uniform(max(0, 1 - saturation), 1 + saturation) 751 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 752 | 753 | if hue > 0: 754 | hue_factor = random.uniform(-hue, hue) 755 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) 756 | 757 | random.shuffle(transforms) 758 | transform = Compose(transforms) 759 | 760 | return transform 761 | 762 | def __call__(self, img): 763 | """ 764 | Args: 765 | img (PIL Image): Input image. 766 | 767 | Returns: 768 | PIL Image: Color jittered image. 769 | """ 770 | transform = self.get_params(self.brightness, self.contrast, 771 | self.saturation, self.hue) 772 | return transform(img) 773 | 774 | def __repr__(self): 775 | format_string = self.__class__.__name__ + '(' 776 | format_string += 'brightness={0}'.format(self.brightness) 777 | format_string += ', contrast={0}'.format(self.contrast) 778 | format_string += ', saturation={0}'.format(self.saturation) 779 | format_string += ', hue={0})'.format(self.hue) 780 | return format_string 781 | 782 | 783 | class RandomRotation(object): 784 | """Rotate the image by angle. 785 | 786 | Args: 787 | degrees (sequence or float or int): Range of degrees to select from. 788 | If degrees is a number instead of sequence like (min, max), the range of degrees 789 | will be (-degrees, +degrees). 790 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 791 | An optional resampling filter. 792 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 793 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 794 | expand (bool, optional): Optional expansion flag. 795 | If true, expands the output to make it large enough to hold the entire rotated image. 796 | If false or omitted, make the output image the same size as the input image. 797 | Note that the expand flag assumes rotation around the center and no translation. 798 | center (2-tuple, optional): Optional center of rotation. 799 | Origin is the upper left corner. 800 | Default is the center of the image. 801 | """ 802 | 803 | def __init__(self, degrees, resample=False, expand=False, center=None): 804 | if isinstance(degrees, numbers.Number): 805 | if degrees < 0: 806 | raise ValueError("If degrees is a single number, it must be positive.") 807 | self.degrees = (-degrees, degrees) 808 | else: 809 | if len(degrees) != 2: 810 | raise ValueError("If degrees is a sequence, it must be of len 2.") 811 | self.degrees = degrees 812 | 813 | self.resample = resample 814 | self.expand = expand 815 | self.center = center 816 | 817 | @staticmethod 818 | def get_params(degrees): 819 | """Get parameters for ``rotate`` for a random rotation. 820 | 821 | Returns: 822 | sequence: params to be passed to ``rotate`` for random rotation. 823 | """ 824 | angle = random.uniform(degrees[0], degrees[1]) 825 | 826 | return angle 827 | 828 | def __call__(self, img): 829 | """ 830 | img (PIL Image): Image to be rotated. 831 | 832 | Returns: 833 | PIL Image: Rotated image. 834 | """ 835 | 836 | angle = self.get_params(self.degrees) 837 | 838 | return F.rotate(img, angle, self.resample, self.expand, self.center) 839 | 840 | def __repr__(self): 841 | format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) 842 | format_string += ', resample={0}'.format(self.resample) 843 | format_string += ', expand={0}'.format(self.expand) 844 | if self.center is not None: 845 | format_string += ', center={0}'.format(self.center) 846 | format_string += ')' 847 | return format_string 848 | 849 | 850 | class RandomAffine(object): 851 | """Random affine transformation of the image keeping center invariant 852 | 853 | Args: 854 | degrees (sequence or float or int): Range of degrees to select from. 855 | If degrees is a number instead of sequence like (min, max), the range of degrees 856 | will be (-degrees, +degrees). Set to 0 to desactivate rotations. 857 | translate (tuple, optional): tuple of maximum absolute fraction for horizontal 858 | and vertical translations. For example translate=(a, b), then horizontal shift 859 | is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is 860 | randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. 861 | scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is 862 | randomly sampled from the range a <= scale <= b. Will keep original scale by default. 863 | shear (sequence or float or int, optional): Range of degrees to select from. 864 | If degrees is a number instead of sequence like (min, max), the range of degrees 865 | will be (-degrees, +degrees). Will not apply shear by default 866 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 867 | An optional resampling filter. 868 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 869 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 870 | fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) 871 | """ 872 | 873 | def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): 874 | if isinstance(degrees, numbers.Number): 875 | if degrees < 0: 876 | raise ValueError("If degrees is a single number, it must be positive.") 877 | self.degrees = (-degrees, degrees) 878 | else: 879 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ 880 | "degrees should be a list or tuple and it must be of length 2." 881 | self.degrees = degrees 882 | 883 | if translate is not None: 884 | assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ 885 | "translate should be a list or tuple and it must be of length 2." 886 | for t in translate: 887 | if not (0.0 <= t <= 1.0): 888 | raise ValueError("translation values should be between 0 and 1") 889 | self.translate = translate 890 | 891 | if scale is not None: 892 | assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ 893 | "scale should be a list or tuple and it must be of length 2." 894 | for s in scale: 895 | if s <= 0: 896 | raise ValueError("scale values should be positive") 897 | self.scale = scale 898 | 899 | if shear is not None: 900 | if isinstance(shear, numbers.Number): 901 | if shear < 0: 902 | raise ValueError("If shear is a single number, it must be positive.") 903 | self.shear = (-shear, shear) 904 | else: 905 | assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ 906 | "shear should be a list or tuple and it must be of length 2." 907 | self.shear = shear 908 | else: 909 | self.shear = shear 910 | 911 | self.resample = resample 912 | self.fillcolor = fillcolor 913 | 914 | @staticmethod 915 | def get_params(degrees, translate, scale_ranges, shears, img_size): 916 | """Get parameters for affine transformation 917 | 918 | Returns: 919 | sequence: params to be passed to the affine transformation 920 | """ 921 | angle = random.uniform(degrees[0], degrees[1]) 922 | if translate is not None: 923 | max_dx = translate[0] * img_size[0] 924 | max_dy = translate[1] * img_size[1] 925 | translations = (np.round(random.uniform(-max_dx, max_dx)), 926 | np.round(random.uniform(-max_dy, max_dy))) 927 | else: 928 | translations = (0, 0) 929 | 930 | if scale_ranges is not None: 931 | scale = random.uniform(scale_ranges[0], scale_ranges[1]) 932 | else: 933 | scale = 1.0 934 | 935 | if shears is not None: 936 | shear = random.uniform(shears[0], shears[1]) 937 | else: 938 | shear = 0.0 939 | 940 | return angle, translations, scale, shear 941 | 942 | def __call__(self, img): 943 | """ 944 | img (PIL Image): Image to be transformed. 945 | 946 | Returns: 947 | PIL Image: Affine transformed image. 948 | """ 949 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) 950 | return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) 951 | 952 | def __repr__(self): 953 | s = '{name}(degrees={degrees}' 954 | if self.translate is not None: 955 | s += ', translate={translate}' 956 | if self.scale is not None: 957 | s += ', scale={scale}' 958 | if self.shear is not None: 959 | s += ', shear={shear}' 960 | if self.resample > 0: 961 | s += ', resample={resample}' 962 | if self.fillcolor != 0: 963 | s += ', fillcolor={fillcolor}' 964 | s += ')' 965 | d = dict(self.__dict__) 966 | d['resample'] = _pil_interpolation_to_str[d['resample']] 967 | return s.format(name=self.__class__.__name__, **d) 968 | 969 | 970 | class Grayscale(object): 971 | """Convert image to grayscale. 972 | 973 | Args: 974 | num_output_channels (int): (1 or 3) number of channels desired for output image 975 | 976 | Returns: 977 | PIL Image: Grayscale version of the input. 978 | - If num_output_channels == 1 : returned image is single channel 979 | - If num_output_channels == 3 : returned image is 3 channel with r == g == b 980 | 981 | """ 982 | 983 | def __init__(self, num_output_channels=1): 984 | self.num_output_channels = num_output_channels 985 | 986 | def __call__(self, img): 987 | """ 988 | Args: 989 | img (PIL Image): Image to be converted to grayscale. 990 | 991 | Returns: 992 | PIL Image: Randomly grayscaled image. 993 | """ 994 | return F.to_grayscale(img, num_output_channels=self.num_output_channels) 995 | 996 | def __repr__(self): 997 | return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) 998 | 999 | 1000 | class RandomGrayscale(object): 1001 | """Randomly convert image to grayscale with a probability of p (default 0.1). 1002 | 1003 | Args: 1004 | p (float): probability that image should be converted to grayscale. 1005 | 1006 | Returns: 1007 | PIL Image: Grayscale version of the input image with probability p and unchanged 1008 | with probability (1-p). 1009 | - If input image is 1 channel: grayscale version is 1 channel 1010 | - If input image is 3 channel: grayscale version is 3 channel with r == g == b 1011 | 1012 | """ 1013 | 1014 | def __init__(self, p=0.1): 1015 | self.p = p 1016 | 1017 | def __call__(self, img): 1018 | """ 1019 | Args: 1020 | img (PIL Image): Image to be converted to grayscale. 1021 | 1022 | Returns: 1023 | PIL Image: Randomly grayscaled image. 1024 | """ 1025 | num_output_channels = 1 if img.mode == 'L' else 3 1026 | if random.random() < self.p: 1027 | return F.to_grayscale(img, num_output_channels=num_output_channels) 1028 | return img 1029 | 1030 | def __repr__(self): 1031 | return self.__class__.__name__ + '(p={0})'.format(self.p) 1032 | -------------------------------------------------------------------------------- /src/utils/read_write_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | the tool to read or write the data. Have a good luck ! 4 | 5 | @author: zifyloo 6 | """ 7 | 8 | import os 9 | import json 10 | import pickle 11 | 12 | 13 | def makedir(root): 14 | if not os.path.exists(root): 15 | os.makedirs(root) 16 | 17 | 18 | def write_json(data, root): 19 | with open(root, 'w') as f: 20 | json.dump(data, f) 21 | 22 | 23 | def read_json(root): 24 | with open(root, 'r') as f: 25 | data = json.load(f) 26 | 27 | return data 28 | 29 | 30 | def read_dict(root): 31 | with open(root, 'rb') as f: 32 | data = pickle.load(f) 33 | 34 | return data 35 | 36 | 37 | def save_dict(data, name): 38 | with open(name + '.pkl', 'wb') as f: 39 | pickle.dump(data, f, pickle.HIGHEST_PROTOCOL) 40 | 41 | 42 | def write_txt(data, name): 43 | with open(name, 'a') as f: 44 | f.write(data) 45 | f.write('\n') 46 | 47 | 48 | 49 | 50 | --------------------------------------------------------------------------------