├── .gitignore ├── LICENSE ├── Readme.md ├── data ├── aircrafts.py ├── cars.py ├── cub200.py ├── flowers.py ├── folder_download.py ├── folder_train.py ├── prepare_data.py └── stanford_dogs.py ├── main.py ├── models ├── DPP │ ├── Readme.md │ ├── _ext │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ └── dpp │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ │ └── _dpp.so │ ├── build.py │ ├── functions │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── dpp.cpython-36.pyc │ │ ├── dpp.py │ │ └── test_dpp.py │ ├── modules │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── dpp.cpython-36.pyc │ │ └── dpp.py │ └── src │ │ ├── dpp.c │ │ └── dpp.h ├── model_construct.py ├── partnet_vgg.py ├── roi_pooling │ ├── Readme.md │ ├── _ext │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ └── roi_pooling │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ │ └── _roi_pooling.so │ ├── build.py │ ├── functions │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ └── roi_pool.cpython-36.pyc │ │ └── roi_pool.py │ ├── modules │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── roi_pool.cpython-36.pyc │ │ │ └── roi_pool_py.cpython-36.pyc │ │ └── roi_pool.py │ └── src │ │ ├── roi_pooling.c │ │ ├── roi_pooling.cu.o │ │ ├── roi_pooling.h │ │ ├── roi_pooling_cuda.c │ │ ├── roi_pooling_cuda.h │ │ ├── roi_pooling_kernel.cu │ │ └── roi_pooling_kernel.h └── vgg.py ├── opts.py ├── process.py ├── run.sh └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yabin Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Part-Aware Fine-grained Object Categorization using Weakly Supervised Part Detection Network 2 | 3 | The **pytorch** implementation of the paper: Part-Aware Fine-grained Object Categorization using Weakly Supervised Part Detection Network. 4 | 5 | The paper is available at: https://arxiv.org/abs/1806.06198 6 | 7 | A **torch** version of the PartNet (Only the PartNet module) is going to be available. 8 | 9 | To train the PartNet based on the ImageNet pre-trained model (e.g., VGGNet), you need to download the ImageNet pre-trained model firstly and place it in the "vgg19-bn-modified" folder. 10 | We provide a ImageNet pre-trained model, which obtains 74.262% top1 acc (almost the same as 74.266% provided by: https://github.com/Cadene/pretrained-models.pytorch#torchvision). 11 | 12 | The commands to train the model from scratch and to verify the final results can be found in the **run.sh**. 13 | 14 | We provide all the intermediate models for fast implementation and verification, which can be downloaded from: 15 | * [Baidu Cloud](https://pan.baidu.com/s/1h5oTI4POrSWBo_XEDkFZnw) 16 | * [Google Cloud](https://drive.google.com/drive/folders/1HNXGE2fI5BHSHCROw8aXyKc4R2aZJzDx?usp=sharing) 17 | 18 | Note that the 'DPP' and 'ROI-Pooling' modules are need be compiled. Details can be find the Readme.md in each folder. 19 | 20 | If you have any question about the paper and the code, feel free to sent email to me: zhang.yabin@mail.scut.edu.cn 21 | 22 | -------------------------------------------------------------------------------- /data/aircrafts.py: -------------------------------------------------------------------------------- 1 | ################################################################# 2 | ## The data preparation for the aircraft dataset 3 | ################################################################# 4 | import os 5 | import shutil 6 | import torch 7 | import torchvision.transforms as transforms 8 | from data.folder_train import ImageFolder_train 9 | from data.folder_download import ImageFolder_download 10 | 11 | def split_train_test_images(data_dir): 12 | #data_dir = '/data/cars/' 13 | raise ValueError('the process of generate splited training and test images is empty') 14 | 15 | def aircrafts(args, process_name, part_index): 16 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 17 | std=[0.229, 0.224, 0.225]) 18 | if process_name == 'image_classifier' or process_name == 'partnet' or process_name == 'download_proposals': 19 | traindir = os.path.join(args.data_path, 'splited_image/train') 20 | valdir = os.path.join(args.data_path, 'splited_image/val') 21 | if not os.path.isdir(traindir): 22 | print('the cub images are not well splited, split all images into train and val set') 23 | split_train_test_images(args.data_path) 24 | if process_name == 'download_proposals': 25 | print('the image pre-process for process: download proposals is Resize 512 and Center Crop 448') 26 | train_dataset = ImageFolder_download( 27 | root=traindir, 28 | transform=transforms.Compose([ 29 | transforms.Resize(512), 30 | transforms.CenterCrop(448), 31 | transforms.ToTensor(), 32 | normalize, 33 | ]), 34 | transform_keep=transforms.Compose([ 35 | transforms.Resize(512), 36 | transforms.CenterCrop(448), 37 | transforms.ToTensor() 38 | ]), 39 | dataset_path=args.data_path 40 | ) 41 | train_loader = torch.utils.data.DataLoader( 42 | train_dataset, batch_size=args.batch_size_partnet, shuffle=False, num_workers=args.workers, 43 | pin_memory=True, sampler=None 44 | ) 45 | val_loader = torch.utils.data.DataLoader( 46 | ImageFolder_download( 47 | root=valdir, 48 | transform=transforms.Compose([ 49 | transforms.Resize(512), 50 | transforms.CenterCrop(448), 51 | transforms.ToTensor(), 52 | normalize, 53 | ]), 54 | transform_keep=transforms.Compose([ 55 | transforms.Resize(512), 56 | transforms.CenterCrop(448), 57 | transforms.ToTensor(), 58 | ]), 59 | dataset_path=args.data_path), 60 | batch_size=args.batch_size_partnet, shuffle=False, 61 | num_workers=args.workers, pin_memory=True 62 | ) 63 | return train_loader, val_loader 64 | elif process_name == 'image_classifier': 65 | print('the image pre-process for process: image_classifier is Resize 512 and Random Crop 448') 66 | train_dataset = ImageFolder_train( 67 | traindir, 68 | transforms.Compose([ 69 | transforms.Resize(512), 70 | transforms.RandomCrop(448), 71 | transforms.ToTensor(), 72 | normalize, 73 | ]) 74 | ) 75 | train_loader = torch.utils.data.DataLoader( 76 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 77 | pin_memory=True, sampler=None 78 | ) 79 | val_loader = torch.utils.data.DataLoader( 80 | ImageFolder_train(valdir, transforms.Compose([ 81 | transforms.Resize(512), 82 | transforms.CenterCrop(448), 83 | transforms.ToTensor(), 84 | normalize, 85 | ])), 86 | batch_size=args.batch_size, shuffle=False, 87 | num_workers=args.workers, pin_memory=True 88 | ) 89 | return train_loader, val_loader 90 | elif process_name == 'partnet': 91 | print('the image pre-process for process: partnet is Resize 512 and Random Crop 448') 92 | train_dataset = ImageFolder_train( 93 | traindir, 94 | transforms.Compose([ 95 | transforms.Resize(512), 96 | transforms.RandomCrop(448), 97 | transforms.ToTensor(), 98 | normalize, 99 | ]) 100 | ) 101 | train_loader = torch.utils.data.DataLoader( 102 | train_dataset, batch_size=args.batch_size_partnet, shuffle=True, num_workers=args.workers, 103 | pin_memory=True, sampler=None 104 | ) 105 | val_loader = torch.utils.data.DataLoader( 106 | ImageFolder_train(valdir, transforms.Compose([ 107 | transforms.Resize(512), 108 | transforms.CenterCrop(448), 109 | transforms.ToTensor(), 110 | normalize, 111 | ])), 112 | batch_size=args.batch_size_partnet, shuffle=False, 113 | num_workers=args.workers, pin_memory=True 114 | ) 115 | return train_loader, val_loader 116 | elif process_name == 'part_classifiers': 117 | traindir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/train/' 118 | valdir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/val/' 119 | print('the image pre-process for process: part_classifier is Resize (448, 448) directly') 120 | train_dataset = ImageFolder_train( 121 | traindir, 122 | transforms.Compose([ 123 | transforms.Resize((448, 448)), 124 | transforms.ToTensor(), 125 | normalize, 126 | ]) 127 | ) 128 | train_loader = torch.utils.data.DataLoader( 129 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 130 | pin_memory=True, sampler=None 131 | ) 132 | val_loader = torch.utils.data.DataLoader( 133 | ImageFolder_train(valdir, transforms.Compose([ 134 | transforms.Resize((448, 448)), 135 | transforms.ToTensor(), 136 | normalize, 137 | ])), 138 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers, 139 | pin_memory=True, sampler=None 140 | ) 141 | return train_loader, val_loader 142 | 143 | 144 | -------------------------------------------------------------------------------- /data/cars.py: -------------------------------------------------------------------------------- 1 | ################################################################# 2 | ## The data preparation for the aircraft dataset 3 | ############################################################ 4 | import os 5 | import shutil 6 | import torch 7 | import torchvision.transforms as transforms 8 | from data.folder_train import ImageFolder_train 9 | from data.folder_download import ImageFolder_download 10 | 11 | def split_train_test_images(data_dir): 12 | #data_dir = '/data/cars/' 13 | raise ValueError('the process of generate splited training and test images is empty') 14 | 15 | def cars(args, process_name, part_index): 16 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 17 | std=[0.229, 0.224, 0.225]) 18 | if process_name == 'image_classifier' or process_name == 'partnet' or process_name == 'download_proposals': 19 | traindir = os.path.join(args.data_path, 'splited_image/train') 20 | valdir = os.path.join(args.data_path, 'splited_image/val') 21 | if not os.path.isdir(traindir): 22 | print('the cub images are not well splited, split all images into train and val set') 23 | split_train_test_images(args.data_path) 24 | if process_name == 'download_proposals': 25 | print('the image pre-process for process: download proposals is Resize 512 and Center Crop 448') 26 | train_dataset = ImageFolder_download( 27 | root=traindir, 28 | transform=transforms.Compose([ 29 | transforms.Resize(512), 30 | transforms.CenterCrop(448), 31 | transforms.ToTensor(), 32 | normalize, 33 | ]), 34 | transform_keep=transforms.Compose([ 35 | transforms.Resize(512), 36 | transforms.CenterCrop(448), 37 | transforms.ToTensor() 38 | ]), 39 | dataset_path=args.data_path 40 | ) 41 | train_loader = torch.utils.data.DataLoader( 42 | train_dataset, batch_size=args.batch_size_partnet, shuffle=False, num_workers=args.workers, 43 | pin_memory=True, sampler=None 44 | ) 45 | val_loader = torch.utils.data.DataLoader( 46 | ImageFolder_download( 47 | root=valdir, 48 | transform=transforms.Compose([ 49 | transforms.Resize(512), 50 | transforms.CenterCrop(448), 51 | transforms.ToTensor(), 52 | normalize, 53 | ]), 54 | transform_keep=transforms.Compose([ 55 | transforms.Resize(512), 56 | transforms.CenterCrop(448), 57 | transforms.ToTensor(), 58 | ]), 59 | dataset_path=args.data_path), 60 | batch_size=args.batch_size_partnet, shuffle=False, 61 | num_workers=args.workers, pin_memory=True 62 | ) 63 | return train_loader, val_loader 64 | elif process_name == 'image_classifier': 65 | print('the image pre-process for process: image_classifier is Resize 512 and Random Crop 448') 66 | train_dataset = ImageFolder_train( 67 | traindir, 68 | transforms.Compose([ 69 | transforms.Resize(512), 70 | transforms.RandomCrop(448), 71 | transforms.ToTensor(), 72 | normalize, 73 | ]) 74 | ) 75 | train_loader = torch.utils.data.DataLoader( 76 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 77 | pin_memory=True, sampler=None 78 | ) 79 | val_loader = torch.utils.data.DataLoader( 80 | ImageFolder_train(valdir, transforms.Compose([ 81 | transforms.Resize(512), 82 | transforms.CenterCrop(448), 83 | transforms.ToTensor(), 84 | normalize, 85 | ])), 86 | batch_size=args.batch_size, shuffle=False, 87 | num_workers=args.workers, pin_memory=True 88 | ) 89 | return train_loader, val_loader 90 | elif process_name == 'partnet': 91 | print('the image pre-process for process: partnet is Resize 512 and Random Crop 448') 92 | train_dataset = ImageFolder_train( 93 | traindir, 94 | transforms.Compose([ 95 | transforms.Resize(512), 96 | transforms.RandomCrop(448), 97 | transforms.ToTensor(), 98 | normalize, 99 | ]) 100 | ) 101 | train_loader = torch.utils.data.DataLoader( 102 | train_dataset, batch_size=args.batch_size_partnet, shuffle=True, num_workers=args.workers, 103 | pin_memory=True, sampler=None 104 | ) 105 | val_loader = torch.utils.data.DataLoader( 106 | ImageFolder_train(valdir, transforms.Compose([ 107 | transforms.Resize(512), 108 | transforms.CenterCrop(448), 109 | transforms.ToTensor(), 110 | normalize, 111 | ])), 112 | batch_size=args.batch_size_partnet, shuffle=False, 113 | num_workers=args.workers, pin_memory=True 114 | ) 115 | return train_loader, val_loader 116 | elif process_name == 'part_classifiers': 117 | traindir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/train/' 118 | valdir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/val/' 119 | print('the image pre-process for process: part_classifier is Resize (448, 448) directly') 120 | train_dataset = ImageFolder_train( 121 | traindir, 122 | transforms.Compose([ 123 | transforms.Resize((448, 448)), 124 | transforms.ToTensor(), 125 | normalize, 126 | ]) 127 | ) 128 | train_loader = torch.utils.data.DataLoader( 129 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 130 | pin_memory=True, sampler=None 131 | ) 132 | val_loader = torch.utils.data.DataLoader( 133 | ImageFolder_train(valdir, transforms.Compose([ 134 | transforms.Resize((448, 448)), 135 | transforms.ToTensor(), 136 | normalize, 137 | ])), 138 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers, 139 | pin_memory=True, sampler=None 140 | ) 141 | return train_loader, val_loader 142 | 143 | 144 | -------------------------------------------------------------------------------- /data/cub200.py: -------------------------------------------------------------------------------- 1 | ################################################################# 2 | ## The data preparation for the cub-200-2011 dataset 3 | ## The cub-200-2011 dataset can be downloaded at: 4 | ## http://www.vision.caltech.edu/visipedia/CUB-200.html 5 | ## All the downloaded related file should be placed at the "args.data_path". 6 | ################################################################# 7 | import os 8 | import shutil 9 | import torch 10 | import torchvision.transforms as transforms 11 | from data.folder_train import ImageFolder_train 12 | from data.folder_download import ImageFolder_download 13 | 14 | def split_train_test_images(data_dir): 15 | #data_dir = '/home/lab-zhangyabin/project/fine-grained/CUB_200_2011/' 16 | src_dir = os.path.join(data_dir, 'images') 17 | target_dir = os.path.join(data_dir, 'splited_image') 18 | if not os.path.isdir(target_dir): 19 | os.makedirs(target_dir) 20 | print(src_dir) 21 | train_test_split = open(os.path.join(data_dir, 'train_test_split.txt')) 22 | line = train_test_split.readline() 23 | images = open(os.path.join(data_dir, 'images.txt')) 24 | images_line = images.readline() 25 | ########################## 26 | # print(images_line) 27 | image_list = str.split(images_line) 28 | # print(image_list[1]) 29 | subclass_name = image_list[1].split('/')[0] 30 | # print(subclass_name) 31 | 32 | # print(line) 33 | class_list = str.split(line)[1] 34 | # print(class_list) 35 | 36 | print('begin to prepare the dataset CUB') 37 | count = 0 38 | while images_line: 39 | print(count) 40 | count = count + 1 41 | image_list = str.split(images_line) 42 | subclass_name = image_list[1].split('/')[0] # get the name of the subclass 43 | # print(image_list[0]) 44 | class_label = str.split(line)[1] # get the label of the image 45 | # print(type(int(class_label))) 46 | test_or_train = 'train' 47 | if class_label == '0': # the class belong to the train dataset 48 | test_or_train = 'val' 49 | train_test_dir = os.path.join(target_dir, test_or_train) 50 | if not os.path.isdir(train_test_dir): 51 | os.makedirs(train_test_dir) 52 | subclass_dir = os.path.join(train_test_dir, subclass_name) 53 | if not os.path.isdir(subclass_dir): 54 | os.makedirs(subclass_dir) 55 | 56 | souce_pos = os.path.join(src_dir, image_list[1]) 57 | targer_pos = os.path.join(subclass_dir, image_list[1].split('/')[1]) 58 | shutil.copyfile(souce_pos, targer_pos) 59 | images_line = images.readline() 60 | line = train_test_split.readline() 61 | 62 | def cub200(args, process_name, part_index): 63 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 64 | std=[0.229, 0.224, 0.225]) 65 | if process_name == 'image_classifier' or process_name == 'partnet' or process_name == 'download_proposals': 66 | traindir = os.path.join(args.data_path, 'splited_image/train') 67 | valdir = os.path.join(args.data_path, 'splited_image/val') 68 | if not os.path.isdir(traindir): 69 | print('the cub images are not well splited, split all images into train and val set') 70 | split_train_test_images(args.data_path) 71 | if process_name == 'download_proposals': 72 | print('the image pre-process for process: download proposals is Resize 512 and Center Crop 448') 73 | train_dataset = ImageFolder_download( 74 | root=traindir, 75 | transform=transforms.Compose([ 76 | transforms.Resize(512), 77 | transforms.CenterCrop(448), 78 | transforms.ToTensor(), 79 | normalize, 80 | ]), 81 | transform_keep=transforms.Compose([ 82 | transforms.Resize(512), 83 | transforms.CenterCrop(448), 84 | transforms.ToTensor() 85 | ]), 86 | dataset_path=args.data_path 87 | ) 88 | train_loader = torch.utils.data.DataLoader( 89 | train_dataset, batch_size=args.batch_size_partnet, shuffle=False, num_workers=args.workers, 90 | pin_memory=True, sampler=None 91 | ) 92 | val_loader = torch.utils.data.DataLoader( 93 | ImageFolder_download( 94 | root=valdir, 95 | transform=transforms.Compose([ 96 | transforms.Resize(512), 97 | transforms.CenterCrop(448), 98 | transforms.ToTensor(), 99 | normalize, 100 | ]), 101 | transform_keep=transforms.Compose([ 102 | transforms.Resize(512), 103 | transforms.CenterCrop(448), 104 | transforms.ToTensor(), 105 | ]), 106 | dataset_path=args.data_path), 107 | batch_size=args.batch_size_partnet, shuffle=False, 108 | num_workers=args.workers, pin_memory=True 109 | ) 110 | return train_loader, val_loader 111 | elif process_name == 'image_classifier': 112 | print('the image pre-process for process: image_classifier is Resize 512 and Random Crop 448') 113 | train_dataset = ImageFolder_train( 114 | traindir, 115 | transforms.Compose([ 116 | transforms.Resize(512), 117 | transforms.RandomCrop(448), 118 | transforms.ToTensor(), 119 | normalize, 120 | ]) 121 | ) 122 | train_loader = torch.utils.data.DataLoader( 123 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 124 | pin_memory=True, sampler=None 125 | ) 126 | val_loader = torch.utils.data.DataLoader( 127 | ImageFolder_train(valdir, transforms.Compose([ 128 | transforms.Resize(512), 129 | transforms.CenterCrop(448), 130 | transforms.ToTensor(), 131 | normalize, 132 | ])), 133 | batch_size=args.batch_size, shuffle=False, 134 | num_workers=args.workers, pin_memory=True 135 | ) 136 | return train_loader, val_loader 137 | elif process_name == 'partnet': 138 | print('the image pre-process for process: partnet is Resize 512 and Random Crop 448') 139 | train_dataset = ImageFolder_train( 140 | traindir, 141 | transforms.Compose([ 142 | transforms.Resize(512), 143 | transforms.RandomCrop(448), 144 | transforms.ToTensor(), 145 | normalize, 146 | ]) 147 | ) 148 | train_loader = torch.utils.data.DataLoader( 149 | train_dataset, batch_size=args.batch_size_partnet, shuffle=True, num_workers=args.workers, 150 | pin_memory=True, sampler=None 151 | ) 152 | val_loader = torch.utils.data.DataLoader( 153 | ImageFolder_train(valdir, transforms.Compose([ 154 | transforms.Resize(512), 155 | transforms.CenterCrop(448), 156 | transforms.ToTensor(), 157 | normalize, 158 | ])), 159 | batch_size=args.batch_size_partnet, shuffle=False, 160 | num_workers=args.workers, pin_memory=True 161 | ) 162 | return train_loader, val_loader 163 | elif process_name == 'part_classifiers': 164 | traindir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/train/' 165 | valdir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/val/' 166 | print('the image pre-process for process: part_classifier is Resize (448, 448) directly') 167 | train_dataset = ImageFolder_train( 168 | traindir, 169 | transforms.Compose([ 170 | transforms.Resize((448, 448)), 171 | transforms.ToTensor(), 172 | normalize, 173 | ]) 174 | ) 175 | train_loader = torch.utils.data.DataLoader( 176 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 177 | pin_memory=True, sampler=None 178 | ) 179 | val_loader = torch.utils.data.DataLoader( 180 | ImageFolder_train(valdir, transforms.Compose([ 181 | transforms.Resize((448, 448)), 182 | transforms.ToTensor(), 183 | normalize, 184 | ])), 185 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers, 186 | pin_memory=True, sampler=None 187 | ) 188 | return train_loader, val_loader 189 | 190 | 191 | -------------------------------------------------------------------------------- /data/flowers.py: -------------------------------------------------------------------------------- 1 | ################################################################# 2 | ## The data preparation for flower-102 dataset 3 | ## The flower-102 dataset can be downloaded at: 4 | ## http://www.robots.ox.ac.uk/~vgg/data/flowers/102/ 5 | ## All the downloaded related file should be placed at the "args.data_path". 6 | ################################################################# 7 | import os 8 | import shutil 9 | import torch 10 | import torchvision.transforms as transforms 11 | from data.folder_train import ImageFolder_train 12 | from data.folder_download import ImageFolder_download 13 | import scipy.io as scio 14 | 15 | def split_train_test_images(data_dir): 16 | image_mat = data_dir + 'imagelabels.mat' 17 | split_mat = data_dir + 'setid.mat' 18 | labels = scio.loadmat(image_mat)['labels'][0] 19 | split = scio.loadmat(split_mat) 20 | train_split = split['trnid'][0] 21 | test_split = split['tstid'][0] 22 | val_split = split['valid'][0] 23 | 24 | for i in range(len(train_split)): 25 | label = labels[train_split[i] - 1] 26 | if train_split[i] > 999: 27 | dir_temp = 'image_0' + str(train_split[i]) + '.jpg' 28 | elif train_split[i] > 99: 29 | dir_temp = 'image_00' + str(train_split[i]) + '.jpg' 30 | elif train_split[i] > 9: 31 | dir_temp = 'image_000' + str(train_split[i]) + '.jpg' 32 | else: 33 | dir_temp = 'image_0000' + str(train_split[i]) + '.jpg' 34 | original_image = data_dir + 'jpg/' + dir_temp 35 | target_dir = data_dir + 'splited_image/train/' + str(label) + '/' 36 | target_image = target_dir + dir_temp 37 | if not os.path.isdir(target_dir): 38 | os.makedirs(target_dir) 39 | shutil.copyfile(original_image, target_image) 40 | 41 | for i in range(len(test_split)): 42 | label = labels[test_split[i] - 1] 43 | if test_split[i] > 999: 44 | dir_temp = 'image_0' + str(test_split[i]) + '.jpg' 45 | elif test_split[i] > 99: 46 | dir_temp = 'image_00' + str(test_split[i]) + '.jpg' 47 | elif test_split[i] > 9: 48 | dir_temp = 'image_000' + str(test_split[i]) + '.jpg' 49 | else: 50 | dir_temp = 'image_0000' + str(test_split[i]) + '.jpg' 51 | original_image = data_dir + 'jpg/' + dir_temp 52 | target_dir = data_dir + 'splited_image/val/' + str(label) + '/' 53 | target_image = target_dir + dir_temp 54 | if not os.path.isdir(target_dir): 55 | os.makedirs(target_dir) 56 | shutil.copyfile(original_image, target_image) 57 | 58 | for i in range(len(val_split)): 59 | label = labels[val_split[i] - 1] 60 | if val_split[i] > 999: 61 | dir_temp = 'image_0' + str(val_split[i]) + '.jpg' 62 | elif val_split[i] > 99: 63 | dir_temp = 'image_00' + str(val_split[i]) + '.jpg' 64 | elif val_split[i] > 9: 65 | dir_temp = 'image_000' + str(val_split[i]) + '.jpg' 66 | else: 67 | dir_temp = 'image_0000' + str(val_split[i]) + '.jpg' 68 | original_image = data_dir + 'jpg/' + dir_temp 69 | target_dir = data_dir + 'splited_image/train/' + str(label) + '/' 70 | target_image = target_dir + dir_temp 71 | if not os.path.isdir(target_dir): 72 | os.makedirs(target_dir) 73 | shutil.copyfile(original_image, target_image) 74 | 75 | def flowers(args, process_name, part_index): 76 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 77 | std=[0.229, 0.224, 0.225]) 78 | if process_name == 'image_classifier' or process_name == 'partnet' or process_name == 'download_proposals': 79 | traindir = os.path.join(args.data_path, 'splited_image/train') 80 | valdir = os.path.join(args.data_path, 'splited_image/val') 81 | if not os.path.isdir(traindir): 82 | print('the cub images are not well splited, split all images into train and val set') 83 | split_train_test_images(args.data_path) 84 | if process_name == 'download_proposals': 85 | print('the image pre-process for process: download proposals is Resize 512 and Center Crop 448') 86 | train_dataset = ImageFolder_download( 87 | root=traindir, 88 | transform=transforms.Compose([ 89 | transforms.Resize(512), 90 | transforms.CenterCrop(448), 91 | transforms.ToTensor(), 92 | normalize, 93 | ]), 94 | transform_keep=transforms.Compose([ 95 | transforms.Resize(512), 96 | transforms.CenterCrop(448), 97 | transforms.ToTensor() 98 | ]), 99 | dataset_path=args.data_path 100 | ) 101 | train_loader = torch.utils.data.DataLoader( 102 | train_dataset, batch_size=args.batch_size_partnet, shuffle=False, num_workers=args.workers, 103 | pin_memory=True, sampler=None 104 | ) 105 | val_loader = torch.utils.data.DataLoader( 106 | ImageFolder_download( 107 | root=valdir, 108 | transform=transforms.Compose([ 109 | transforms.Resize(512), 110 | transforms.CenterCrop(448), 111 | transforms.ToTensor(), 112 | normalize, 113 | ]), 114 | transform_keep=transforms.Compose([ 115 | transforms.Resize(512), 116 | transforms.CenterCrop(448), 117 | transforms.ToTensor(), 118 | ]), 119 | dataset_path=args.data_path), 120 | batch_size=args.batch_size_partnet, shuffle=False, 121 | num_workers=args.workers, pin_memory=True 122 | ) 123 | return train_loader, val_loader 124 | elif process_name == 'image_classifier': 125 | print('the image pre-process for process: image_classifier is Resize 512 and Random Crop 448') 126 | train_dataset = ImageFolder_train( 127 | traindir, 128 | transforms.Compose([ 129 | transforms.Resize(512), 130 | transforms.RandomCrop(448), 131 | transforms.ToTensor(), 132 | normalize, 133 | ]) 134 | ) 135 | train_loader = torch.utils.data.DataLoader( 136 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 137 | pin_memory=True, sampler=None 138 | ) 139 | val_loader = torch.utils.data.DataLoader( 140 | ImageFolder_train(valdir, transforms.Compose([ 141 | transforms.Resize(512), 142 | transforms.CenterCrop(448), 143 | transforms.ToTensor(), 144 | normalize, 145 | ])), 146 | batch_size=args.batch_size, shuffle=False, 147 | num_workers=args.workers, pin_memory=True 148 | ) 149 | return train_loader, val_loader 150 | elif process_name == 'partnet': 151 | print('the image pre-process for process: partnet is Resize 512 and Random Crop 448') 152 | train_dataset = ImageFolder_train( 153 | traindir, 154 | transforms.Compose([ 155 | transforms.Resize(512), 156 | transforms.RandomCrop(448), 157 | transforms.ToTensor(), 158 | normalize, 159 | ]) 160 | ) 161 | train_loader = torch.utils.data.DataLoader( 162 | train_dataset, batch_size=args.batch_size_partnet, shuffle=True, num_workers=args.workers, 163 | pin_memory=True, sampler=None 164 | ) 165 | val_loader = torch.utils.data.DataLoader( 166 | ImageFolder_train(valdir, transforms.Compose([ 167 | transforms.Resize(512), 168 | transforms.CenterCrop(448), 169 | transforms.ToTensor(), 170 | normalize, 171 | ])), 172 | batch_size=args.batch_size_partnet, shuffle=False, 173 | num_workers=args.workers, pin_memory=True 174 | ) 175 | return train_loader, val_loader 176 | elif process_name == 'part_classifiers': 177 | traindir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/train/' 178 | valdir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/val/' 179 | print('the image pre-process for process: part_classifier is Resize (448, 448) directly') 180 | train_dataset = ImageFolder_train( 181 | traindir, 182 | transforms.Compose([ 183 | transforms.Resize((448, 448)), 184 | transforms.ToTensor(), 185 | normalize, 186 | ]) 187 | ) 188 | train_loader = torch.utils.data.DataLoader( 189 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 190 | pin_memory=True, sampler=None 191 | ) 192 | val_loader = torch.utils.data.DataLoader( 193 | ImageFolder_train(valdir, transforms.Compose([ 194 | transforms.Resize((448, 448)), 195 | transforms.ToTensor(), 196 | normalize, 197 | ])), 198 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers, 199 | pin_memory=True, sampler=None 200 | ) 201 | return train_loader, val_loader 202 | 203 | 204 | -------------------------------------------------------------------------------- /data/folder_download.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | ## This file is modified from the original folder.py for the process of detected parts downloading. The return values: 3 | ## 4 | ## img_input: same as the 'img' in the folder.py. It indicates the inputs for the model. 5 | ## target: same as the 'target' in the folder.py. It is a int number representing the label of the img_input. 6 | ## img_keep: Image with the same size as the img_input, but it is not normalized. Detected parts are cropped from it. 7 | ## target_loss: The one-hot label for the image. It is used in the Binary Cross-entropy loss. 8 | ## path_image: The path of the loaded images. 9 | ################################################################################### 10 | 11 | import torch.utils.data as data 12 | import torch 13 | from PIL import Image 14 | import os 15 | import os.path 16 | from PIL import ImageFile 17 | ImageFile.LOAD_TRUNCATED_IMAGES = True 18 | 19 | print('the folder.py has been changed') 20 | 21 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 22 | 23 | 24 | def is_image_file(filename): 25 | """Checks if a file is an image. 26 | 27 | Args: 28 | filename (string): path to a file 29 | 30 | Returns: 31 | bool: True if the filename ends with a known image extension 32 | """ 33 | filename_lower = filename.lower() 34 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 35 | 36 | 37 | def find_classes(dir): 38 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 39 | classes.sort() 40 | class_to_idx = {classes[i]: i for i in range(len(classes))} 41 | return classes, class_to_idx 42 | 43 | 44 | def make_dataset(dir, class_to_idx): 45 | images = [] 46 | dir = os.path.expanduser(dir) 47 | for target in sorted(os.listdir(dir)): 48 | d = os.path.join(dir, target) 49 | if not os.path.isdir(d): 50 | continue 51 | 52 | for root, _, fnames in sorted(os.walk(d)): 53 | for fname in sorted(fnames): 54 | if is_image_file(fname): 55 | path = os.path.join(root, fname) 56 | item = (path, class_to_idx[target]) 57 | images.append(item) 58 | 59 | return images 60 | 61 | 62 | def pil_loader(path): 63 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 64 | with open(path, 'rb') as f: 65 | img = Image.open(f) 66 | return img.convert('RGB') 67 | 68 | 69 | def accimage_loader(path): 70 | import accimage 71 | try: 72 | return accimage.Image(path) 73 | except IOError: 74 | # Potentially a decoding problem, fall back to PIL.Image 75 | return pil_loader(path) 76 | 77 | 78 | def default_loader(path): 79 | from torchvision import get_image_backend 80 | if get_image_backend() == 'accimage': 81 | return accimage_loader(path) 82 | else: 83 | return pil_loader(path) 84 | 85 | 86 | class ImageFolder_download(data.Dataset): 87 | """A generic data loader where the images are arranged in this way: :: 88 | 89 | root/dog/xxx.png 90 | root/dog/xxy.png 91 | root/dog/xxz.png 92 | 93 | root/cat/123.png 94 | root/cat/nsdf3.png 95 | root/cat/asd932_.png 96 | 97 | Args: 98 | root (string): Root directory path. 99 | transform (callable, optional): A function/transform that takes in an PIL image 100 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 101 | target_transform (callable, optional): A function/transform that takes in the 102 | target and transforms it. 103 | loader (callable, optional): A function to load an image given its path. 104 | 105 | Attributes: 106 | classes (list): List of the class names. 107 | class_to_idx (dict): Dict with items (class_name, class_index). 108 | imgs (list): List of (image path, class_index) tuples 109 | """ 110 | 111 | def __init__(self, root, transform=None, transform_keep = None,target_transform=None, 112 | loader=default_loader, dataset_path = None): 113 | classes, class_to_idx = find_classes(root) 114 | imgs = make_dataset(root, class_to_idx) 115 | if len(imgs) == 0: 116 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 117 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 118 | 119 | self.root = root 120 | self.imgs = imgs 121 | self.classes = classes 122 | self.class_to_idx = class_to_idx 123 | self.transform = transform 124 | self.transform_keep = transform_keep 125 | self.target_transform = target_transform 126 | self.loader = loader 127 | self.dataset_path = dataset_path 128 | print('the classes number is', len(self.classes)) 129 | print('the images number is', len(self.imgs)) 130 | 131 | def __getitem__(self, index): 132 | """ 133 | Args: 134 | index (int): Index 135 | 136 | Returns: 137 | tuple: (image, target) where target is class_index of the target class. 138 | """ 139 | path, target = self.imgs[index] 140 | img = self.loader(path) 141 | # print(target) 142 | if self.transform is not None: 143 | img_input = self.transform(img) 144 | if self.transform_keep is not None: 145 | img_keep = self.transform_keep(img) 146 | if self.target_transform is not None: 147 | target = self.target_transform(target) 148 | target_loss = torch.Tensor(len(self.classes)).fill_(0) 149 | target_loss[target] = 1 150 | if self.dataset_path is not None: 151 | path_image = path.replace(self.dataset_path, '') 152 | return img_input, img_keep, target, target_loss, path_image 153 | 154 | def __len__(self): 155 | return len(self.imgs) 156 | 157 | def __repr__(self): 158 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 159 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 160 | fmt_str += ' Root Location: {}\n'.format(self.root) 161 | tmp = ' Transforms (if any): ' 162 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 163 | tmp = ' Target Transforms (if any): ' 164 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 165 | return fmt_str 166 | -------------------------------------------------------------------------------- /data/folder_train.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | ## This file is modified from the original folder.py for the process of PartNet training. The return values: 3 | ## 4 | ## img_input: same as the 'img' in the folder.py. It indicates the inputs for the model. 5 | ## target: same as the 'target' in the folder.py. It is a int number representing the label of the img_input. 6 | ## target_loss: The one-hot label for the image. It is used in the Binary Cross-entropy loss. 7 | ## 8 | ################################################################################### 9 | import torch.utils.data as data 10 | import torch 11 | from PIL import Image 12 | import os 13 | import os.path 14 | from PIL import ImageFile 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | print('the folder.py has been changed') 18 | 19 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] 20 | 21 | 22 | def is_image_file(filename): 23 | """Checks if a file is an image. 24 | 25 | Args: 26 | filename (string): path to a file 27 | 28 | Returns: 29 | bool: True if the filename ends with a known image extension 30 | """ 31 | filename_lower = filename.lower() 32 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) 33 | 34 | 35 | def find_classes(dir): 36 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 37 | classes.sort() 38 | class_to_idx = {classes[i]: i for i in range(len(classes))} 39 | return classes, class_to_idx 40 | 41 | 42 | def make_dataset(dir, class_to_idx): 43 | images = [] 44 | dir = os.path.expanduser(dir) 45 | for target in sorted(os.listdir(dir)): 46 | d = os.path.join(dir, target) 47 | if not os.path.isdir(d): 48 | continue 49 | 50 | for root, _, fnames in sorted(os.walk(d)): 51 | for fname in sorted(fnames): 52 | if is_image_file(fname): 53 | path = os.path.join(root, fname) 54 | item = (path, class_to_idx[target]) 55 | images.append(item) 56 | 57 | return images 58 | 59 | 60 | def pil_loader(path): 61 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 62 | with open(path, 'rb') as f: 63 | img = Image.open(f) 64 | return img.convert('RGB') 65 | 66 | 67 | def accimage_loader(path): 68 | import accimage 69 | try: 70 | return accimage.Image(path) 71 | except IOError: 72 | # Potentially a decoding problem, fall back to PIL.Image 73 | return pil_loader(path) 74 | 75 | 76 | def default_loader(path): 77 | from torchvision import get_image_backend 78 | if get_image_backend() == 'accimage': 79 | return accimage_loader(path) 80 | else: 81 | return pil_loader(path) 82 | 83 | 84 | class ImageFolder_train(data.Dataset): 85 | """A generic data loader where the images are arranged in this way: :: 86 | 87 | root/dog/xxx.png 88 | root/dog/xxy.png 89 | root/dog/xxz.png 90 | 91 | root/cat/123.png 92 | root/cat/nsdf3.png 93 | root/cat/asd932_.png 94 | 95 | Args: 96 | root (string): Root directory path. 97 | transform (callable, optional): A function/transform that takes in an PIL image 98 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 99 | target_transform (callable, optional): A function/transform that takes in the 100 | target and transforms it. 101 | loader (callable, optional): A function to load an image given its path. 102 | 103 | Attributes: 104 | classes (list): List of the class names. 105 | class_to_idx (dict): Dict with items (class_name, class_index). 106 | imgs (list): List of (image path, class_index) tuples 107 | """ 108 | 109 | def __init__(self, root, transform=None, target_transform=None, 110 | loader=default_loader): 111 | classes, class_to_idx = find_classes(root) 112 | imgs = make_dataset(root, class_to_idx) 113 | if len(imgs) == 0: 114 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 115 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 116 | 117 | self.root = root 118 | self.imgs = imgs 119 | self.classes = classes 120 | self.class_to_idx = class_to_idx 121 | self.transform = transform 122 | self.target_transform = target_transform 123 | self.loader = loader 124 | print('the classes number is', len(self.classes)) 125 | print('the images number is', len(self.imgs)) 126 | 127 | def __getitem__(self, index): 128 | """ 129 | Args: 130 | index (int): Index 131 | 132 | Returns: 133 | tuple: (image, target) where target is class_index of the target class. 134 | """ 135 | path, target = self.imgs[index] 136 | img = self.loader(path) 137 | # print(target) 138 | if self.transform is not None: 139 | img = self.transform(img) 140 | if self.target_transform is not None: 141 | target = self.target_transform(target) 142 | target_loss = torch.Tensor(len(self.classes)).fill_(0) 143 | target_loss[target] = 1 144 | return img, target, target_loss 145 | 146 | def __len__(self): 147 | return len(self.imgs) 148 | 149 | def __repr__(self): 150 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 151 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 152 | fmt_str += ' Root Location: {}\n'.format(self.root) 153 | tmp = ' Transforms (if any): ' 154 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 155 | tmp = ' Target Transforms (if any): ' 156 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 157 | return fmt_str 158 | -------------------------------------------------------------------------------- /data/prepare_data.py: -------------------------------------------------------------------------------- 1 | from data.cub200 import cub200 2 | from data.stanford_dogs import stanford_dogs 3 | from data.flowers import flowers 4 | from data.cars import cars 5 | from data.aircrafts import aircrafts 6 | 7 | def generate_dataloader(args, process_name, part_index=-1): 8 | print('the required dataset is', args.dataset) 9 | if args.dataset == 'cub200': 10 | train_loader, val_loader = cub200(args, process_name, part_index) 11 | elif args.dataset == 'stanford_dogs': 12 | raise ValueError('the required dataset is not prepared') 13 | # train_loader, val_loader = stanford_dogs(args, process_name, part_index) 14 | elif args.dataset == 'flowers': 15 | train_loader, val_loader = flowers(args, process_name, part_index) 16 | elif args.dataset == 'cars': 17 | raise ValueError('the required dataset is not prepared') 18 | # train_loader, val_loader = cars(args, process_name, part_index) 19 | elif args.dataset == 'aircrafts': 20 | raise ValueError('the required dataset is not prepared') 21 | # train_loader, val_loader = aircrafts(args, process_name, part_index) 22 | else: 23 | raise ValueError('the required dataset is not prepared') 24 | 25 | return train_loader, val_loader 26 | 27 | -------------------------------------------------------------------------------- /data/stanford_dogs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import torchvision.transforms as transforms 5 | from data.folder_train import ImageFolder_train 6 | from data.folder_download import ImageFolder_download 7 | 8 | def split_train_test_images(data_dir): 9 | #data_dir = '/data/Stanford_dog/' 10 | raise ValueError('the process of generate splited training and test images is empty') 11 | 12 | def stanford_dogs(args, process_name, part_index): 13 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 14 | std=[0.229, 0.224, 0.225]) 15 | if process_name == 'image_classifier' or process_name == 'partnet' or process_name == 'download_proposals': 16 | traindir = os.path.join(args.data_path, 'splited_image/train') 17 | valdir = os.path.join(args.data_path, 'splited_image/val') 18 | if not os.path.isdir(traindir): 19 | print('the cub images are not well splited, split all images into train and val set') 20 | split_train_test_images(args.data_path) 21 | if process_name == 'download_proposals': 22 | print('the image pre-process for process: download proposals is Resize 512 and Center Crop 448') 23 | train_dataset = ImageFolder_download( 24 | root=traindir, 25 | transform=transforms.Compose([ 26 | transforms.Resize(512), 27 | transforms.CenterCrop(448), 28 | transforms.ToTensor(), 29 | normalize, 30 | ]), 31 | transform_keep=transforms.Compose([ 32 | transforms.Resize(512), 33 | transforms.CenterCrop(448), 34 | transforms.ToTensor() 35 | ]), 36 | dataset_path=args.data_path 37 | ) 38 | train_loader = torch.utils.data.DataLoader( 39 | train_dataset, batch_size=args.batch_size_partnet, shuffle=False, num_workers=args.workers, 40 | pin_memory=True, sampler=None 41 | ) 42 | val_loader = torch.utils.data.DataLoader( 43 | ImageFolder_download( 44 | root=valdir, 45 | transform=transforms.Compose([ 46 | transforms.Resize(512), 47 | transforms.CenterCrop(448), 48 | transforms.ToTensor(), 49 | normalize, 50 | ]), 51 | transform_keep=transforms.Compose([ 52 | transforms.Resize(512), 53 | transforms.CenterCrop(448), 54 | transforms.ToTensor(), 55 | ]), 56 | dataset_path=args.data_path), 57 | batch_size=args.batch_size_partnet, shuffle=False, 58 | num_workers=args.workers, pin_memory=True 59 | ) 60 | return train_loader, val_loader 61 | elif process_name == 'image_classifier': 62 | print('the image pre-process for process: image_classifier is Resize 512 and Random Crop 448') 63 | train_dataset = ImageFolder_train( 64 | traindir, 65 | transforms.Compose([ 66 | transforms.Resize(512), 67 | transforms.RandomCrop(448), 68 | transforms.ToTensor(), 69 | normalize, 70 | ]) 71 | ) 72 | train_loader = torch.utils.data.DataLoader( 73 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 74 | pin_memory=True, sampler=None 75 | ) 76 | val_loader = torch.utils.data.DataLoader( 77 | ImageFolder_train(valdir, transforms.Compose([ 78 | transforms.Resize(512), 79 | transforms.CenterCrop(448), 80 | transforms.ToTensor(), 81 | normalize, 82 | ])), 83 | batch_size=args.batch_size, shuffle=False, 84 | num_workers=args.workers, pin_memory=True 85 | ) 86 | return train_loader, val_loader 87 | elif process_name == 'partnet': 88 | print('the image pre-process for process: partnet is Resize 512 and Random Crop 448') 89 | train_dataset = ImageFolder_train( 90 | traindir, 91 | transforms.Compose([ 92 | transforms.Resize(512), 93 | transforms.RandomCrop(448), 94 | transforms.ToTensor(), 95 | normalize, 96 | ]) 97 | ) 98 | train_loader = torch.utils.data.DataLoader( 99 | train_dataset, batch_size=args.batch_size_partnet, shuffle=True, num_workers=args.workers, 100 | pin_memory=True, sampler=None 101 | ) 102 | val_loader = torch.utils.data.DataLoader( 103 | ImageFolder_train(valdir, transforms.Compose([ 104 | transforms.Resize(512), 105 | transforms.CenterCrop(448), 106 | transforms.ToTensor(), 107 | normalize, 108 | ])), 109 | batch_size=args.batch_size_partnet, shuffle=False, 110 | num_workers=args.workers, pin_memory=True 111 | ) 112 | return train_loader, val_loader 113 | elif process_name == 'part_classifiers': 114 | traindir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/train/' 115 | valdir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/val/' 116 | print('the image pre-process for process: part_classifier is Resize (448, 448) directly') 117 | train_dataset = ImageFolder_train( 118 | traindir, 119 | transforms.Compose([ 120 | transforms.Resize((448, 448)), 121 | transforms.ToTensor(), 122 | normalize, 123 | ]) 124 | ) 125 | train_loader = torch.utils.data.DataLoader( 126 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, 127 | pin_memory=True, sampler=None 128 | ) 129 | val_loader = torch.utils.data.DataLoader( 130 | ImageFolder_train(valdir, transforms.Compose([ 131 | transforms.Resize((448, 448)), 132 | transforms.ToTensor(), 133 | normalize, 134 | ])), 135 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers, 136 | pin_memory=True, sampler=None 137 | ) 138 | return train_loader, val_loader 139 | 140 | 141 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ############################################################################## 2 | # Pytorch PartNet 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Yabin Zhang 5 | ############################################################################## 6 | 7 | from opts import opts # The options for the project 8 | from process import Process1_Image_Classifier 9 | from process import Process2_PartNet 10 | from process import Process3_Download_Proposals 11 | from process import Process4_Part_Classifiers 12 | from process import Process5_Final_Result 13 | import os 14 | def main(): 15 | global args 16 | args = opts() 17 | if args.test_only: 18 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 19 | print('download the detected part images in the disk') 20 | Process3_Download_Proposals(args) ### We use hook in this process, so only one gpu should be used. 21 | os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3,4,5,6,7" 22 | print('only test the final results, please make sure all the needed files is prepared.') 23 | Process5_Final_Result(args) 24 | else: 25 | # fine-tune on the Fine-Grained dataset an ImageNet pre-trained model 26 | os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3,4,5,6,7" 27 | Process1_Image_Classifier(args) 28 | # Train the PartNet based on the above fine-tuned model on the Fine-Grained dataset 29 | Process2_PartNet(args) 30 | # Download the proposals detected by the PartNet 31 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 32 | Process3_Download_Proposals(args) ### We use hook in this process, so only one gpu should be used. 33 | os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3,4,5,6,7" 34 | # Train individual Classifier for each part 35 | Process4_Part_Classifiers(args) 36 | # Averaging the probabilities of the above models (image level + part 1-2-3 + ParNet) for final prediction 37 | Process5_Final_Result(args) 38 | 39 | if __name__ == '__main__': 40 | main() 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /models/DPP/Readme.md: -------------------------------------------------------------------------------- 1 | # Discretized Part Proposals (DPP) 2 | It is a C implementation of the DPP module 3 | ## Compile Process 4 | Requirement: same with: https://github.com/jwyang/faster-rcnn.pytorch 5 | 6 | Command to compile: 7 | ```python 8 | python build.py 9 | ``` 10 | 11 | -------------------------------------------------------------------------------- /models/DPP/_ext/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/_ext/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/DPP/_ext/dpp/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._dpp import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /models/DPP/_ext/dpp/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/_ext/dpp/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/DPP/_ext/dpp/_dpp.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/_ext/dpp/_dpp.so -------------------------------------------------------------------------------- /models/DPP/build.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from torch.utils.ffi import create_extension 5 | 6 | 7 | sources = ['src/dpp.c'] 8 | headers = ['src/dpp.h'] 9 | defines = [] 10 | with_cuda = False 11 | 12 | # if torch.cuda.is_available(): 13 | # print('Including CUDA code.') 14 | # sources += ['src/roi_pooling_cuda.c'] 15 | # headers += ['src/roi_pooling_cuda.h'] 16 | # defines += [('WITH_CUDA', None)] 17 | # with_cuda = True 18 | 19 | this_file = os.path.dirname(os.path.realpath(__file__)) 20 | print(this_file) 21 | #extra_objects = ['src/dpp.cu.o'] 22 | #extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 23 | 24 | ffi = create_extension( 25 | '_ext.dpp', 26 | headers=headers, 27 | sources=sources, 28 | define_macros=defines, 29 | relative_to=__file__, 30 | with_cuda=with_cuda # , 31 | # extra_objects=extra_objects 32 | ) 33 | 34 | if __name__ == '__main__': 35 | ffi.build() 36 | -------------------------------------------------------------------------------- /models/DPP/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/DPP/functions/__pycache__/dpp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/functions/__pycache__/dpp.cpython-36.pyc -------------------------------------------------------------------------------- /models/DPP/functions/dpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from .._ext import dpp 4 | import pdb 5 | import time 6 | 7 | class DPPFunction(Function): 8 | def __init__(ctx, square_size, proposals_per_square, proposals_per_image, spatial_scale): 9 | ctx.square_size = square_size 10 | ctx.proposals_per_square = proposals_per_square 11 | ctx.spatital_scale = spatial_scale 12 | ctx.output = torch.cuda.FloatTensor() 13 | ctx.proposals_per_image = proposals_per_image 14 | ctx.box_plan = torch.Tensor([[-1, -1, 1, 1], 15 | [-2, -2, 2, 2], 16 | [-1, -3, 1, 3], 17 | [-3, -1, 3, 1], 18 | [-3, -3, 3, 3], 19 | [-2, -4, 2, 4], 20 | [-4, -2, 4, 2], 21 | [-4, -4, 4, 4], 22 | [-3, -5, 3, 5], 23 | [-5, -3, 5, 3], # 10 24 | [-5, -5, 5, 5], 25 | [-4, -7, 4, 7], 26 | [-7, -4, 7, 4], 27 | [-6, -6, 6, 6], 28 | [-4, -8, 4, 8], # 15 29 | [-8, -4, 8, 4], 30 | [-7, -7, 7, 7], 31 | [-5, -10, 5, 10], 32 | [-10, -5, 10, 5], 33 | [-8, -8, 8, 8], # 20 34 | [-6, -11, 6, 11], 35 | [-11, -6, 11, 6], 36 | [-9, -9, 9, 9], 37 | [-7, -12, 7, 12], 38 | [-12, -7, 12, 12], 39 | [-10, -10, 10, 10], 40 | [-7, -14, 7, 14], 41 | [-14, -7, 14, 7]]) 42 | ctx.square_num = int(28 / ctx.square_size) ### it should be 7 43 | calculate_num = ctx.square_num * ctx.square_num * proposals_per_square 44 | if ctx.proposals_per_image != calculate_num: 45 | raise ValueError('the number generated by dpp should be', calculate_num) 46 | if ctx.square_size != 4 and ctx.square_size != 7: 47 | raise ValueError('the number of the square for one line should be 4 or 7, but you define:', 48 | ctx.square_size) 49 | if ctx.proposals_per_square > 28: 50 | raise ValueError('the proposals number for each image should below 28') 51 | if ctx.spatital_scale != 16: 52 | raise ValueError('the spatial scale should be 16, but you define is:', ctx.spatital_scale) 53 | 54 | def forward(ctx, features): 55 | #timer= time.time() 56 | batch_size, num_channels, data_height, data_width = features.size() 57 | num_rois = batch_size * ctx.proposals_per_image 58 | output = torch.FloatTensor(num_rois, 5).fill_(1) 59 | histogram = torch.FloatTensor(data_height, data_width).fill_(0) 60 | score_sum = torch.FloatTensor(data_height, data_width).fill_(0) 61 | 62 | if features.is_cuda: 63 | # print('the DPP is in the cpu form') 64 | dpp.dpp_forward(ctx.square_size, ctx.proposals_per_square, ctx.proposals_per_image, ctx.spatital_scale, 65 | ctx.box_plan, histogram, score_sum, output, features.cpu()) 66 | # print('the dpp forward time is:', time.time() - timer) 67 | # print('the output of dpp is:', output) 68 | else: 69 | dpp.dpp_forward(ctx.square_size, ctx.proposals_per_square, ctx.proposals_per_image, ctx.spatital_scale, 70 | ctx.box_plan, histogram, score_sum, output, features) 71 | # print('the dpp forward time is:', time.time() - timer) 72 | # print('the output of dpp is:', output) 73 | 74 | return output.cuda() 75 | 76 | def backward(ctx, grad_output): 77 | 78 | return None 79 | -------------------------------------------------------------------------------- /models/DPP/functions/test_dpp.py: -------------------------------------------------------------------------------- 1 | from .._ext import dpp 2 | import torch 3 | output = torch.FloatTensor(2744, 5).fill_(0) 4 | histogram = torch.FloatTensor(28, 28).fill_(0) 5 | score_sum = torch.FloatTensor(28, 28).fill_(0) 6 | features = torch.randn(2, 512, 28, 28) 7 | box_plan = torch.Tensor([[-1, -1, 1, 1], 8 | [-2, -2, 2, 2], 9 | [-1, -3, 1, 3], 10 | [-3, -1, 3, 1], 11 | [-3, -3, 3, 3], 12 | [-2, -4, 2, 4], 13 | [-4, -2, 4, 2], 14 | [-4, -4, 4, 4], 15 | [-3, -5, 3, 5], 16 | [-5, -3, 5, 3], # 10 17 | [-5, -5, 5, 5], 18 | [-4, -7, 4, 7], 19 | [-7, -4, 7, 4], 20 | [-6, -6, 6, 6], 21 | [-4, -8, 4, 8], # 15 22 | [-8, -4, 8, 4], 23 | [-7, -7, 7, 7], 24 | [-5, -10, 5, 10], 25 | [-10, -5, 10, 5], 26 | [-8, -8, 8, 8], # 20 27 | [-6, -11, 6, 11], 28 | [-11, -6, 11, 6], 29 | [-9, -9, 9, 9], 30 | [-7, -12, 7, 12], 31 | [-12, -7, 12, 12], 32 | [-10, -10, 10, 10], 33 | [-7, -14, 7, 14], 34 | [-14, -7, 14, 7]]) 35 | dpp.dpp_forward(4, 28, 1372, 16, 36 | box_plan, histogram, score_sum, output, features) 37 | 38 | print(output) -------------------------------------------------------------------------------- /models/DPP/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/DPP/modules/__pycache__/dpp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/modules/__pycache__/dpp.cpython-36.pyc -------------------------------------------------------------------------------- /models/DPP/modules/dpp.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from ..functions.dpp import DPPFunction 3 | 4 | 5 | class _DPP(Module): 6 | def __init__(self, square_size, proposals_per_square, proposals_per_image, spatial_scale): 7 | super(_DPP, self).__init__() 8 | 9 | self.square_size = int(square_size) 10 | self.proposals_per_square = int(proposals_per_square) 11 | self.proposals_per_image = int(proposals_per_image) 12 | self.spatial_scale = int(spatial_scale) 13 | 14 | def forward(self, features): 15 | return DPPFunction(self.square_size, self.proposals_per_square, self.proposals_per_image, self.spatial_scale)(features) 16 | -------------------------------------------------------------------------------- /models/DPP/src/dpp.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | int dpp_forward(int square_size, int proposals_per_square, int proposals_per_image, int spatital_scale, THFloatTensor * box_plan, 6 | THFloatTensor * histogram, THFloatTensor * score_sum, THFloatTensor * output, THFloatTensor * features) 7 | { 8 | // Grab the input tensor 9 | float * box_plan_flat = THFloatTensor_data(box_plan); 10 | float * histogram_flat = THFloatTensor_data(histogram); 11 | float * score_sum_flat = THFloatTensor_data(score_sum); 12 | float * output_flat = THFloatTensor_data(output); 13 | float * features_flat = THFloatTensor_data(features); 14 | 15 | int batch_size = THFloatTensor_size(features, 0); 16 | 17 | int data_height = THFloatTensor_size(features, 2); 18 | // data width 19 | int data_width = THFloatTensor_size(features, 3); 20 | // Number of channels 21 | int num_channels = THFloatTensor_size(features, 1); 22 | 23 | int b, c, h, w, max_h, max_w, max_value, p; 24 | int index_proposals = 0; 25 | 26 | for (b = 0; b < batch_size; b++) 27 | { 28 | // printf("here is the batch of %d", b); 29 | THFloatStorage_fill(THFloatTensor_storage(histogram), 0); 30 | // printf("here is the line after fill lalalla "); 31 | THFloatStorage_fill(THFloatTensor_storage(score_sum), 0); 32 | 33 | for (c =0; c < num_channels; c++) 34 | { 35 | // printf("histogram of %d", c); 36 | //find the max position for each channel, and add 1 to the histogram 37 | const int index_features = b * data_height * data_width * num_channels + c * data_height * data_width; 38 | max_w = 0; 39 | max_h = 0; 40 | max_value = 0; 41 | 42 | for (h=0; h < data_height; h++) 43 | { 44 | for (w=0; w< data_width; w++) 45 | { 46 | const int index_histogram = h * data_width + w; 47 | if(features_flat[index_features + index_histogram] > max_value) 48 | { 49 | max_value = features_flat[index_features + index_histogram]; 50 | max_w = w; 51 | max_h = h; 52 | } 53 | } 54 | } 55 | histogram_flat[max_h * data_width + max_w] += 1; 56 | } 57 | 58 | /// calculate the score sum 59 | for (c =0; c < num_channels; c++) 60 | { 61 | // printf("score sum of %d", c); 62 | //add values at each position of the feature to the score sum 63 | const int index_features = b * data_height * data_width * num_channels + c * data_height * data_width; 64 | 65 | for (h=0; h < data_height; h++) 66 | { 67 | for (w=0; w< data_width; w++) 68 | { 69 | const int index_score_sum = h *data_width + w; 70 | score_sum_flat[index_score_sum] += features_flat[index_features + index_score_sum]; 71 | } 72 | } 73 | } 74 | 75 | // 76 | int sub_h, sub_w; 77 | for (h=0; h 28: 203 | # raise ValueError('the proposals number for each image should below 28') 204 | # 205 | # 206 | # def forward(self, features): 207 | # timer = time.time() 208 | # batch_size, num_channels, data_height, data_width = features.size() 209 | # features_float = features.float() 210 | # 211 | # num_rois = batch_size * self.proposals_per_image 212 | # output = torch.Tensor(num_rois, 5).fill_(0) 213 | # 214 | # if self.spatital_scale != 16: 215 | # raise ValueError('the spatial scale should be 16, but you define is:', self.spatital_scale) 216 | # proposals_index = 0 217 | # 218 | # for i in range(batch_size): 219 | # roi_memory = torch.Tensor(28,28).zero_() 220 | # roi_score = torch.sum(features_float[i], 0) 221 | # output[i*self.proposals_per_image: (i+1)*self.proposals_per_image, 0] = i 222 | # for j in range(num_channels): 223 | # one_channel = features_float[i][j] 224 | # max_per_row, max_column_per_row = torch.max(one_channel, 1) 225 | # max_p, max_row = torch.max(max_per_row, 0) 226 | # x_center = max_row[0] 227 | # max_col = max_column_per_row[x_center] 228 | # y_center = max_col 229 | # roi_memory[x_center][y_center] = roi_memory[x_center][y_center] + 1 230 | # for x in range(self.square_num): 231 | # for y in range(self.square_num): 232 | # temp = roi_memory[x * self.square_size:(x+1)*self.square_size, y*self.square_size: (y+1)*self.square_size] 233 | # max_per_row, max_column_per_row = torch.max(temp, 1) 234 | # max_p, max_row = torch.max(max_per_row, 0) 235 | # x_center = max_row[0] 236 | # max_col = max_column_per_row[x_center] 237 | # y_center = max_col 238 | # find_repeat = torch.eq(temp, temp[x_center][y_center]) 239 | # if torch.sum(find_repeat) > 1: 240 | # score_temp = roi_score[x * self.square_size:(x+1)*self.square_size, y*self.square_size:(y+1)*self.square_size] 241 | # max_per_row, max_column_per_row = torch.max(score_temp, 1) 242 | # max_p, max_row = torch.max(max_per_row, 0) 243 | # x_center = max_row[0] 244 | # max_col = max_column_per_row[x_center] 245 | # y_center = max_col 246 | # x_center = x_center + x * self.square_size 247 | # y_center = y_center + y * self.square_size 248 | # for k in range(self.proposals_per_square): 249 | # output[proposals_index][1:5] = torch.Tensor([x_center +self.box_plan[k][0], y_center+self.box_plan[k][1], x_center+self.box_plan[k][2], y_center+self.box_plan[k][3]]) 250 | # proposals_index = proposals_index + 1 251 | # output[:, 1:5].mul_(self.spatital_scale) 252 | # output[:, 1:5].clamp_(0, 448) 253 | # #self.output.resize_(output.size()).copy_(output, non_blocking=False) 254 | # print('the dpp time is:', time.time() - timer) 255 | # return output.cuda() 256 | # 257 | # 258 | # def backward(self, grad_output): 259 | # 260 | # return None 261 | 262 | # class DPP(Module): 263 | # def __init__(self, square_size, proposals_per_square, proposals_per_image, spatial_scale): 264 | # super(DPP, self).__init__() 265 | # 266 | # self.square_size = int(square_size) 267 | # self.proposals_per_square = int(proposals_per_square) 268 | # self.proposals_per_image = int(proposals_per_image) 269 | # self.spatial_scale = int(spatial_scale) 270 | # 271 | # def forward(self, features): 272 | # return DPPFunction(self.square_size, self.proposals_per_square, self.proposals_per_image, self.spatial_scale)(features) 273 | 274 | class Classification_stream(Module): 275 | def __init__(self, proposal_num, num_classes): 276 | super(Classification_stream, self).__init__() 277 | self.proposals_num = proposal_num 278 | self.num_classes = num_classes 279 | self.classifier = nn.Sequential( 280 | nn.Linear(512 * 3 * 3, 4096), 281 | nn.ReLU(True), 282 | nn.Dropout(), 283 | nn.Linear(4096, self.num_classes+1), 284 | ) 285 | self.softmax = nn.Softmax(2) 286 | def forward(self, x): 287 | x = self.classifier(x) 288 | x = x.view(-1, self.proposals_num, self.num_classes+1) 289 | x = self.softmax(x) 290 | x = x.narrow(2, 0, self.num_classes) 291 | return x 292 | 293 | class Detection_stream(Module): 294 | def __init__(self, proposal_num, part_num): 295 | super(Detection_stream, self).__init__() 296 | self.proposals_num = proposal_num 297 | self.part_num = part_num 298 | self.detector = nn.Sequential( 299 | nn.Linear(512*3*3, 4096), 300 | nn.ReLU(True), 301 | nn.Dropout(), 302 | nn.Linear(4096, self.part_num+1), 303 | ) 304 | self.softmax_cls = nn.Softmax(2) 305 | self.softmax_nor = nn.Softmax(1) 306 | def forward(self, x): 307 | x = self.detector(x) 308 | x = x.view(-1, self.proposals_num, self.part_num+1) 309 | x = self.softmax_cls(x) 310 | x = x.narrow(2, 0, self.part_num) 311 | x = self.softmax_nor(x) 312 | return x 313 | 314 | class construct_partnet(nn.Module): 315 | def __init__(self, conv_model, args): 316 | super(construct_partnet, self).__init__() 317 | self.conv_model = conv_model 318 | self.roi_pool = _RoIPooling(3, 3, 1.0/16) 319 | self.args = args 320 | self.DPP = _DPP(args.square_size, args.proposals_per_square, args.proposals_num, args.stride) 321 | self.classification_stream = Classification_stream(args.proposals_num, args.num_classes) 322 | self.detection_stream = Detection_stream(args.proposals_num, args.num_part) 323 | 324 | def forward(self, x): 325 | x = self.conv_model(x) 326 | rois = self.DPP(x) 327 | x = self.roi_pool(x, rois) 328 | x = x.view(x.size(0), -1) 329 | x_c = self.classification_stream(x) 330 | x_d = self.detection_stream(x) 331 | x_c = x_c.transpose(1, 2) 332 | mix = torch.matmul(x_c, x_d) 333 | x = mix.mean(2) 334 | 335 | return x 336 | 337 | def vgg16_bn(args, **kwargs): 338 | """VGG 16-layer model (configuration "D") with batch normalization 339 | 340 | Args: 341 | pretrained (bool): If True, returns a model pre-trained on ImageNet 342 | """ 343 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 344 | if args.pretrain: 345 | print('load the imageNet pretrained model') 346 | pretrained_dict = model_zoo.load_url(model_urls['vgg16_bn']) 347 | model_dict = model.state_dict() 348 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 349 | model_dict.update(pretrained_dict_temp) 350 | model.load_state_dict(model_dict) 351 | print(args.finetuned_model) 352 | if args.finetuned_model != '': 353 | print('load the model that has been finetuned on cub', args.finetuned_model) 354 | pretrained_dict = torch.load(args.finetuned_model)['state_dict'] 355 | pretrained_dict_temp = copy.deepcopy(pretrained_dict) 356 | model_state_dict = model.state_dict() 357 | 358 | for k_tmp in pretrained_dict_temp.keys(): 359 | if k_tmp.find('module.base_conv') != -1: 360 | k = k_tmp.replace('module.base_conv.', '') 361 | pretrained_dict[k] = pretrained_dict.pop(k_tmp) 362 | # ipdb.set_trace() 363 | # print(pretrained_dict) 364 | pretrained_dict_temp2 = {k: v for k, v in pretrained_dict.items() if k in model_state_dict} 365 | # print(pretrained_dict_temp2) 366 | model_state_dict.update(pretrained_dict_temp2) 367 | model.load_state_dict(model_state_dict) 368 | 369 | print('here load the fine tuned conv layer to the model') 370 | 371 | ### change the model to PartNet by ADD: DPP + ROIPooling + two stream. 372 | PartNet = construct_partnet(model, args) 373 | 374 | return PartNet 375 | 376 | 377 | def vgg19(pretrained=False, **kwargs): 378 | """VGG 19-layer model (configuration "E") 379 | 380 | Args: 381 | pretrained (bool): If True, returns a model pre-trained on ImageNet 382 | """ 383 | model = VGG(make_layers(cfg['E']), **kwargs) 384 | if pretrained: 385 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 386 | return model 387 | 388 | 389 | 390 | 391 | def vgg19_bn(args, **kwargs): 392 | """VGG 19-layer model (configuration 'E') with batch normalization 393 | 394 | Args: 395 | pretrained (bool): If True, returns a model pre-trained on ImageNet 396 | """ 397 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 398 | finetuned_model = args.dataset + '/Image_Classifier/model_best.pth.tar' 399 | if args.pretrain: 400 | print('load the imageNet pretrained model') 401 | pretrained_dict = model_zoo.load_url(model_urls['vgg19_bn']) 402 | model_dict = model.state_dict() 403 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 404 | model_dict.update(pretrained_dict_temp) 405 | model.load_state_dict(model_dict) 406 | if finetuned_model != '': 407 | print('load the model that has been finetuned on cub', finetuned_model) 408 | pretrained_dict = torch.load(finetuned_model)['state_dict'] 409 | pretrained_dict_temp = copy.deepcopy(pretrained_dict) 410 | model_state_dict = model.state_dict() 411 | 412 | for k_tmp in pretrained_dict_temp.keys(): 413 | if k_tmp.find('module.base_conv') != -1: 414 | k = k_tmp.replace('module.base_conv.', '') 415 | pretrained_dict[k] = pretrained_dict.pop(k_tmp) 416 | pretrained_dict_temp2 = {k: v for k, v in pretrained_dict.items() if k in model_state_dict} 417 | model_state_dict.update(pretrained_dict_temp2) 418 | model.load_state_dict(model_state_dict) 419 | PartNet = construct_partnet(model, args) 420 | return PartNet 421 | 422 | 423 | def partnet_vgg(args, **kwargs): 424 | print("==> creating model '{}' ".format(args.arch)) 425 | if args.arch == 'vgg19_bn': 426 | return vgg19_bn(args) 427 | elif args.arch == 'vgg16_bn': 428 | return vgg16_bn(args) 429 | else: 430 | raise ValueError('Unrecognized model architecture', args.arch) 431 | -------------------------------------------------------------------------------- /models/roi_pooling/Readme.md: -------------------------------------------------------------------------------- 1 | # ROI Pooling 2 | We use the released code from: https://github.com/jwyang/faster-rcnn.pytorch/tree/master/lib/model/roi_pooling 3 | 4 | -------------------------------------------------------------------------------- /models/roi_pooling/_ext/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/_ext/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/roi_pooling/_ext/roi_pooling/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._roi_pooling import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /models/roi_pooling/_ext/roi_pooling/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/_ext/roi_pooling/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/roi_pooling/_ext/roi_pooling/_roi_pooling.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/_ext/roi_pooling/_roi_pooling.so -------------------------------------------------------------------------------- /models/roi_pooling/build.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from torch.utils.ffi import create_extension 5 | 6 | 7 | sources = ['src/roi_pooling.c'] 8 | headers = ['src/roi_pooling.h'] 9 | defines = [] 10 | with_cuda = False 11 | 12 | if torch.cuda.is_available(): 13 | print('Including CUDA code.') 14 | sources += ['src/roi_pooling_cuda.c'] 15 | headers += ['src/roi_pooling_cuda.h'] 16 | defines += [('WITH_CUDA', None)] 17 | with_cuda = True 18 | 19 | this_file = os.path.dirname(os.path.realpath(__file__)) 20 | print(this_file) 21 | extra_objects = ['src/roi_pooling.cu.o'] 22 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 23 | 24 | ffi = create_extension( 25 | '_ext.roi_pooling', 26 | headers=headers, 27 | sources=sources, 28 | define_macros=defines, 29 | relative_to=__file__, 30 | with_cuda=with_cuda, 31 | extra_objects=extra_objects 32 | ) 33 | 34 | if __name__ == '__main__': 35 | ffi.build() 36 | -------------------------------------------------------------------------------- /models/roi_pooling/functions/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/functions/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/roi_pooling/functions/__pycache__/roi_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/functions/__pycache__/roi_pool.cpython-36.pyc -------------------------------------------------------------------------------- /models/roi_pooling/functions/roi_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from .._ext import roi_pooling 4 | # import time 5 | 6 | class RoIPoolFunction(Function): 7 | def __init__(ctx, pooled_height, pooled_width, spatial_scale): 8 | ctx.pooled_width = pooled_width 9 | ctx.pooled_height = pooled_height 10 | ctx.spatial_scale = spatial_scale 11 | ctx.feature_size = None 12 | 13 | def forward(ctx, features, rois): 14 | # timer = time.time() 15 | ctx.feature_size = features.size() 16 | batch_size, num_channels, data_height, data_width = ctx.feature_size 17 | num_rois = rois.size(0) 18 | output = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_() 19 | ctx.argmax = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_().int() 20 | ctx.rois = rois 21 | if not features.is_cuda: 22 | _features = features.permute(0, 2, 3, 1) 23 | roi_pooling.roi_pooling_forward(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale, 24 | _features, rois, output) 25 | else: 26 | roi_pooling.roi_pooling_forward_cuda(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale, 27 | features, rois, output, ctx.argmax) 28 | # print('the forward time of roipooling is:', time.time() - timer) 29 | return output 30 | 31 | def backward(ctx, grad_output): 32 | assert(ctx.feature_size is not None and grad_output.is_cuda) 33 | batch_size, num_channels, data_height, data_width = ctx.feature_size 34 | grad_input = grad_output.new(batch_size, num_channels, data_height, data_width).zero_() 35 | 36 | roi_pooling.roi_pooling_backward_cuda(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale, 37 | grad_output, ctx.rois, grad_input, ctx.argmax) 38 | 39 | return grad_input, None 40 | -------------------------------------------------------------------------------- /models/roi_pooling/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/roi_pooling/modules/__pycache__/roi_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/modules/__pycache__/roi_pool.cpython-36.pyc -------------------------------------------------------------------------------- /models/roi_pooling/modules/__pycache__/roi_pool_py.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/modules/__pycache__/roi_pool_py.cpython-36.pyc -------------------------------------------------------------------------------- /models/roi_pooling/modules/roi_pool.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from ..functions.roi_pool import RoIPoolFunction 3 | 4 | 5 | class _RoIPooling(Module): 6 | def __init__(self, pooled_height, pooled_width, spatial_scale): 7 | super(_RoIPooling, self).__init__() 8 | 9 | self.pooled_width = int(pooled_width) 10 | self.pooled_height = int(pooled_height) 11 | self.spatial_scale = float(spatial_scale) 12 | 13 | def forward(self, features, rois): 14 | return RoIPoolFunction(self.pooled_height, self.pooled_width, self.spatial_scale)(features, rois) 15 | -------------------------------------------------------------------------------- /models/roi_pooling/src/roi_pooling.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int roi_pooling_forward(int pooled_height, int pooled_width, float spatial_scale, 5 | THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output) 6 | { 7 | // Grab the input tensor 8 | float * data_flat = THFloatTensor_data(features); 9 | float * rois_flat = THFloatTensor_data(rois); 10 | 11 | float * output_flat = THFloatTensor_data(output); 12 | 13 | // Number of ROIs 14 | int num_rois = THFloatTensor_size(rois, 0); 15 | int size_rois = THFloatTensor_size(rois, 1); 16 | // batch size 17 | int batch_size = THFloatTensor_size(features, 0); 18 | if(batch_size != 1) 19 | { 20 | return 0; 21 | } 22 | // data height 23 | int data_height = THFloatTensor_size(features, 1); 24 | // data width 25 | int data_width = THFloatTensor_size(features, 2); 26 | // Number of channels 27 | int num_channels = THFloatTensor_size(features, 3); 28 | 29 | // Set all element of the output tensor to -inf. 30 | THFloatStorage_fill(THFloatTensor_storage(output), -1); 31 | 32 | // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R 33 | int index_roi = 0; 34 | int index_output = 0; 35 | int n; 36 | for (n = 0; n < num_rois; ++n) 37 | { 38 | int roi_batch_ind = rois_flat[index_roi + 0]; 39 | int roi_start_w = round(rois_flat[index_roi + 1] * spatial_scale); 40 | int roi_start_h = round(rois_flat[index_roi + 2] * spatial_scale); 41 | int roi_end_w = round(rois_flat[index_roi + 3] * spatial_scale); 42 | int roi_end_h = round(rois_flat[index_roi + 4] * spatial_scale); 43 | // CHECK_GE(roi_batch_ind, 0); 44 | // CHECK_LT(roi_batch_ind, batch_size); 45 | 46 | int roi_height = fmaxf(roi_end_h - roi_start_h + 1, 1); 47 | int roi_width = fmaxf(roi_end_w - roi_start_w + 1, 1); 48 | float bin_size_h = (float)(roi_height) / (float)(pooled_height); 49 | float bin_size_w = (float)(roi_width) / (float)(pooled_width); 50 | 51 | int index_data = roi_batch_ind * data_height * data_width * num_channels; 52 | const int output_area = pooled_width * pooled_height; 53 | 54 | int c, ph, pw; 55 | for (ph = 0; ph < pooled_height; ++ph) 56 | { 57 | for (pw = 0; pw < pooled_width; ++pw) 58 | { 59 | int hstart = (floor((float)(ph) * bin_size_h)); 60 | int wstart = (floor((float)(pw) * bin_size_w)); 61 | int hend = (ceil((float)(ph + 1) * bin_size_h)); 62 | int wend = (ceil((float)(pw + 1) * bin_size_w)); 63 | 64 | hstart = fminf(fmaxf(hstart + roi_start_h, 0), data_height); 65 | hend = fminf(fmaxf(hend + roi_start_h, 0), data_height); 66 | wstart = fminf(fmaxf(wstart + roi_start_w, 0), data_width); 67 | wend = fminf(fmaxf(wend + roi_start_w, 0), data_width); 68 | 69 | const int pool_index = index_output + (ph * pooled_width + pw); 70 | int is_empty = (hend <= hstart) || (wend <= wstart); 71 | if (is_empty) 72 | { 73 | for (c = 0; c < num_channels * output_area; c += output_area) 74 | { 75 | output_flat[pool_index + c] = 0; 76 | } 77 | } 78 | else 79 | { 80 | int h, w, c; 81 | for (h = hstart; h < hend; ++h) 82 | { 83 | for (w = wstart; w < wend; ++w) 84 | { 85 | for (c = 0; c < num_channels; ++c) 86 | { 87 | const int index = (h * data_width + w) * num_channels + c; 88 | if (data_flat[index_data + index] > output_flat[pool_index + c * output_area]) 89 | { 90 | output_flat[pool_index + c * output_area] = data_flat[index_data + index]; 91 | } 92 | } 93 | } 94 | } 95 | } 96 | } 97 | } 98 | 99 | // Increment ROI index 100 | index_roi += size_rois; 101 | index_output += pooled_height * pooled_width * num_channels; 102 | } 103 | return 1; 104 | } -------------------------------------------------------------------------------- /models/roi_pooling/src/roi_pooling.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/src/roi_pooling.cu.o -------------------------------------------------------------------------------- /models/roi_pooling/src/roi_pooling.h: -------------------------------------------------------------------------------- 1 | int roi_pooling_forward(int pooled_height, int pooled_width, float spatial_scale, 2 | THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output); -------------------------------------------------------------------------------- /models/roi_pooling/src/roi_pooling_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "roi_pooling_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | int roi_pooling_forward_cuda(int pooled_height, int pooled_width, float spatial_scale, 8 | THCudaTensor * features, THCudaTensor * rois, THCudaTensor * output, THCudaIntTensor * argmax) 9 | { 10 | // Grab the input tensor 11 | float * data_flat = THCudaTensor_data(state, features); 12 | float * rois_flat = THCudaTensor_data(state, rois); 13 | 14 | float * output_flat = THCudaTensor_data(state, output); 15 | int * argmax_flat = THCudaIntTensor_data(state, argmax); 16 | 17 | // Number of ROIs 18 | int num_rois = THCudaTensor_size(state, rois, 0); 19 | int size_rois = THCudaTensor_size(state, rois, 1); 20 | if (size_rois != 5) 21 | { 22 | return 0; 23 | } 24 | 25 | // batch size 26 | // int batch_size = THCudaTensor_size(state, features, 0); 27 | // if (batch_size != 1) 28 | // { 29 | // return 0; 30 | // } 31 | // data height 32 | int data_height = THCudaTensor_size(state, features, 2); 33 | // data width 34 | int data_width = THCudaTensor_size(state, features, 3); 35 | // Number of channels 36 | int num_channels = THCudaTensor_size(state, features, 1); 37 | 38 | cudaStream_t stream = THCState_getCurrentStream(state); 39 | 40 | ROIPoolForwardLaucher( 41 | data_flat, spatial_scale, num_rois, data_height, 42 | data_width, num_channels, pooled_height, 43 | pooled_width, rois_flat, 44 | output_flat, argmax_flat, stream); 45 | 46 | return 1; 47 | } 48 | 49 | int roi_pooling_backward_cuda(int pooled_height, int pooled_width, float spatial_scale, 50 | THCudaTensor * top_grad, THCudaTensor * rois, THCudaTensor * bottom_grad, THCudaIntTensor * argmax) 51 | { 52 | // Grab the input tensor 53 | float * top_grad_flat = THCudaTensor_data(state, top_grad); 54 | float * rois_flat = THCudaTensor_data(state, rois); 55 | 56 | float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad); 57 | int * argmax_flat = THCudaIntTensor_data(state, argmax); 58 | 59 | // Number of ROIs 60 | int num_rois = THCudaTensor_size(state, rois, 0); 61 | int size_rois = THCudaTensor_size(state, rois, 1); 62 | if (size_rois != 5) 63 | { 64 | return 0; 65 | } 66 | 67 | // batch size 68 | int batch_size = THCudaTensor_size(state, bottom_grad, 0); 69 | // if (batch_size != 1) 70 | // { 71 | // return 0; 72 | // } 73 | // data height 74 | int data_height = THCudaTensor_size(state, bottom_grad, 2); 75 | // data width 76 | int data_width = THCudaTensor_size(state, bottom_grad, 3); 77 | // Number of channels 78 | int num_channels = THCudaTensor_size(state, bottom_grad, 1); 79 | 80 | cudaStream_t stream = THCState_getCurrentStream(state); 81 | ROIPoolBackwardLaucher( 82 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height, 83 | data_width, num_channels, pooled_height, 84 | pooled_width, rois_flat, 85 | bottom_grad_flat, argmax_flat, stream); 86 | 87 | return 1; 88 | } 89 | -------------------------------------------------------------------------------- /models/roi_pooling/src/roi_pooling_cuda.h: -------------------------------------------------------------------------------- 1 | int roi_pooling_forward_cuda(int pooled_height, int pooled_width, float spatial_scale, 2 | THCudaTensor * features, THCudaTensor * rois, THCudaTensor * output, THCudaIntTensor * argmax); 3 | 4 | int roi_pooling_backward_cuda(int pooled_height, int pooled_width, float spatial_scale, 5 | THCudaTensor * top_grad, THCudaTensor * rois, THCudaTensor * bottom_grad, THCudaIntTensor * argmax); -------------------------------------------------------------------------------- /models/roi_pooling/src/roi_pooling_kernel.cu: -------------------------------------------------------------------------------- 1 | // #ifdef __cplusplus 2 | // extern "C" { 3 | // #endif 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "roi_pooling_kernel.h" 10 | 11 | 12 | #define DIVUP(m, n) ((m) / (m) + ((m) % (n) > 0)) 13 | 14 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 15 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 16 | i += blockDim.x * gridDim.x) 17 | 18 | // CUDA: grid stride looping 19 | #define CUDA_KERNEL_LOOP(i, n) \ 20 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 21 | i < (n); \ 22 | i += blockDim.x * gridDim.x) 23 | 24 | __global__ void ROIPoolForward(const int nthreads, const float* bottom_data, 25 | const float spatial_scale, const int height, const int width, 26 | const int channels, const int pooled_height, const int pooled_width, 27 | const float* bottom_rois, float* top_data, int* argmax_data) 28 | { 29 | CUDA_KERNEL_LOOP(index, nthreads) 30 | { 31 | // (n, c, ph, pw) is an element in the pooled output 32 | // int n = index; 33 | // int pw = n % pooled_width; 34 | // n /= pooled_width; 35 | // int ph = n % pooled_height; 36 | // n /= pooled_height; 37 | // int c = n % channels; 38 | // n /= channels; 39 | int pw = index % pooled_width; 40 | int ph = (index / pooled_width) % pooled_height; 41 | int c = (index / pooled_width / pooled_height) % channels; 42 | int n = index / pooled_width / pooled_height / channels; 43 | 44 | // bottom_rois += n * 5; 45 | int roi_batch_ind = bottom_rois[n * 5 + 0]; 46 | int roi_start_w = round(bottom_rois[n * 5 + 1] * spatial_scale); 47 | int roi_start_h = round(bottom_rois[n * 5 + 2] * spatial_scale); 48 | int roi_end_w = round(bottom_rois[n * 5 + 3] * spatial_scale); 49 | int roi_end_h = round(bottom_rois[n * 5 + 4] * spatial_scale); 50 | 51 | // Force malformed ROIs to be 1x1 52 | int roi_width = fmaxf(roi_end_w - roi_start_w + 1, 1); 53 | int roi_height = fmaxf(roi_end_h - roi_start_h + 1, 1); 54 | float bin_size_h = (float)(roi_height) / (float)(pooled_height); 55 | float bin_size_w = (float)(roi_width) / (float)(pooled_width); 56 | 57 | int hstart = (int)(floor((float)(ph) * bin_size_h)); 58 | int wstart = (int)(floor((float)(pw) * bin_size_w)); 59 | int hend = (int)(ceil((float)(ph + 1) * bin_size_h)); 60 | int wend = (int)(ceil((float)(pw + 1) * bin_size_w)); 61 | 62 | // Add roi offsets and clip to input boundaries 63 | hstart = fminf(fmaxf(hstart + roi_start_h, 0), height); 64 | hend = fminf(fmaxf(hend + roi_start_h, 0), height); 65 | wstart = fminf(fmaxf(wstart + roi_start_w, 0), width); 66 | wend = fminf(fmaxf(wend + roi_start_w, 0), width); 67 | bool is_empty = (hend <= hstart) || (wend <= wstart); 68 | 69 | // Define an empty pooling region to be zero 70 | float maxval = is_empty ? 0 : -FLT_MAX; 71 | // If nothing is pooled, argmax = -1 causes nothing to be backprop'd 72 | int maxidx = -1; 73 | // bottom_data += roi_batch_ind * channels * height * width; 74 | 75 | int bottom_data_batch_offset = roi_batch_ind * channels * height * width; 76 | int bottom_data_offset = bottom_data_batch_offset + c * height * width; 77 | 78 | for (int h = hstart; h < hend; ++h) { 79 | for (int w = wstart; w < wend; ++w) { 80 | // int bottom_index = (h * width + w) * channels + c; 81 | // int bottom_index = (c * height + h) * width + w; 82 | int bottom_index = h * width + w; 83 | if (bottom_data[bottom_data_offset + bottom_index] > maxval) { 84 | maxval = bottom_data[bottom_data_offset + bottom_index]; 85 | maxidx = bottom_data_offset + bottom_index; 86 | } 87 | } 88 | } 89 | top_data[index] = maxval; 90 | if (argmax_data != NULL) 91 | argmax_data[index] = maxidx; 92 | } 93 | } 94 | 95 | int ROIPoolForwardLaucher( 96 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height, 97 | const int width, const int channels, const int pooled_height, 98 | const int pooled_width, const float* bottom_rois, 99 | float* top_data, int* argmax_data, cudaStream_t stream) 100 | { 101 | const int kThreadsPerBlock = 1024; 102 | int output_size = num_rois * pooled_height * pooled_width * channels; 103 | cudaError_t err; 104 | 105 | ROIPoolForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 106 | output_size, bottom_data, spatial_scale, height, width, channels, pooled_height, 107 | pooled_width, bottom_rois, top_data, argmax_data); 108 | 109 | // dim3 blocks(DIVUP(output_size, kThreadsPerBlock), 110 | // DIVUP(output_size, kThreadsPerBlock)); 111 | // dim3 threads(kThreadsPerBlock); 112 | // 113 | // ROIPoolForward<<>>( 114 | // output_size, bottom_data, spatial_scale, height, width, channels, pooled_height, 115 | // pooled_width, bottom_rois, top_data, argmax_data); 116 | 117 | err = cudaGetLastError(); 118 | if(cudaSuccess != err) 119 | { 120 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 121 | exit( -1 ); 122 | } 123 | 124 | return 1; 125 | } 126 | 127 | 128 | __global__ void ROIPoolBackward(const int nthreads, const float* top_diff, 129 | const int* argmax_data, const int num_rois, const float spatial_scale, 130 | const int height, const int width, const int channels, 131 | const int pooled_height, const int pooled_width, float* bottom_diff, 132 | const float* bottom_rois) { 133 | CUDA_1D_KERNEL_LOOP(index, nthreads) 134 | { 135 | 136 | // (n, c, ph, pw) is an element in the pooled output 137 | int n = index; 138 | int w = n % width; 139 | n /= width; 140 | int h = n % height; 141 | n /= height; 142 | int c = n % channels; 143 | n /= channels; 144 | 145 | float gradient = 0; 146 | // Accumulate gradient over all ROIs that pooled this element 147 | for (int roi_n = 0; roi_n < num_rois; ++roi_n) 148 | { 149 | const float* offset_bottom_rois = bottom_rois + roi_n * 5; 150 | int roi_batch_ind = offset_bottom_rois[0]; 151 | // Skip if ROI's batch index doesn't match n 152 | if (n != roi_batch_ind) { 153 | continue; 154 | } 155 | 156 | int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); 157 | int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); 158 | int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); 159 | int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); 160 | 161 | // Skip if ROI doesn't include (h, w) 162 | const bool in_roi = (w >= roi_start_w && w <= roi_end_w && 163 | h >= roi_start_h && h <= roi_end_h); 164 | if (!in_roi) { 165 | continue; 166 | } 167 | 168 | int offset = roi_n * pooled_height * pooled_width * channels; 169 | const float* offset_top_diff = top_diff + offset; 170 | const int* offset_argmax_data = argmax_data + offset; 171 | 172 | // Compute feasible set of pooled units that could have pooled 173 | // this bottom unit 174 | 175 | // Force malformed ROIs to be 1x1 176 | int roi_width = fmaxf(roi_end_w - roi_start_w + 1, 1); 177 | int roi_height = fmaxf(roi_end_h - roi_start_h + 1, 1); 178 | 179 | float bin_size_h = (float)(roi_height) / (float)(pooled_height); 180 | float bin_size_w = (float)(roi_width) / (float)(pooled_width); 181 | 182 | int phstart = floor((float)(h - roi_start_h) / bin_size_h); 183 | int phend = ceil((float)(h - roi_start_h + 1) / bin_size_h); 184 | int pwstart = floor((float)(w - roi_start_w) / bin_size_w); 185 | int pwend = ceil((float)(w - roi_start_w + 1) / bin_size_w); 186 | 187 | phstart = fminf(fmaxf(phstart, 0), pooled_height); 188 | phend = fminf(fmaxf(phend, 0), pooled_height); 189 | pwstart = fminf(fmaxf(pwstart, 0), pooled_width); 190 | pwend = fminf(fmaxf(pwend, 0), pooled_width); 191 | 192 | for (int ph = phstart; ph < phend; ++ph) { 193 | for (int pw = pwstart; pw < pwend; ++pw) { 194 | if (offset_argmax_data[(c * pooled_height + ph) * pooled_width + pw] == index) 195 | { 196 | gradient += offset_top_diff[(c * pooled_height + ph) * pooled_width + pw]; 197 | } 198 | } 199 | } 200 | } 201 | bottom_diff[index] = gradient; 202 | } 203 | } 204 | 205 | int ROIPoolBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, 206 | const int height, const int width, const int channels, const int pooled_height, 207 | const int pooled_width, const float* bottom_rois, 208 | float* bottom_diff, const int* argmax_data, cudaStream_t stream) 209 | { 210 | const int kThreadsPerBlock = 1024; 211 | int output_size = batch_size * height * width * channels; 212 | cudaError_t err; 213 | 214 | ROIPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 215 | output_size, top_diff, argmax_data, num_rois, spatial_scale, height, width, channels, pooled_height, 216 | pooled_width, bottom_diff, bottom_rois); 217 | 218 | // dim3 blocks(DIVUP(output_size, kThreadsPerBlock), 219 | // DIVUP(output_size, kThreadsPerBlock)); 220 | // dim3 threads(kThreadsPerBlock); 221 | // 222 | // ROIPoolBackward<<>>( 223 | // output_size, top_diff, argmax_data, num_rois, spatial_scale, height, width, channels, pooled_height, 224 | // pooled_width, bottom_diff, bottom_rois); 225 | 226 | err = cudaGetLastError(); 227 | if(cudaSuccess != err) 228 | { 229 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 230 | exit( -1 ); 231 | } 232 | 233 | return 1; 234 | } 235 | 236 | 237 | // #ifdef __cplusplus 238 | // } 239 | // #endif 240 | -------------------------------------------------------------------------------- /models/roi_pooling/src/roi_pooling_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _ROI_POOLING_KERNEL 2 | #define _ROI_POOLING_KERNEL 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | int ROIPoolForwardLaucher( 9 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height, 10 | const int width, const int channels, const int pooled_height, 11 | const int pooled_width, const float* bottom_rois, 12 | float* top_data, int* argmax_data, cudaStream_t stream); 13 | 14 | 15 | int ROIPoolBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois, 16 | const int height, const int width, const int channels, const int pooled_height, 17 | const int pooled_width, const float* bottom_rois, 18 | float* bottom_diff, const int* argmax_data, cudaStream_t stream); 19 | 20 | #ifdef __cplusplus 21 | } 22 | #endif 23 | 24 | #endif 25 | 26 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | #################################################### 2 | ## Prepare the vgg model for the image level classifier and the part level classifier. 3 | ################################################### 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | import math 7 | import copy 8 | import torch 9 | 10 | __all__ = [ 11 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 12 | 'vgg19_bn', 'vgg19', 13 | ] 14 | 15 | 16 | model_urls = { 17 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 18 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 19 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 20 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 21 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 22 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 23 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 24 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 25 | } 26 | 27 | 28 | class VGG(nn.Module): 29 | 30 | def __init__(self, features, num_classes=1000): 31 | super(VGG, self).__init__() 32 | self.features = features 33 | # self.classifier = nn.Sequential( 34 | # nn.Linear(512 * 7 * 7, 4096), 35 | # nn.ReLU(True), 36 | # nn.Dropout(), 37 | # nn.Linear(4096, 4096), 38 | # nn.ReLU(True), 39 | # nn.Dropout(), 40 | # nn.Linear(4096, num_classes), 41 | # ) 42 | self.avgpool = nn.AvgPool2d(28) 43 | self._initialize_weights() 44 | 45 | def forward(self, x): 46 | x = self.features(x) 47 | x = self.avgpool(x) 48 | x = x.view(x.size(0), -1) 49 | # x = self.classifier(x) 50 | return x 51 | 52 | def _initialize_weights(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2. / n)) 57 | if m.bias is not None: 58 | m.bias.data.zero_() 59 | elif isinstance(m, nn.BatchNorm2d): 60 | m.weight.data.fill_(1) 61 | m.bias.data.zero_() 62 | elif isinstance(m, nn.Linear): 63 | m.weight.data.normal_(0, 0.01) 64 | m.bias.data.zero_() 65 | 66 | 67 | def make_layers(cfg, batch_norm=False): 68 | layers = [] 69 | in_channels = 3 70 | print('the vgg network structure has been changed') 71 | for v in cfg: 72 | if v == 'M': 73 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 74 | else: 75 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 76 | if batch_norm: 77 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 78 | else: 79 | layers += [conv2d, nn.ReLU(inplace=True)] 80 | in_channels = v 81 | return nn.Sequential(*layers) 82 | 83 | 84 | cfg = { 85 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 86 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512], 87 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512], 88 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512], 89 | } 90 | 91 | 92 | def vgg11(pretrained=False, **kwargs): 93 | """VGG 11-layer model (configuration "A") 94 | 95 | Args: 96 | pretrained (bool): If True, returns a model pre-trained on ImageNet 97 | """ 98 | model = VGG(make_layers(cfg['A']), **kwargs) 99 | if pretrained: 100 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 101 | return model 102 | 103 | 104 | def vgg11_bn(pretrained=False, **kwargs): 105 | """VGG 11-layer model (configuration "A") with batch normalization 106 | 107 | Args: 108 | pretrained (bool): If True, returns a model pre-trained on ImageNet 109 | """ 110 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 111 | if pretrained: 112 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 113 | return model 114 | 115 | 116 | def vgg13(pretrained=False, **kwargs): 117 | """VGG 13-layer model (configuration "B") 118 | 119 | Args: 120 | pretrained (bool): If True, returns a model pre-trained on ImageNet 121 | """ 122 | model = VGG(make_layers(cfg['B']), **kwargs) 123 | if pretrained: 124 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 125 | return model 126 | 127 | 128 | def vgg13_bn(pretrained=False, **kwargs): 129 | """VGG 13-layer model (configuration "B") with batch normalization 130 | 131 | Args: 132 | pretrained (bool): If True, returns a model pre-trained on ImageNet 133 | """ 134 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 135 | if pretrained: 136 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 137 | return model 138 | 139 | 140 | def vgg16(pretrained=False, **kwargs): 141 | """VGG 16-layer model (configuration "D") 142 | 143 | Args: 144 | pretrained (bool): If True, returns a model pre-trained on ImageNet 145 | """ 146 | model = VGG(make_layers(cfg['D']), **kwargs) 147 | if pretrained: 148 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 149 | return model 150 | 151 | 152 | def vgg16_bn(args, **kwargs): 153 | """VGG 16-layer model (configuration "D") with batch normalization 154 | 155 | Args: 156 | pretrained (bool): If True, returns a model pre-trained on ImageNet 157 | """ 158 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 159 | 160 | if args.pretrain: 161 | print('load the imageNet pretrained model') 162 | pretrained_dict = model_zoo.load_url(model_urls['vgg16_bn']) 163 | model_dict = model.state_dict() 164 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 165 | model_dict.update(pretrained_dict_temp) 166 | model.load_state_dict(model_dict) 167 | if args.pretrained_model != '': 168 | print('load the pretrained model from:', args.pretrained_model) 169 | pretrained_dict = torch.load(args.pretrained_model)['state_dict'] 170 | pretrained_dict_temp = copy.deepcopy(pretrained_dict) 171 | model_state_dict = model.state_dict() 172 | 173 | for k_tmp in pretrained_dict_temp.keys(): 174 | if k_tmp.find('module.') != -1: 175 | k = k_tmp.replace('module.', '') 176 | pretrained_dict[k] = pretrained_dict.pop(k_tmp) 177 | pretrained_dict_temp2 = {k: v for k, v in pretrained_dict.items() if k in model_state_dict} 178 | print(pretrained_dict_temp2) 179 | model_state_dict.update(pretrained_dict_temp2) 180 | model.load_state_dict(model_state_dict) 181 | 182 | source_model = Share_convs(model, 512, args.num_classes_t) 183 | return source_model 184 | 185 | 186 | def vgg19(pretrained=False, **kwargs): 187 | """VGG 19-layer model (configuration "E") 188 | 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | """ 192 | model = VGG(make_layers(cfg['E']), **kwargs) 193 | if pretrained: 194 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 195 | return model 196 | 197 | 198 | class Share_convs(nn.Module): 199 | def __init__(self, base_conv, convout_dimension, num_class): 200 | super(Share_convs, self).__init__() 201 | self.base_conv = base_conv 202 | self.fc = nn.Linear(convout_dimension, num_class) 203 | def forward(self, x): 204 | x = self.base_conv(x) 205 | x = self.fc(x) 206 | return x 207 | 208 | def vgg19_bn(args, **kwargs): 209 | """VGG 19-layer model (configuration 'E') with batch normalization 210 | 211 | Args: 212 | pretrained (bool): If True, returns a model pre-trained on ImageNet 213 | """ 214 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 215 | 216 | if args.pretrain: 217 | print('load the imageNet pretrained model') 218 | pretrained_dict = model_zoo.load_url(model_urls['vgg19_bn']) 219 | model_dict = model.state_dict() 220 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict} 221 | model_dict.update(pretrained_dict_temp) 222 | model.load_state_dict(model_dict) 223 | if args.pretrained_model != '': 224 | print('load the pretrained model from:', args.pretrained_model) 225 | pretrained_dict = torch.load(args.pretrained_model)['state_dict'] 226 | pretrained_dict_temp = copy.deepcopy(pretrained_dict) 227 | model_state_dict = model.state_dict() 228 | 229 | for k_tmp in pretrained_dict_temp.keys(): 230 | if k_tmp.find('module.') != -1: 231 | k = k_tmp.replace('module.', '') 232 | pretrained_dict[k] = pretrained_dict.pop(k_tmp) 233 | 234 | pretrained_dict_temp2 = {k: v for k, v in pretrained_dict.items() if k in model_state_dict} 235 | 236 | model_state_dict.update(pretrained_dict_temp2) 237 | model.load_state_dict(model_state_dict) 238 | 239 | source_model = Share_convs(model, 512, args.num_classes) 240 | return source_model 241 | 242 | 243 | def vgg(args, **kwargs): 244 | print("==> creating model '{}' ".format(args.arch)) 245 | if args.arch == 'vgg19_bn': 246 | return vgg19_bn(args) 247 | elif args.arch == 'vgg16_bn': 248 | return vgg16_bn(args) 249 | else: 250 | raise ValueError('Unrecognized model architecture', args.arch) 251 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def opts(): 5 | parser = argparse.ArgumentParser(description='Train resnet on the cub dataset', 6 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | parser.add_argument('--data_path', type=str, default='/home/lab-zhangyabin/project/fine-grained/CUB_200_2011/', 8 | help='Root of the data set') 9 | parser.add_argument('--dataset', type=str, default='cub200', 10 | help='choose between flowers/cub200') 11 | # Optimization options 12 | parser.add_argument('--epochs', type=int, default=160, help='Number of epochs to train') 13 | parser.add_argument('--schedule', type=int, nargs='+', default=[80, 120], 14 | help='Decrease learning rate at these epochs.') 15 | parser.add_argument('--epochs_part', type=int, default=30, help='Number of epochs to train for the part classifier') 16 | parser.add_argument('--schedule_part', type=int, nargs='+', default=[10, 20], 17 | help='Decrease learning rate at these epochs in the training of part classifier.') 18 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size.') 19 | parser.add_argument('--batch_size_partnet', type=int, default=64, help='Batch size for partnet, set to 64 due to the GPU memory constrain') 20 | parser.add_argument('--lr', '--learning_rate', type=float, default=0.1, help='The initial learning rate.') 21 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.') 22 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='Weight decay (L2 penalty).') 23 | 24 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 25 | # checkpoints 26 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 27 | help='manual epoch number (useful on restarts)') 28 | parser.add_argument('--resume', type=str, default='', help='Checkpoints path to resume(default none)') 29 | parser.add_argument('--pretrained_model', type=str, default='', help='Dir of the ImageNet pretrained modified models') 30 | parser.add_argument('--pretrain', action='store_true', help='whether using pretrained model') 31 | parser.add_argument('--test_only', action='store_true', help='Test only flag') 32 | # Architecture 33 | parser.add_argument('--arch', type=str, default='', help='Model name') 34 | parser.add_argument('--proposals_num', type=int, default=1372, help='the number of proposals for one image') 35 | parser.add_argument('--square_size', type=int, default=4, help='the side length of each cell in DPP module') 36 | parser.add_argument('--proposals_per_square', type=int, default=28, help='the num of proposals per square') 37 | parser.add_argument('--stride', type=int, default=16, help='Stride of the used model') 38 | parser.add_argument('--num_part', type=int, default=3, help='the number of part to be detected in the partnet') 39 | parser.add_argument('--num_classes', type=int, default=200, help='the number of fine-grained classes') 40 | parser.add_argument('--num_select_proposals', type=int, default=50, help='the number of fine-grained classes') 41 | 42 | parser.add_argument('--svb', action='store_true', help='whether apply svb on the classifier') 43 | parser.add_argument('--svb_factor', type=float, default=1.5, help='svb factor in the SVB method') 44 | # i/o 45 | parser.add_argument('--log', type=str, default='./checkpoints', help='Log folder') 46 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 47 | help='number of data loading workers (default: 4)') 48 | parser.add_argument('--print_freq', '-p', default=10, type=int, 49 | metavar='N', help='print frequency (default: 10)') 50 | args = parser.parse_args() 51 | return args 52 | -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import torch.optim 5 | import torch.nn as nn 6 | import torch.backends.cudnn as cudnn 7 | from models.model_construct import Model_Construct 8 | from data.prepare_data import generate_dataloader 9 | from trainer import train # For the training process 10 | from trainer import validate # For the validate (test) process 11 | from trainer import download_part_proposals 12 | from trainer import download_scores 13 | from trainer import svb 14 | from trainer import svb_det 15 | import time 16 | 17 | def Process1_Image_Classifier(args): 18 | 19 | log_now = args.dataset + '/Image_Classifier' 20 | process_name = 'image_classifier' 21 | if os.path.isfile(log_now + '/final.txt'): 22 | print('the Process1_Image_Classifier is finished') 23 | return 24 | best_prec1 = 0 25 | model = Model_Construct(args, process_name) 26 | model = torch.nn.DataParallel(model).cuda() 27 | criterion = nn.CrossEntropyLoss().cuda() 28 | optimizer = torch.optim.SGD([ 29 | {'params': model.module.base_conv.parameters(), 'name': 'pre-trained'}, 30 | {'params': model.module.fc.parameters(), 'lr': args.lr, 'name': 'new-added'} 31 | ], 32 | lr=args.lr, 33 | momentum=args.momentum, 34 | weight_decay=args.weight_decay) 35 | start_epoch = args.start_epoch 36 | if args.resume: 37 | if os.path.isfile(args.resume): 38 | print("==> loading checkpoints '{}'".format(args.resume)) 39 | checkpoint = torch.load(args.resume) 40 | start_epoch = checkpoint['epoch'] 41 | best_prec1 = checkpoint['best_prec1'] 42 | model.load_state_dict(checkpoint['state_dict']) 43 | optimizer.load_state_dict(checkpoint['optimizer']) 44 | print("==> loaded checkpoint '{}'(epoch {})" 45 | .format(args.resume, checkpoint['epoch'])) 46 | args.resume = '' 47 | else: 48 | raise ValueError('The file to be resumed from is not exited', args.resume) 49 | else: 50 | if not os.path.isdir(log_now): 51 | os.makedirs(log_now) 52 | log = open(os.path.join(log_now, 'log.txt'), 'w') 53 | state = {k: v for k, v in args._get_kwargs()} 54 | log.write(json.dumps(state) + '\n') 55 | log.close() 56 | cudnn.benchmark = True 57 | train_loader, val_loader = generate_dataloader(args, process_name, -1) 58 | if args.test_only: 59 | validate(val_loader, model, criterion, 2000, args) 60 | 61 | for epoch in range(start_epoch, args.epochs): 62 | # train for one epoch 63 | train(train_loader, model, criterion, optimizer, epoch, log_now, process_name, args) 64 | # evaluate on the val data 65 | prec1 = validate(val_loader, model, criterion, epoch, log_now, process_name, args) 66 | # record the best prec1 and save checkpoint 67 | is_best = prec1 > best_prec1 68 | best_prec1 = max(prec1, best_prec1) 69 | if is_best: 70 | log = open(os.path.join(log_now, 'log.txt'), 'a') 71 | log.write( 72 | "best acc %3f" % (best_prec1)) 73 | log.close() 74 | save_checkpoint({ 75 | 'epoch': epoch + 1, 76 | 'arch': args.arch, 77 | 'state_dict': model.state_dict(), 78 | 'best_prec1': best_prec1, 79 | 'optimizer' : optimizer.state_dict(), 80 | }, is_best, log_now) 81 | #download_scores(val_loader, model, log_now, process_name, args) 82 | log = open(os.path.join(log_now, 'final.txt'), 'w') 83 | log.write( 84 | "best acc %3f" % (best_prec1)) 85 | log.close() 86 | 87 | 88 | def Process2_PartNet(args): 89 | log_now = args.dataset + '/PartNet' 90 | process_name = 'partnet' 91 | if os.path.isfile(log_now + '/final.txt'): 92 | print('the Process2_PartNet is finished') 93 | return 94 | best_prec1 = 0 95 | model = Model_Construct(args, process_name) 96 | model = torch.nn.DataParallel(model).cuda() 97 | criterion = nn.BCELoss().cuda() 98 | # print(model) 99 | # print('the learning rate for the new added layer is set to 1e-3 to slow down the speed of learning.') 100 | optimizer = torch.optim.SGD([ 101 | {'params': model.module.conv_model.parameters(), 'name': 'pre-trained'}, 102 | {'params': model.module.classification_stream.parameters(), 'name': 'new-added'}, 103 | {'params': model.module.detection_stream.parameters(), 'name': 'new-added'} 104 | ], 105 | lr=args.lr, 106 | momentum=args.momentum, 107 | weight_decay=args.weight_decay) 108 | start_epoch = args.start_epoch 109 | if args.resume: 110 | if os.path.isfile(args.resume): 111 | print("==> loading checkpoints '{}'".format(args.resume)) 112 | checkpoint = torch.load(args.resume) 113 | start_epoch = checkpoint['epoch'] 114 | best_prec1 = checkpoint['best_prec1'] 115 | model.load_state_dict(checkpoint['state_dict']) 116 | optimizer.load_state_dict(checkpoint['optimizer']) 117 | print("==> loaded checkpoint '{}'(epoch {})" 118 | .format(args.resume, checkpoint['epoch'])) 119 | args.resume = '' 120 | else: 121 | raise ValueError('The file to be resumed from is not exited', args.resume) 122 | else: 123 | if not os.path.isdir(log_now): 124 | os.makedirs(log_now) 125 | log = open(os.path.join(log_now, 'log.txt'), 'w') 126 | state = {k: v for k, v in args._get_kwargs()} 127 | log.write(json.dumps(state) + '\n') 128 | log.close() 129 | cudnn.benchmark = True 130 | train_loader, val_loader = generate_dataloader(args, process_name, -1) 131 | if args.test_only: 132 | validate(val_loader, model, criterion, 2000, args) 133 | for epoch in range(start_epoch, args.epochs): 134 | # train for one epoch 135 | train(train_loader, model, criterion, optimizer, epoch, log_now, process_name, args) 136 | # evaluate on the val data 137 | prec1 = validate(val_loader, model, criterion, epoch, log_now, process_name, args) 138 | # record the best prec1 and save checkpoint 139 | is_best = prec1 > best_prec1 140 | best_prec1 = max(prec1, best_prec1) 141 | if is_best: 142 | log = open(os.path.join(log_now, 'log.txt'), 'a') 143 | log.write( 144 | "best acc %3f" % (best_prec1)) 145 | log.close() 146 | save_checkpoint({ 147 | 'epoch': epoch + 1, 148 | 'arch': args.arch, 149 | 'state_dict': model.state_dict(), 150 | 'best_prec1': best_prec1, 151 | 'optimizer': optimizer.state_dict(), 152 | }, is_best, log_now) 153 | svb_timer = time.time() 154 | if args.svb and epoch != (args.epochs - 1): 155 | svb(model, args) 156 | print('!!!!!!!!!!!!!!!!!! the svb constrain is only applied on the classification stream.') 157 | svb_det(model, args) 158 | print('the svb time is: ', time.time() - svb_timer) 159 | #download_scores(val_loader, model, log_now, process_name, args) 160 | log = open(os.path.join(log_now, 'final.txt'), 'w') 161 | log.write( 162 | "best acc %3f" % (best_prec1)) 163 | log.close() 164 | 165 | def Process3_Download_Proposals(args): 166 | log_now = args.dataset + '/Download_Proposals' 167 | process_name = 'download_proposals' 168 | if os.path.isfile(log_now + '/final.txt'): 169 | print('the Process3_download proposals is finished') 170 | return 171 | 172 | model = Model_Construct(args, process_name) 173 | model = torch.nn.DataParallel(model).cuda() 174 | 175 | optimizer = torch.optim.SGD([ 176 | {'params': model.module.conv_model.parameters(), 'name': 'pre-trained'}, 177 | {'params': model.module.classification_stream.parameters(), 'name': 'new-added'}, 178 | {'params': model.module.detection_stream.parameters(), 'name': 'new-added'} 179 | ], 180 | lr=args.lr, 181 | momentum=args.momentum, 182 | weight_decay=args.weight_decay) 183 | log_partnet_model = args.dataset + '/PartNet/model_best.pth.tar' 184 | checkpoint = torch.load(log_partnet_model) 185 | model.load_state_dict(checkpoint['state_dict']) 186 | print('load the pre-trained partnet model from:', log_partnet_model) 187 | 188 | if args.resume: 189 | if os.path.isfile(args.resume): 190 | print("==> loading checkpoints '{}'".format(args.resume)) 191 | checkpoint = torch.load(args.resume) 192 | start_epoch = checkpoint['epoch'] 193 | best_prec1 = checkpoint['best_prec1'] 194 | model.load_state_dict(checkpoint['state_dict']) 195 | optimizer.load_state_dict(checkpoint['optimizer']) 196 | print("==> loaded checkpoint '{}'(epoch {})" 197 | .format(args.resume, checkpoint['epoch'])) 198 | args.resume = '' 199 | else: 200 | raise ValueError('The file to be resumed from is not exited', args.resume) 201 | else: 202 | if not os.path.isdir(log_now): 203 | os.makedirs(log_now) 204 | log = open(os.path.join(log_now, 'log.txt'), 'w') 205 | state = {k: v for k, v in args._get_kwargs()} 206 | log.write(json.dumps(state) + '\n') 207 | log.close() 208 | 209 | cudnn.benchmark = True 210 | train_loader, val_loader = generate_dataloader(args, process_name) 211 | 212 | for epoch in range(1): 213 | 214 | download_part_proposals(train_loader, model, epoch, log_now, process_name, 'train', args) 215 | 216 | best_prec1 = download_part_proposals(val_loader, model, epoch, log_now, process_name, 'val', args) 217 | 218 | log = open(os.path.join(log_now, 'final.txt'), 'w') 219 | log.write( 220 | "best acc %3f" % (best_prec1)) 221 | log.close() 222 | 223 | 224 | def Process4_Part_Classifiers(args): 225 | for i in range(args.num_part): ### if the process is break in this section, more modification is needed. 226 | log_now = args.dataset + '/Part_Classifiers_' + str(i) 227 | process_name = 'part_classifiers' 228 | if os.path.isfile(log_now + '/final.txt'): 229 | print('the Process4_Part_Classifier is finished', i) 230 | continue 231 | best_prec1 = 0 232 | model = Model_Construct(args, process_name) 233 | model = torch.nn.DataParallel(model).cuda() 234 | criterion = nn.CrossEntropyLoss().cuda() 235 | optimizer = torch.optim.SGD([ 236 | {'params': model.module.base_conv.parameters(), 'name': 'pre-trained'}, 237 | {'params': model.module.fc.parameters(), 'lr': args.lr, 'name': 'new-added'} 238 | ], 239 | lr=args.lr, 240 | momentum=args.momentum, 241 | weight_decay=args.weight_decay) 242 | log_image_model = args.dataset + '/Image_Classifier/model_best.pth.tar' 243 | checkpoint = torch.load(log_image_model) 244 | model.load_state_dict(checkpoint['state_dict']) 245 | print('load the cub fine-tuned model from:', log_image_model) 246 | start_epoch = args.start_epoch 247 | if args.resume: 248 | if os.path.isfile(args.resume): 249 | print("==> loading checkpoints '{}'".format(args.resume)) 250 | checkpoint = torch.load(args.resume) 251 | start_epoch = checkpoint['epoch'] 252 | best_prec1 = checkpoint['best_prec1'] 253 | model.load_state_dict(checkpoint['state_dict']) 254 | optimizer.load_state_dict(checkpoint['optimizer']) 255 | print("==> loaded checkpoint '{}'(epoch {})" 256 | .format(args.resume, checkpoint['epoch'])) 257 | args.resume = '' 258 | else: 259 | raise ValueError('The file to be resumed from is not exited', args.resume) 260 | else: 261 | if not os.path.isdir(log_now): 262 | os.makedirs(log_now) 263 | log = open(os.path.join(log_now, 'log.txt'), 'w') 264 | state = {k: v for k, v in args._get_kwargs()} 265 | log.write(json.dumps(state) + '\n') 266 | log.close() 267 | cudnn.benchmark = True 268 | train_loader, val_loader = generate_dataloader(args, process_name, i) 269 | if args.test_only: 270 | validate(val_loader, model, criterion, 2000, args) 271 | for epoch in range(start_epoch, args.epochs_part): 272 | # train for one epoch 273 | train(train_loader, model, criterion, optimizer, epoch, log_now, process_name, args) 274 | # evaluate on the val data 275 | prec1 = validate(val_loader, model, criterion, epoch, log_now, process_name, args) 276 | # record the best prec1 and save checkpoint 277 | is_best = prec1 > best_prec1 278 | best_prec1 = max(prec1, best_prec1) 279 | if is_best: 280 | log = open(os.path.join(log_now, 'log.txt'), 'a') 281 | log.write( 282 | "best acc %3f" % (best_prec1)) 283 | log.close() 284 | save_checkpoint({ 285 | 'epoch': epoch + 1, 286 | 'arch': args.arch, 287 | 'state_dict': model.state_dict(), 288 | 'best_prec1': best_prec1, 289 | 'optimizer' : optimizer.state_dict(), 290 | }, is_best, log_now) 291 | #download_scores(val_loader, model, log_now, process_name, args) 292 | log = open(os.path.join(log_now, 'final.txt'), 'w') 293 | log.write( 294 | "best acc %3f" % (best_prec1)) 295 | log.close() 296 | 297 | def Process5_Final_Result(args): 298 | ############################# Image Level Classifier ############################# 299 | log_now = args.dataset + '/Image_Classifier' 300 | process_name = 'image_classifier' 301 | model = Model_Construct(args, process_name) 302 | model = torch.nn.DataParallel(model).cuda() 303 | pre_trained_model = log_now + '/model_best.pth.tar' 304 | checkpoint = torch.load(pre_trained_model) 305 | model.load_state_dict(checkpoint['state_dict']) 306 | train_loader, val_loader = generate_dataloader(args, process_name, -1) 307 | download_scores(val_loader, model, log_now, process_name, args) 308 | ############################# PartNet ############################################ 309 | log_now = args.dataset + '/PartNet' 310 | process_name = 'partnet' 311 | model = Model_Construct(args, process_name) 312 | model = torch.nn.DataParallel(model).cuda() 313 | pre_trained_model = log_now + '/model_best.pth.tar' 314 | checkpoint = torch.load(pre_trained_model) 315 | model.load_state_dict(checkpoint['state_dict']) 316 | train_loader, val_loader = generate_dataloader(args, process_name) 317 | download_scores(val_loader, model, log_now, process_name, args) 318 | ############################# Three Part Level Classifiers ####################### 319 | for i in range(args.num_part): ### if the process is break in this section, more modification is needed. 320 | log_now = args.dataset + '/Part_Classifiers_' + str(i) 321 | process_name = 'part_classifiers' 322 | model = Model_Construct(args, process_name) 323 | model = torch.nn.DataParallel(model).cuda() 324 | pre_trained_model = log_now + '/model_best.pth.tar' 325 | checkpoint = torch.load(pre_trained_model) 326 | model.load_state_dict(checkpoint['state_dict']) 327 | train_loader, val_loader = generate_dataloader(args, process_name, i) 328 | download_scores(val_loader, model, log_now, process_name, args) 329 | 330 | 331 | log_image = args.dataset + '/Image_Classifier' 332 | process_image = 'image_classifier' 333 | 334 | log_partnet = args.dataset + '/PartNet' 335 | process_partnet = 'partnet' 336 | 337 | log_part0 = args.dataset + '/Part_Classifiers_' + str(0) 338 | process_part0 = 'part_classifiers' 339 | 340 | log_part1 = args.dataset + '/Part_Classifiers_' + str(1) 341 | process_part1 = 'part_classifiers' 342 | 343 | log_part2 = args.dataset + '/Part_Classifiers_' + str(2) 344 | process_part2 = 'part_classifiers' 345 | 346 | image_table = torch.load(log_image + '/' + process_image + '.pth.tar') 347 | image_probability = image_table['scores'] 348 | labels = image_table['labels'] 349 | partnet_table = torch.load(log_partnet + '/' + process_partnet + '.pth.tar') 350 | partnet_probability = partnet_table['scores'] 351 | ####################### 352 | part0_table = torch.load(log_part0 + '/' + process_part0 + '.pth.tar') 353 | part0_probability = part0_table['scores'] 354 | ########################## 355 | part1_table = torch.load(log_part1 + '/' + process_part1 + '.pth.tar') 356 | part1_probability = part1_table['scores'] 357 | ########################## 358 | part2_table = torch.load(log_part2 + '/' + process_part2 + '.pth.tar') 359 | part2_probability = part2_table['scores'] 360 | ########################## 361 | 362 | probabilities_group = [] 363 | probabilities_group.append(image_probability) 364 | probabilities_group.append(part0_probability) 365 | probabilities_group.append(part1_probability) 366 | probabilities_group.append(part2_probability) 367 | probabilities_group.append(partnet_probability) 368 | count = 0 369 | for i in range(len(labels)): 370 | probability = probabilities_group[0][i] 371 | for j in range(len(probabilities_group)): 372 | probability = probabilities_group[j][i] + probability 373 | probability = probability - probabilities_group[0][i] 374 | label = labels[i] 375 | value, index = probability.sort(0, descending=True) 376 | if index[0] == label: 377 | count = count + 1 378 | top1 = count / len(labels) 379 | print('the final results obtained by averaging part0-1-2 image partnet is', top1) 380 | 381 | def save_checkpoint(state, is_best, log_now): 382 | filename = 'checkpoint.pth.tar' 383 | dir_save_file = os.path.join(log_now, filename) 384 | torch.save(state, dir_save_file) 385 | if is_best: 386 | shutil.copyfile(dir_save_file, os.path.join(log_now, 'model_best.pth.tar')) 387 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ########################################## Use the following command to test the final results #################################### 3 | python main.py --batch_size 128 --batch_size_partnet 64 --momentum 0.9 --weight_decay 1e-4 --data_path /data/flower102/ --test_only \ 4 | --proposals_num 1372 --square_size 4 --proposals_per_square 28 --workers 8 --lr 0.1 --svb --svb_factor 1.5 \ 5 | --dataset flowers --print_freq 1 --arch vgg19_bn --num_part 3 --epochs_part 30 --num_classes 102 --num_select_proposals 50 \ 6 | --epochs 160 --pretrained_model /home/lab-zhang.yabin/PAFGN/pytorch-PartNet/pretrain_vgg_imagenet/vgg19-bn-modified/model_best.pth.tar \ 7 | 8 | #python main.py --batch_size 128 --batch_size_partnet 64 --momentum 0.9 --weight_decay 1e-4 --data_path /data/CUB_200_2011/ --test_only \ 9 | # --proposals_num 1372 --square_size 4 --proposals_per_square 28 --workers 8 --lr 0.1 --svb --svb_factor 1.5 --schedule 80 120 --schedule_part 10 20 \ 10 | # --dataset cub200 --print_freq 1 --arch vgg19_bn --num_part 3 --epochs_part 30 --num_classes 200 --num_select_proposals 50 \ 11 | # --epochs 160 --pretrained_model /home/lab-zhang.yabin/PAFGN/pytorch-PartNet/pretrain_vgg_imagenet/vgg19-bn-modified/model_best.pth.tar 12 | 13 | 14 | ########################################### use the following command to train the PartNet ####################################### 15 | #python main.py --batch_size 128 --batch_size_partnet 64 --momentum 0.9 --weight_decay 1e-4 --data_path /data/flower102/ \ 16 | # --proposals_num 1372 --square_size 4 --proposals_per_square 28 --workers 8 --lr 0.1 --svb --svb_factor 1.5 \ 17 | # --dataset flowers --print_freq 1 --arch vgg19_bn --num_part 3 --epochs_part 30 --num_classes 102 --num_select_proposals 50 \ 18 | # --epochs 160 --pretrained_model /home/lab-zhang.yabin/PAFGN/pytorch-PartNet/pretrain_vgg_imagenet/vgg19-bn-modified/model_best.pth.tar \ 19 | 20 | 21 | #python main.py --batch_size 128 --batch_size_partnet 64 --momentum 0.9 --weight_decay 1e-4 --data_path /data/CUB_200_2011/ \ 22 | # --proposals_num 1372 --square_size 4 --proposals_per_square 28 --workers 8 --lr 0.1 --svb --svb_factor 1.5 --schedule 80 120 --schedule_part 10 20 \ 23 | # --dataset cub200 --print_freq 1 --arch vgg19_bn --num_part 3 --epochs_part 30 --num_classes 200 --num_select_proposals 50 \ 24 | # --epochs 160 --pretrained_model /home/lab-zhang.yabin/PAFGN/pytorch-PartNet/pretrain_vgg_imagenet/vgg19-bn-modified/model_best.pth.tar 25 | 26 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import os 4 | import ipdb 5 | from PIL import Image 6 | import torch.nn as nn 7 | import torchvision.transforms as transforms 8 | import numpy as np 9 | 10 | def train(train_loader, model, criterion, optimizer, epoch, log_now, process_name, args): 11 | batch_time = AverageMeter() 12 | data_time = AverageMeter() 13 | losses = AverageMeter() 14 | top1 = AverageMeter() 15 | top5 = AverageMeter() 16 | model.train() 17 | adjust_learning_rate(optimizer, epoch, process_name, args) 18 | end = time.time() 19 | for i, (input, target, target_loss) in enumerate(train_loader): 20 | data_time.update(time.time() - end) 21 | if process_name == 'partnet': 22 | target = target.cuda(non_blocking=True) 23 | target_loss = target_loss.cuda(non_blocking=True) 24 | target_var = torch.autograd.Variable(target_loss) 25 | elif process_name == 'image_classifier' or process_name == 'part_classifiers': 26 | target = target.cuda(non_blocking=True) 27 | target_var = torch.autograd.Variable(target) 28 | else: 29 | raise ValueError('the required process type is not supported') 30 | input_var = torch.autograd.Variable(input) 31 | # print(target_var) 32 | # ipdb.set_trace() 33 | output = model(input_var) 34 | # print(output) 35 | loss = criterion(output, target_var) 36 | #mesure accuracy and record loss 37 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 38 | losses.update(loss.data[0], input.size(0)) 39 | top1.update(prec1[0], input.size(0)) 40 | top5.update(prec5[0], input.size(0)) 41 | 42 | #compute gradient and do SGD step 43 | optimizer.zero_grad() 44 | loss.backward() 45 | optimizer.step() 46 | batch_time.update(time.time() - end) 47 | end = time.time() 48 | if i % args.print_freq == 0: 49 | print('Train: [{0}][{1}/{2}]\t' 50 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 51 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 52 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 53 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 54 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 55 | epoch, i, len(train_loader), batch_time=batch_time, 56 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 57 | log = open(os.path.join(log_now, 'log.txt'), 'a') 58 | log.write("\n") 59 | log.write("Train:epoch: %d, loss: %4f, Top1 acc: %3f, Top5 acc: %3f" % (epoch, losses.avg, top1.avg, top5.avg)) 60 | log.close() 61 | 62 | 63 | def validate(val_loader, model, criterion, epoch, log_now, process_name, args): 64 | batch_time = AverageMeter() 65 | losses = AverageMeter() 66 | top1 = AverageMeter() 67 | top5 = AverageMeter() 68 | # switch to evaluate mode 69 | model.eval() 70 | 71 | end = time.time() 72 | for i, (input, target, target_loss) in enumerate(val_loader): 73 | target = target.cuda(non_blocking=True) 74 | input_var = torch.autograd.Variable(input) 75 | if process_name == 'partnet': 76 | target = target.cuda(non_blocking=True) 77 | target_loss = target_loss.cuda(non_blocking=True) 78 | target_var = torch.autograd.Variable(target_loss) 79 | elif process_name == 'image_classifier' or process_name == 'part_classifiers': 80 | target = target.cuda(non_blocking=True) 81 | target_var = torch.autograd.Variable(target) 82 | # compute output 83 | with torch.no_grad(): 84 | output = model(input_var) 85 | loss = criterion(output, target_var) 86 | # measure accuracy and record loss 87 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 88 | losses.update(loss.data[0], input.size(0)) 89 | top1.update(prec1[0], input.size(0)) 90 | top5.update(prec5[0], input.size(0)) 91 | 92 | # measure elapsed time 93 | batch_time.update(time.time() - end) 94 | end = time.time() 95 | 96 | if i % args.print_freq == 0: 97 | print('Test: [{0}][{1}/{2}]\t' 98 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 99 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 100 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 101 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 102 | epoch, i, len(val_loader), batch_time=batch_time, loss=losses, 103 | top1=top1, top5=top5)) 104 | 105 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 106 | .format(top1=top1, top5=top5)) 107 | log = open(os.path.join(log_now, 'log.txt'), 'a') 108 | log.write("\n") 109 | log.write(" Train:epoch: %d, loss: %4f, Top1 acc: %3f, Top5 acc: %3f" %\ 110 | (epoch, losses.avg, top1.avg, top5.avg)) 111 | log.close() 112 | return top1.avg 113 | 114 | def download_part_proposals(data_loader, model, epoch, log_now, process_name, train_or_val, args): 115 | batch_time = AverageMeter() 116 | data_time = AverageMeter() 117 | 118 | top1 = AverageMeter() 119 | top5 = AverageMeter() 120 | model.eval() 121 | # adjust_learning_rate(optimizer, epoch, args) 122 | end = time.time() 123 | tensor2image = transforms.ToPILImage() 124 | if train_or_val == 'train': 125 | num_selected = args.num_select_proposals 126 | elif train_or_val == 'val': 127 | num_selected = 1 128 | else: 129 | raise ValueError('only accept train or val module') 130 | for i, (input, input_keep, target, target_loss, path_image) in enumerate(data_loader): 131 | data_time.update(time.time() - end) 132 | target = target.cuda(non_blocking=True) 133 | target_loss = target_loss.cuda(non_blocking=True) 134 | input_var = torch.autograd.Variable(input) 135 | target_var = torch.autograd.Variable(target_loss) 136 | 137 | proposals_scores_h = torch.Tensor(input.size(0), args.proposals_num, args.num_part + 1).fill_( 138 | 0) # all the proposals scores 139 | all_proposals = torch.Tensor(input.size(0) * args.proposals_num, 5).fill_(0) # all the 140 | 141 | def hook_scores(module, inputdata, output): 142 | proposals_scores_h.copy_(output.data) 143 | 144 | def hook_proposals(module, inputdata, output): 145 | all_proposals.copy_(output.data) 146 | 147 | handle_scores = model.module.detection_stream.softmax_cls.register_forward_hook(hook_scores) 148 | handle_proposals = model.module.DPP.register_forward_hook(hook_proposals) 149 | with torch.no_grad(): 150 | output = model(input_var) 151 | 152 | handle_proposals.remove() ## delete the hook after used. 153 | handle_scores.remove() 154 | # print(output) 155 | for j in range(input.size(0)): 156 | real_image = tensor2image(input_keep[j]) 157 | scores_for_image = proposals_scores_h[j] 158 | value, sort = torch.sort(scores_for_image, 0, descending=True) # the score from large to small 159 | # print('value:', value) 160 | # print('sort:', sort) 161 | proposals_one = all_proposals[j*args.proposals_num:(j+1)*args.proposals_num, 1:5] 162 | # print(sort) 163 | #check whether the dir file exist, if not, create one. 164 | img_dir = path_image[j] 165 | 166 | # ipdb.set_trace() 167 | for num_p in range(0, args.num_part+1): 168 | last_dash = img_dir.rfind('/') 169 | image_folder = args.data_path + 'PartNet' + args.arch + '/part_' + str(num_p) + '/' + img_dir[0:last_dash] 170 | if not os.path.isdir(image_folder): 171 | os.makedirs(image_folder) 172 | image_name = img_dir[last_dash:] 173 | last_point = image_name.find('.') 174 | name = image_name[0:last_point] 175 | one_part_sort = sort[:, num_p] 176 | for k in range(num_selected): # for each proposals 177 | select_proposals = np.array(proposals_one[one_part_sort[k]]) 178 | cropped_image = real_image.crop((select_proposals[0], select_proposals[1], select_proposals[2], select_proposals[3])) 179 | dir_to_save = image_folder + name + '_' + str(k) + image_name[last_point:] 180 | cropped_image.save(dir_to_save) 181 | 182 | #mesure accuracy and record loss 183 | prec1, prec5 = accuracy(output.data, target, topk=(1,5)) 184 | top1.update(prec1[0], input.size(0)) 185 | top5.update(prec5[0], input.size(0)) 186 | 187 | batch_time.update(time.time() - end) 188 | end = time.time() 189 | if i % args.print_freq == 0: 190 | print(train_or_val + ': [{0}][{1}/{2}]\t' 191 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 192 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 193 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 194 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 195 | epoch, i, len(data_loader), batch_time=batch_time, 196 | data_time=data_time, top1=top1, top5=top5)) 197 | log = open(os.path.join(log_now, 'log.txt'), 'a') 198 | log.write("\n") 199 | log.write(train_or_val + ":epoch: %d, Top1 acc: %3f, Top5 acc: %3f" % (epoch, top1.avg, top5.avg)) 200 | log.close() 201 | return top1.avg 202 | 203 | def download_scores(val_loader, model, log_now, process_name, args): 204 | if not os.path.isdir(log_now): 205 | raise ValueError('the log dir request is not exist') 206 | file_to_save_or_load = log_now + '/' + process_name + '.pth.tar' 207 | probabilities = [] 208 | labels = [] 209 | batch_time = AverageMeter() 210 | top1 = AverageMeter() 211 | top5 = AverageMeter() 212 | # switch to evaluate mode 213 | model.eval() 214 | end = time.time() 215 | softmax = nn.Softmax() 216 | for i, (input, target, target_loss) in enumerate(val_loader): 217 | target = target.cuda(async=True) 218 | input_var = torch.autograd.Variable(input) 219 | 220 | # compute output 221 | with torch.no_grad(): 222 | output = model(input_var) 223 | # print(output) 224 | if process_name == 'partnet': 225 | output = torch.nn.functional.normalize(output, p=1, dim=1) 226 | else: 227 | output = softmax(output) 228 | 229 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 230 | for j in range(input.size(0)): 231 | # maybe here need to sub tensor to save memory. 232 | probabilities.append(output.data[j].cpu().clone()) 233 | labels.append(target[j]) 234 | 235 | top1.update(prec1[0], input.size(0)) 236 | top5.update(prec5[0], input.size(0)) 237 | 238 | # measure elapsed time 239 | batch_time.update(time.time() - end) 240 | end = time.time() 241 | print('Test: [{0}][{1}/{2}]\t' 242 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 243 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 244 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 245 | 1, i, len(val_loader), batch_time=batch_time, 246 | top1=top1, top5=top5)) 247 | 248 | log = open(os.path.join(log_now, 'log.txt'), 'a') 249 | log.write("\n") 250 | log.write(process_name) 251 | log.write(" Val:epoch: %d, Top1 acc: %3f, Top5 acc: %3f" % \ 252 | (1, top1.avg, top5.avg)) 253 | log.close() 254 | torch.save({'scores': probabilities, 'labels': labels}, file_to_save_or_load) 255 | 256 | return probabilities, labels 257 | 258 | def svb(model, args): 259 | print('the layer used for svb is', model.module.classification_stream.classifier[3]) 260 | svb_model = model.module.classification_stream.classifier[3] 261 | tmpbatchM = svb_model.weight.data.t().clone() 262 | tmpU, tmpS, tmpV = torch.svd(tmpbatchM) 263 | for idx in range(0, tmpS.size(0)): 264 | if tmpS[idx] > args.svb_factor: 265 | tmpS[idx] = args.svb_factor 266 | elif tmpS[idx] < 1 / args.svb_factor: 267 | tmpS[idx] = 1 / args.svb_factor 268 | tmpbatchM = torch.mm(torch.mm(tmpU, torch.diag(tmpS.cuda())), tmpV.t()).t().contiguous() 269 | svb_model.weight.data.copy_(tmpbatchM.view_as(svb_model.weight.data)) 270 | 271 | def svb_det(model, args): ## it is not use in our experiments 272 | print('the layer used for svb is', model.module.detection_stream.detector[3]) 273 | svb_model = model.module.detection_stream.detector[3] 274 | tmpbatchM = svb_model.weight.data.t().clone() 275 | tmpU, tmpS, tmpV = torch.svd(tmpbatchM) 276 | for idx in range(0, tmpS.size(0)): 277 | if tmpS[idx] > args.svb_factor: 278 | tmpS[idx] = args.svb_factor 279 | elif tmpS[idx] < 1 / args.svb_factor: 280 | tmpS[idx] = 1 / args.svb_factor 281 | tmpbatchM = torch.mm(torch.mm(tmpU, torch.diag(tmpS.cuda())), tmpV.t()).t().contiguous() 282 | svb_model.weight.data.copy_(tmpbatchM.view_as(svb_model.weight.data)) 283 | 284 | class AverageMeter(object): 285 | """Computes and stores the average and current value""" 286 | def __init__(self): 287 | self.reset() 288 | 289 | def reset(self): 290 | self.val = 0 291 | self.avg = 0 292 | self.sum = 0 293 | self.count = 0 294 | 295 | def update(self, val, n=1): 296 | self.val = val 297 | self.sum += val * n 298 | self.count += n 299 | self.avg = self.sum / self.count 300 | 301 | 302 | def adjust_learning_rate(optimizer, epoch, process_name, args): 303 | """Adjust the learning rate according the epoch""" 304 | # print(epoch) 305 | # print(args.schedule[1]) 306 | if process_name == 'part_classifiers': 307 | exp = epoch >= args.schedule_part[1] and 2 or epoch >= args.schedule_part[0] and 1 or 0 308 | exp_pre = epoch >= args.schedule_part[1] and 2 or epoch >= args.schedule_part[0] and 2 or 2 309 | elif process_name == 'partnet': 310 | exp = epoch >= args.schedule[1] and 2 or epoch >= args.schedule[0] and 1 or 0 311 | exp_pre = epoch >= args.schedule[1] and 2 or epoch >= args.schedule[0] and 2 or 2 312 | else: 313 | exp = epoch >= args.schedule[1] and 2 or epoch >= args.schedule[0] and 1 or 0 314 | exp_pre = epoch >= args.schedule[1] and 2 or epoch >= args.schedule[0] and 2 or 2 315 | 316 | # print(exp) 317 | lr = args.lr * (args.gamma ** exp) 318 | lr_pre = args.lr * (args.gamma ** exp_pre) 319 | print('LR for new-added', lr) 320 | print('LR for old', lr_pre) 321 | for param_group in optimizer.param_groups: 322 | if param_group['name'] == 'pre-trained': 323 | param_group['lr'] = lr_pre 324 | else: 325 | param_group['lr'] = lr 326 | 327 | 328 | def accuracy(output, target, topk=(1,)): 329 | """Computes the precision@k for the specified values of k""" 330 | maxk = max(topk) 331 | batch_size = target.size(0) 332 | _, pred = output.topk(maxk, 1, True, True) 333 | pred = pred.t() 334 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 335 | 336 | res = [] 337 | for k in topk: 338 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 339 | res.append(correct_k.mul_(100.0 / batch_size)) 340 | return res 341 | --------------------------------------------------------------------------------