├── .gitignore
├── LICENSE
├── Readme.md
├── data
├── aircrafts.py
├── cars.py
├── cub200.py
├── flowers.py
├── folder_download.py
├── folder_train.py
├── prepare_data.py
└── stanford_dogs.py
├── main.py
├── models
├── DPP
│ ├── Readme.md
│ ├── _ext
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-36.pyc
│ │ └── dpp
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ └── __init__.cpython-36.pyc
│ │ │ └── _dpp.so
│ ├── build.py
│ ├── functions
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ └── dpp.cpython-36.pyc
│ │ ├── dpp.py
│ │ └── test_dpp.py
│ ├── modules
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ └── dpp.cpython-36.pyc
│ │ └── dpp.py
│ └── src
│ │ ├── dpp.c
│ │ └── dpp.h
├── model_construct.py
├── partnet_vgg.py
├── roi_pooling
│ ├── Readme.md
│ ├── _ext
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-36.pyc
│ │ └── roi_pooling
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ └── __init__.cpython-36.pyc
│ │ │ └── _roi_pooling.so
│ ├── build.py
│ ├── functions
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ └── roi_pool.cpython-36.pyc
│ │ └── roi_pool.py
│ ├── modules
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── roi_pool.cpython-36.pyc
│ │ │ └── roi_pool_py.cpython-36.pyc
│ │ └── roi_pool.py
│ └── src
│ │ ├── roi_pooling.c
│ │ ├── roi_pooling.cu.o
│ │ ├── roi_pooling.h
│ │ ├── roi_pooling_cuda.c
│ │ ├── roi_pooling_cuda.h
│ │ ├── roi_pooling_kernel.cu
│ │ └── roi_pooling_kernel.h
└── vgg.py
├── opts.py
├── process.py
├── run.sh
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Yabin Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Readme.md:
--------------------------------------------------------------------------------
1 | # Part-Aware Fine-grained Object Categorization using Weakly Supervised Part Detection Network
2 |
3 | The **pytorch** implementation of the paper: Part-Aware Fine-grained Object Categorization using Weakly Supervised Part Detection Network.
4 |
5 | The paper is available at: https://arxiv.org/abs/1806.06198
6 |
7 | A **torch** version of the PartNet (Only the PartNet module) is going to be available.
8 |
9 | To train the PartNet based on the ImageNet pre-trained model (e.g., VGGNet), you need to download the ImageNet pre-trained model firstly and place it in the "vgg19-bn-modified" folder.
10 | We provide a ImageNet pre-trained model, which obtains 74.262% top1 acc (almost the same as 74.266% provided by: https://github.com/Cadene/pretrained-models.pytorch#torchvision).
11 |
12 | The commands to train the model from scratch and to verify the final results can be found in the **run.sh**.
13 |
14 | We provide all the intermediate models for fast implementation and verification, which can be downloaded from:
15 | * [Baidu Cloud](https://pan.baidu.com/s/1h5oTI4POrSWBo_XEDkFZnw)
16 | * [Google Cloud](https://drive.google.com/drive/folders/1HNXGE2fI5BHSHCROw8aXyKc4R2aZJzDx?usp=sharing)
17 |
18 | Note that the 'DPP' and 'ROI-Pooling' modules are need be compiled. Details can be find the Readme.md in each folder.
19 |
20 | If you have any question about the paper and the code, feel free to sent email to me: zhang.yabin@mail.scut.edu.cn
21 |
22 |
--------------------------------------------------------------------------------
/data/aircrafts.py:
--------------------------------------------------------------------------------
1 | #################################################################
2 | ## The data preparation for the aircraft dataset
3 | #################################################################
4 | import os
5 | import shutil
6 | import torch
7 | import torchvision.transforms as transforms
8 | from data.folder_train import ImageFolder_train
9 | from data.folder_download import ImageFolder_download
10 |
11 | def split_train_test_images(data_dir):
12 | #data_dir = '/data/cars/'
13 | raise ValueError('the process of generate splited training and test images is empty')
14 |
15 | def aircrafts(args, process_name, part_index):
16 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
17 | std=[0.229, 0.224, 0.225])
18 | if process_name == 'image_classifier' or process_name == 'partnet' or process_name == 'download_proposals':
19 | traindir = os.path.join(args.data_path, 'splited_image/train')
20 | valdir = os.path.join(args.data_path, 'splited_image/val')
21 | if not os.path.isdir(traindir):
22 | print('the cub images are not well splited, split all images into train and val set')
23 | split_train_test_images(args.data_path)
24 | if process_name == 'download_proposals':
25 | print('the image pre-process for process: download proposals is Resize 512 and Center Crop 448')
26 | train_dataset = ImageFolder_download(
27 | root=traindir,
28 | transform=transforms.Compose([
29 | transforms.Resize(512),
30 | transforms.CenterCrop(448),
31 | transforms.ToTensor(),
32 | normalize,
33 | ]),
34 | transform_keep=transforms.Compose([
35 | transforms.Resize(512),
36 | transforms.CenterCrop(448),
37 | transforms.ToTensor()
38 | ]),
39 | dataset_path=args.data_path
40 | )
41 | train_loader = torch.utils.data.DataLoader(
42 | train_dataset, batch_size=args.batch_size_partnet, shuffle=False, num_workers=args.workers,
43 | pin_memory=True, sampler=None
44 | )
45 | val_loader = torch.utils.data.DataLoader(
46 | ImageFolder_download(
47 | root=valdir,
48 | transform=transforms.Compose([
49 | transforms.Resize(512),
50 | transforms.CenterCrop(448),
51 | transforms.ToTensor(),
52 | normalize,
53 | ]),
54 | transform_keep=transforms.Compose([
55 | transforms.Resize(512),
56 | transforms.CenterCrop(448),
57 | transforms.ToTensor(),
58 | ]),
59 | dataset_path=args.data_path),
60 | batch_size=args.batch_size_partnet, shuffle=False,
61 | num_workers=args.workers, pin_memory=True
62 | )
63 | return train_loader, val_loader
64 | elif process_name == 'image_classifier':
65 | print('the image pre-process for process: image_classifier is Resize 512 and Random Crop 448')
66 | train_dataset = ImageFolder_train(
67 | traindir,
68 | transforms.Compose([
69 | transforms.Resize(512),
70 | transforms.RandomCrop(448),
71 | transforms.ToTensor(),
72 | normalize,
73 | ])
74 | )
75 | train_loader = torch.utils.data.DataLoader(
76 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
77 | pin_memory=True, sampler=None
78 | )
79 | val_loader = torch.utils.data.DataLoader(
80 | ImageFolder_train(valdir, transforms.Compose([
81 | transforms.Resize(512),
82 | transforms.CenterCrop(448),
83 | transforms.ToTensor(),
84 | normalize,
85 | ])),
86 | batch_size=args.batch_size, shuffle=False,
87 | num_workers=args.workers, pin_memory=True
88 | )
89 | return train_loader, val_loader
90 | elif process_name == 'partnet':
91 | print('the image pre-process for process: partnet is Resize 512 and Random Crop 448')
92 | train_dataset = ImageFolder_train(
93 | traindir,
94 | transforms.Compose([
95 | transforms.Resize(512),
96 | transforms.RandomCrop(448),
97 | transforms.ToTensor(),
98 | normalize,
99 | ])
100 | )
101 | train_loader = torch.utils.data.DataLoader(
102 | train_dataset, batch_size=args.batch_size_partnet, shuffle=True, num_workers=args.workers,
103 | pin_memory=True, sampler=None
104 | )
105 | val_loader = torch.utils.data.DataLoader(
106 | ImageFolder_train(valdir, transforms.Compose([
107 | transforms.Resize(512),
108 | transforms.CenterCrop(448),
109 | transforms.ToTensor(),
110 | normalize,
111 | ])),
112 | batch_size=args.batch_size_partnet, shuffle=False,
113 | num_workers=args.workers, pin_memory=True
114 | )
115 | return train_loader, val_loader
116 | elif process_name == 'part_classifiers':
117 | traindir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/train/'
118 | valdir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/val/'
119 | print('the image pre-process for process: part_classifier is Resize (448, 448) directly')
120 | train_dataset = ImageFolder_train(
121 | traindir,
122 | transforms.Compose([
123 | transforms.Resize((448, 448)),
124 | transforms.ToTensor(),
125 | normalize,
126 | ])
127 | )
128 | train_loader = torch.utils.data.DataLoader(
129 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
130 | pin_memory=True, sampler=None
131 | )
132 | val_loader = torch.utils.data.DataLoader(
133 | ImageFolder_train(valdir, transforms.Compose([
134 | transforms.Resize((448, 448)),
135 | transforms.ToTensor(),
136 | normalize,
137 | ])),
138 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
139 | pin_memory=True, sampler=None
140 | )
141 | return train_loader, val_loader
142 |
143 |
144 |
--------------------------------------------------------------------------------
/data/cars.py:
--------------------------------------------------------------------------------
1 | #################################################################
2 | ## The data preparation for the aircraft dataset
3 | ############################################################
4 | import os
5 | import shutil
6 | import torch
7 | import torchvision.transforms as transforms
8 | from data.folder_train import ImageFolder_train
9 | from data.folder_download import ImageFolder_download
10 |
11 | def split_train_test_images(data_dir):
12 | #data_dir = '/data/cars/'
13 | raise ValueError('the process of generate splited training and test images is empty')
14 |
15 | def cars(args, process_name, part_index):
16 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
17 | std=[0.229, 0.224, 0.225])
18 | if process_name == 'image_classifier' or process_name == 'partnet' or process_name == 'download_proposals':
19 | traindir = os.path.join(args.data_path, 'splited_image/train')
20 | valdir = os.path.join(args.data_path, 'splited_image/val')
21 | if not os.path.isdir(traindir):
22 | print('the cub images are not well splited, split all images into train and val set')
23 | split_train_test_images(args.data_path)
24 | if process_name == 'download_proposals':
25 | print('the image pre-process for process: download proposals is Resize 512 and Center Crop 448')
26 | train_dataset = ImageFolder_download(
27 | root=traindir,
28 | transform=transforms.Compose([
29 | transforms.Resize(512),
30 | transforms.CenterCrop(448),
31 | transforms.ToTensor(),
32 | normalize,
33 | ]),
34 | transform_keep=transforms.Compose([
35 | transforms.Resize(512),
36 | transforms.CenterCrop(448),
37 | transforms.ToTensor()
38 | ]),
39 | dataset_path=args.data_path
40 | )
41 | train_loader = torch.utils.data.DataLoader(
42 | train_dataset, batch_size=args.batch_size_partnet, shuffle=False, num_workers=args.workers,
43 | pin_memory=True, sampler=None
44 | )
45 | val_loader = torch.utils.data.DataLoader(
46 | ImageFolder_download(
47 | root=valdir,
48 | transform=transforms.Compose([
49 | transforms.Resize(512),
50 | transforms.CenterCrop(448),
51 | transforms.ToTensor(),
52 | normalize,
53 | ]),
54 | transform_keep=transforms.Compose([
55 | transforms.Resize(512),
56 | transforms.CenterCrop(448),
57 | transforms.ToTensor(),
58 | ]),
59 | dataset_path=args.data_path),
60 | batch_size=args.batch_size_partnet, shuffle=False,
61 | num_workers=args.workers, pin_memory=True
62 | )
63 | return train_loader, val_loader
64 | elif process_name == 'image_classifier':
65 | print('the image pre-process for process: image_classifier is Resize 512 and Random Crop 448')
66 | train_dataset = ImageFolder_train(
67 | traindir,
68 | transforms.Compose([
69 | transforms.Resize(512),
70 | transforms.RandomCrop(448),
71 | transforms.ToTensor(),
72 | normalize,
73 | ])
74 | )
75 | train_loader = torch.utils.data.DataLoader(
76 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
77 | pin_memory=True, sampler=None
78 | )
79 | val_loader = torch.utils.data.DataLoader(
80 | ImageFolder_train(valdir, transforms.Compose([
81 | transforms.Resize(512),
82 | transforms.CenterCrop(448),
83 | transforms.ToTensor(),
84 | normalize,
85 | ])),
86 | batch_size=args.batch_size, shuffle=False,
87 | num_workers=args.workers, pin_memory=True
88 | )
89 | return train_loader, val_loader
90 | elif process_name == 'partnet':
91 | print('the image pre-process for process: partnet is Resize 512 and Random Crop 448')
92 | train_dataset = ImageFolder_train(
93 | traindir,
94 | transforms.Compose([
95 | transforms.Resize(512),
96 | transforms.RandomCrop(448),
97 | transforms.ToTensor(),
98 | normalize,
99 | ])
100 | )
101 | train_loader = torch.utils.data.DataLoader(
102 | train_dataset, batch_size=args.batch_size_partnet, shuffle=True, num_workers=args.workers,
103 | pin_memory=True, sampler=None
104 | )
105 | val_loader = torch.utils.data.DataLoader(
106 | ImageFolder_train(valdir, transforms.Compose([
107 | transforms.Resize(512),
108 | transforms.CenterCrop(448),
109 | transforms.ToTensor(),
110 | normalize,
111 | ])),
112 | batch_size=args.batch_size_partnet, shuffle=False,
113 | num_workers=args.workers, pin_memory=True
114 | )
115 | return train_loader, val_loader
116 | elif process_name == 'part_classifiers':
117 | traindir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/train/'
118 | valdir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/val/'
119 | print('the image pre-process for process: part_classifier is Resize (448, 448) directly')
120 | train_dataset = ImageFolder_train(
121 | traindir,
122 | transforms.Compose([
123 | transforms.Resize((448, 448)),
124 | transforms.ToTensor(),
125 | normalize,
126 | ])
127 | )
128 | train_loader = torch.utils.data.DataLoader(
129 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
130 | pin_memory=True, sampler=None
131 | )
132 | val_loader = torch.utils.data.DataLoader(
133 | ImageFolder_train(valdir, transforms.Compose([
134 | transforms.Resize((448, 448)),
135 | transforms.ToTensor(),
136 | normalize,
137 | ])),
138 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
139 | pin_memory=True, sampler=None
140 | )
141 | return train_loader, val_loader
142 |
143 |
144 |
--------------------------------------------------------------------------------
/data/cub200.py:
--------------------------------------------------------------------------------
1 | #################################################################
2 | ## The data preparation for the cub-200-2011 dataset
3 | ## The cub-200-2011 dataset can be downloaded at:
4 | ## http://www.vision.caltech.edu/visipedia/CUB-200.html
5 | ## All the downloaded related file should be placed at the "args.data_path".
6 | #################################################################
7 | import os
8 | import shutil
9 | import torch
10 | import torchvision.transforms as transforms
11 | from data.folder_train import ImageFolder_train
12 | from data.folder_download import ImageFolder_download
13 |
14 | def split_train_test_images(data_dir):
15 | #data_dir = '/home/lab-zhangyabin/project/fine-grained/CUB_200_2011/'
16 | src_dir = os.path.join(data_dir, 'images')
17 | target_dir = os.path.join(data_dir, 'splited_image')
18 | if not os.path.isdir(target_dir):
19 | os.makedirs(target_dir)
20 | print(src_dir)
21 | train_test_split = open(os.path.join(data_dir, 'train_test_split.txt'))
22 | line = train_test_split.readline()
23 | images = open(os.path.join(data_dir, 'images.txt'))
24 | images_line = images.readline()
25 | ##########################
26 | # print(images_line)
27 | image_list = str.split(images_line)
28 | # print(image_list[1])
29 | subclass_name = image_list[1].split('/')[0]
30 | # print(subclass_name)
31 |
32 | # print(line)
33 | class_list = str.split(line)[1]
34 | # print(class_list)
35 |
36 | print('begin to prepare the dataset CUB')
37 | count = 0
38 | while images_line:
39 | print(count)
40 | count = count + 1
41 | image_list = str.split(images_line)
42 | subclass_name = image_list[1].split('/')[0] # get the name of the subclass
43 | # print(image_list[0])
44 | class_label = str.split(line)[1] # get the label of the image
45 | # print(type(int(class_label)))
46 | test_or_train = 'train'
47 | if class_label == '0': # the class belong to the train dataset
48 | test_or_train = 'val'
49 | train_test_dir = os.path.join(target_dir, test_or_train)
50 | if not os.path.isdir(train_test_dir):
51 | os.makedirs(train_test_dir)
52 | subclass_dir = os.path.join(train_test_dir, subclass_name)
53 | if not os.path.isdir(subclass_dir):
54 | os.makedirs(subclass_dir)
55 |
56 | souce_pos = os.path.join(src_dir, image_list[1])
57 | targer_pos = os.path.join(subclass_dir, image_list[1].split('/')[1])
58 | shutil.copyfile(souce_pos, targer_pos)
59 | images_line = images.readline()
60 | line = train_test_split.readline()
61 |
62 | def cub200(args, process_name, part_index):
63 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
64 | std=[0.229, 0.224, 0.225])
65 | if process_name == 'image_classifier' or process_name == 'partnet' or process_name == 'download_proposals':
66 | traindir = os.path.join(args.data_path, 'splited_image/train')
67 | valdir = os.path.join(args.data_path, 'splited_image/val')
68 | if not os.path.isdir(traindir):
69 | print('the cub images are not well splited, split all images into train and val set')
70 | split_train_test_images(args.data_path)
71 | if process_name == 'download_proposals':
72 | print('the image pre-process for process: download proposals is Resize 512 and Center Crop 448')
73 | train_dataset = ImageFolder_download(
74 | root=traindir,
75 | transform=transforms.Compose([
76 | transforms.Resize(512),
77 | transforms.CenterCrop(448),
78 | transforms.ToTensor(),
79 | normalize,
80 | ]),
81 | transform_keep=transforms.Compose([
82 | transforms.Resize(512),
83 | transforms.CenterCrop(448),
84 | transforms.ToTensor()
85 | ]),
86 | dataset_path=args.data_path
87 | )
88 | train_loader = torch.utils.data.DataLoader(
89 | train_dataset, batch_size=args.batch_size_partnet, shuffle=False, num_workers=args.workers,
90 | pin_memory=True, sampler=None
91 | )
92 | val_loader = torch.utils.data.DataLoader(
93 | ImageFolder_download(
94 | root=valdir,
95 | transform=transforms.Compose([
96 | transforms.Resize(512),
97 | transforms.CenterCrop(448),
98 | transforms.ToTensor(),
99 | normalize,
100 | ]),
101 | transform_keep=transforms.Compose([
102 | transforms.Resize(512),
103 | transforms.CenterCrop(448),
104 | transforms.ToTensor(),
105 | ]),
106 | dataset_path=args.data_path),
107 | batch_size=args.batch_size_partnet, shuffle=False,
108 | num_workers=args.workers, pin_memory=True
109 | )
110 | return train_loader, val_loader
111 | elif process_name == 'image_classifier':
112 | print('the image pre-process for process: image_classifier is Resize 512 and Random Crop 448')
113 | train_dataset = ImageFolder_train(
114 | traindir,
115 | transforms.Compose([
116 | transforms.Resize(512),
117 | transforms.RandomCrop(448),
118 | transforms.ToTensor(),
119 | normalize,
120 | ])
121 | )
122 | train_loader = torch.utils.data.DataLoader(
123 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
124 | pin_memory=True, sampler=None
125 | )
126 | val_loader = torch.utils.data.DataLoader(
127 | ImageFolder_train(valdir, transforms.Compose([
128 | transforms.Resize(512),
129 | transforms.CenterCrop(448),
130 | transforms.ToTensor(),
131 | normalize,
132 | ])),
133 | batch_size=args.batch_size, shuffle=False,
134 | num_workers=args.workers, pin_memory=True
135 | )
136 | return train_loader, val_loader
137 | elif process_name == 'partnet':
138 | print('the image pre-process for process: partnet is Resize 512 and Random Crop 448')
139 | train_dataset = ImageFolder_train(
140 | traindir,
141 | transforms.Compose([
142 | transforms.Resize(512),
143 | transforms.RandomCrop(448),
144 | transforms.ToTensor(),
145 | normalize,
146 | ])
147 | )
148 | train_loader = torch.utils.data.DataLoader(
149 | train_dataset, batch_size=args.batch_size_partnet, shuffle=True, num_workers=args.workers,
150 | pin_memory=True, sampler=None
151 | )
152 | val_loader = torch.utils.data.DataLoader(
153 | ImageFolder_train(valdir, transforms.Compose([
154 | transforms.Resize(512),
155 | transforms.CenterCrop(448),
156 | transforms.ToTensor(),
157 | normalize,
158 | ])),
159 | batch_size=args.batch_size_partnet, shuffle=False,
160 | num_workers=args.workers, pin_memory=True
161 | )
162 | return train_loader, val_loader
163 | elif process_name == 'part_classifiers':
164 | traindir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/train/'
165 | valdir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/val/'
166 | print('the image pre-process for process: part_classifier is Resize (448, 448) directly')
167 | train_dataset = ImageFolder_train(
168 | traindir,
169 | transforms.Compose([
170 | transforms.Resize((448, 448)),
171 | transforms.ToTensor(),
172 | normalize,
173 | ])
174 | )
175 | train_loader = torch.utils.data.DataLoader(
176 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
177 | pin_memory=True, sampler=None
178 | )
179 | val_loader = torch.utils.data.DataLoader(
180 | ImageFolder_train(valdir, transforms.Compose([
181 | transforms.Resize((448, 448)),
182 | transforms.ToTensor(),
183 | normalize,
184 | ])),
185 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
186 | pin_memory=True, sampler=None
187 | )
188 | return train_loader, val_loader
189 |
190 |
191 |
--------------------------------------------------------------------------------
/data/flowers.py:
--------------------------------------------------------------------------------
1 | #################################################################
2 | ## The data preparation for flower-102 dataset
3 | ## The flower-102 dataset can be downloaded at:
4 | ## http://www.robots.ox.ac.uk/~vgg/data/flowers/102/
5 | ## All the downloaded related file should be placed at the "args.data_path".
6 | #################################################################
7 | import os
8 | import shutil
9 | import torch
10 | import torchvision.transforms as transforms
11 | from data.folder_train import ImageFolder_train
12 | from data.folder_download import ImageFolder_download
13 | import scipy.io as scio
14 |
15 | def split_train_test_images(data_dir):
16 | image_mat = data_dir + 'imagelabels.mat'
17 | split_mat = data_dir + 'setid.mat'
18 | labels = scio.loadmat(image_mat)['labels'][0]
19 | split = scio.loadmat(split_mat)
20 | train_split = split['trnid'][0]
21 | test_split = split['tstid'][0]
22 | val_split = split['valid'][0]
23 |
24 | for i in range(len(train_split)):
25 | label = labels[train_split[i] - 1]
26 | if train_split[i] > 999:
27 | dir_temp = 'image_0' + str(train_split[i]) + '.jpg'
28 | elif train_split[i] > 99:
29 | dir_temp = 'image_00' + str(train_split[i]) + '.jpg'
30 | elif train_split[i] > 9:
31 | dir_temp = 'image_000' + str(train_split[i]) + '.jpg'
32 | else:
33 | dir_temp = 'image_0000' + str(train_split[i]) + '.jpg'
34 | original_image = data_dir + 'jpg/' + dir_temp
35 | target_dir = data_dir + 'splited_image/train/' + str(label) + '/'
36 | target_image = target_dir + dir_temp
37 | if not os.path.isdir(target_dir):
38 | os.makedirs(target_dir)
39 | shutil.copyfile(original_image, target_image)
40 |
41 | for i in range(len(test_split)):
42 | label = labels[test_split[i] - 1]
43 | if test_split[i] > 999:
44 | dir_temp = 'image_0' + str(test_split[i]) + '.jpg'
45 | elif test_split[i] > 99:
46 | dir_temp = 'image_00' + str(test_split[i]) + '.jpg'
47 | elif test_split[i] > 9:
48 | dir_temp = 'image_000' + str(test_split[i]) + '.jpg'
49 | else:
50 | dir_temp = 'image_0000' + str(test_split[i]) + '.jpg'
51 | original_image = data_dir + 'jpg/' + dir_temp
52 | target_dir = data_dir + 'splited_image/val/' + str(label) + '/'
53 | target_image = target_dir + dir_temp
54 | if not os.path.isdir(target_dir):
55 | os.makedirs(target_dir)
56 | shutil.copyfile(original_image, target_image)
57 |
58 | for i in range(len(val_split)):
59 | label = labels[val_split[i] - 1]
60 | if val_split[i] > 999:
61 | dir_temp = 'image_0' + str(val_split[i]) + '.jpg'
62 | elif val_split[i] > 99:
63 | dir_temp = 'image_00' + str(val_split[i]) + '.jpg'
64 | elif val_split[i] > 9:
65 | dir_temp = 'image_000' + str(val_split[i]) + '.jpg'
66 | else:
67 | dir_temp = 'image_0000' + str(val_split[i]) + '.jpg'
68 | original_image = data_dir + 'jpg/' + dir_temp
69 | target_dir = data_dir + 'splited_image/train/' + str(label) + '/'
70 | target_image = target_dir + dir_temp
71 | if not os.path.isdir(target_dir):
72 | os.makedirs(target_dir)
73 | shutil.copyfile(original_image, target_image)
74 |
75 | def flowers(args, process_name, part_index):
76 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
77 | std=[0.229, 0.224, 0.225])
78 | if process_name == 'image_classifier' or process_name == 'partnet' or process_name == 'download_proposals':
79 | traindir = os.path.join(args.data_path, 'splited_image/train')
80 | valdir = os.path.join(args.data_path, 'splited_image/val')
81 | if not os.path.isdir(traindir):
82 | print('the cub images are not well splited, split all images into train and val set')
83 | split_train_test_images(args.data_path)
84 | if process_name == 'download_proposals':
85 | print('the image pre-process for process: download proposals is Resize 512 and Center Crop 448')
86 | train_dataset = ImageFolder_download(
87 | root=traindir,
88 | transform=transforms.Compose([
89 | transforms.Resize(512),
90 | transforms.CenterCrop(448),
91 | transforms.ToTensor(),
92 | normalize,
93 | ]),
94 | transform_keep=transforms.Compose([
95 | transforms.Resize(512),
96 | transforms.CenterCrop(448),
97 | transforms.ToTensor()
98 | ]),
99 | dataset_path=args.data_path
100 | )
101 | train_loader = torch.utils.data.DataLoader(
102 | train_dataset, batch_size=args.batch_size_partnet, shuffle=False, num_workers=args.workers,
103 | pin_memory=True, sampler=None
104 | )
105 | val_loader = torch.utils.data.DataLoader(
106 | ImageFolder_download(
107 | root=valdir,
108 | transform=transforms.Compose([
109 | transforms.Resize(512),
110 | transforms.CenterCrop(448),
111 | transforms.ToTensor(),
112 | normalize,
113 | ]),
114 | transform_keep=transforms.Compose([
115 | transforms.Resize(512),
116 | transforms.CenterCrop(448),
117 | transforms.ToTensor(),
118 | ]),
119 | dataset_path=args.data_path),
120 | batch_size=args.batch_size_partnet, shuffle=False,
121 | num_workers=args.workers, pin_memory=True
122 | )
123 | return train_loader, val_loader
124 | elif process_name == 'image_classifier':
125 | print('the image pre-process for process: image_classifier is Resize 512 and Random Crop 448')
126 | train_dataset = ImageFolder_train(
127 | traindir,
128 | transforms.Compose([
129 | transforms.Resize(512),
130 | transforms.RandomCrop(448),
131 | transforms.ToTensor(),
132 | normalize,
133 | ])
134 | )
135 | train_loader = torch.utils.data.DataLoader(
136 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
137 | pin_memory=True, sampler=None
138 | )
139 | val_loader = torch.utils.data.DataLoader(
140 | ImageFolder_train(valdir, transforms.Compose([
141 | transforms.Resize(512),
142 | transforms.CenterCrop(448),
143 | transforms.ToTensor(),
144 | normalize,
145 | ])),
146 | batch_size=args.batch_size, shuffle=False,
147 | num_workers=args.workers, pin_memory=True
148 | )
149 | return train_loader, val_loader
150 | elif process_name == 'partnet':
151 | print('the image pre-process for process: partnet is Resize 512 and Random Crop 448')
152 | train_dataset = ImageFolder_train(
153 | traindir,
154 | transforms.Compose([
155 | transforms.Resize(512),
156 | transforms.RandomCrop(448),
157 | transforms.ToTensor(),
158 | normalize,
159 | ])
160 | )
161 | train_loader = torch.utils.data.DataLoader(
162 | train_dataset, batch_size=args.batch_size_partnet, shuffle=True, num_workers=args.workers,
163 | pin_memory=True, sampler=None
164 | )
165 | val_loader = torch.utils.data.DataLoader(
166 | ImageFolder_train(valdir, transforms.Compose([
167 | transforms.Resize(512),
168 | transforms.CenterCrop(448),
169 | transforms.ToTensor(),
170 | normalize,
171 | ])),
172 | batch_size=args.batch_size_partnet, shuffle=False,
173 | num_workers=args.workers, pin_memory=True
174 | )
175 | return train_loader, val_loader
176 | elif process_name == 'part_classifiers':
177 | traindir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/train/'
178 | valdir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/val/'
179 | print('the image pre-process for process: part_classifier is Resize (448, 448) directly')
180 | train_dataset = ImageFolder_train(
181 | traindir,
182 | transforms.Compose([
183 | transforms.Resize((448, 448)),
184 | transforms.ToTensor(),
185 | normalize,
186 | ])
187 | )
188 | train_loader = torch.utils.data.DataLoader(
189 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
190 | pin_memory=True, sampler=None
191 | )
192 | val_loader = torch.utils.data.DataLoader(
193 | ImageFolder_train(valdir, transforms.Compose([
194 | transforms.Resize((448, 448)),
195 | transforms.ToTensor(),
196 | normalize,
197 | ])),
198 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
199 | pin_memory=True, sampler=None
200 | )
201 | return train_loader, val_loader
202 |
203 |
204 |
--------------------------------------------------------------------------------
/data/folder_download.py:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | ## This file is modified from the original folder.py for the process of detected parts downloading. The return values:
3 | ##
4 | ## img_input: same as the 'img' in the folder.py. It indicates the inputs for the model.
5 | ## target: same as the 'target' in the folder.py. It is a int number representing the label of the img_input.
6 | ## img_keep: Image with the same size as the img_input, but it is not normalized. Detected parts are cropped from it.
7 | ## target_loss: The one-hot label for the image. It is used in the Binary Cross-entropy loss.
8 | ## path_image: The path of the loaded images.
9 | ###################################################################################
10 |
11 | import torch.utils.data as data
12 | import torch
13 | from PIL import Image
14 | import os
15 | import os.path
16 | from PIL import ImageFile
17 | ImageFile.LOAD_TRUNCATED_IMAGES = True
18 |
19 | print('the folder.py has been changed')
20 |
21 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
22 |
23 |
24 | def is_image_file(filename):
25 | """Checks if a file is an image.
26 |
27 | Args:
28 | filename (string): path to a file
29 |
30 | Returns:
31 | bool: True if the filename ends with a known image extension
32 | """
33 | filename_lower = filename.lower()
34 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
35 |
36 |
37 | def find_classes(dir):
38 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
39 | classes.sort()
40 | class_to_idx = {classes[i]: i for i in range(len(classes))}
41 | return classes, class_to_idx
42 |
43 |
44 | def make_dataset(dir, class_to_idx):
45 | images = []
46 | dir = os.path.expanduser(dir)
47 | for target in sorted(os.listdir(dir)):
48 | d = os.path.join(dir, target)
49 | if not os.path.isdir(d):
50 | continue
51 |
52 | for root, _, fnames in sorted(os.walk(d)):
53 | for fname in sorted(fnames):
54 | if is_image_file(fname):
55 | path = os.path.join(root, fname)
56 | item = (path, class_to_idx[target])
57 | images.append(item)
58 |
59 | return images
60 |
61 |
62 | def pil_loader(path):
63 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
64 | with open(path, 'rb') as f:
65 | img = Image.open(f)
66 | return img.convert('RGB')
67 |
68 |
69 | def accimage_loader(path):
70 | import accimage
71 | try:
72 | return accimage.Image(path)
73 | except IOError:
74 | # Potentially a decoding problem, fall back to PIL.Image
75 | return pil_loader(path)
76 |
77 |
78 | def default_loader(path):
79 | from torchvision import get_image_backend
80 | if get_image_backend() == 'accimage':
81 | return accimage_loader(path)
82 | else:
83 | return pil_loader(path)
84 |
85 |
86 | class ImageFolder_download(data.Dataset):
87 | """A generic data loader where the images are arranged in this way: ::
88 |
89 | root/dog/xxx.png
90 | root/dog/xxy.png
91 | root/dog/xxz.png
92 |
93 | root/cat/123.png
94 | root/cat/nsdf3.png
95 | root/cat/asd932_.png
96 |
97 | Args:
98 | root (string): Root directory path.
99 | transform (callable, optional): A function/transform that takes in an PIL image
100 | and returns a transformed version. E.g, ``transforms.RandomCrop``
101 | target_transform (callable, optional): A function/transform that takes in the
102 | target and transforms it.
103 | loader (callable, optional): A function to load an image given its path.
104 |
105 | Attributes:
106 | classes (list): List of the class names.
107 | class_to_idx (dict): Dict with items (class_name, class_index).
108 | imgs (list): List of (image path, class_index) tuples
109 | """
110 |
111 | def __init__(self, root, transform=None, transform_keep = None,target_transform=None,
112 | loader=default_loader, dataset_path = None):
113 | classes, class_to_idx = find_classes(root)
114 | imgs = make_dataset(root, class_to_idx)
115 | if len(imgs) == 0:
116 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
117 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
118 |
119 | self.root = root
120 | self.imgs = imgs
121 | self.classes = classes
122 | self.class_to_idx = class_to_idx
123 | self.transform = transform
124 | self.transform_keep = transform_keep
125 | self.target_transform = target_transform
126 | self.loader = loader
127 | self.dataset_path = dataset_path
128 | print('the classes number is', len(self.classes))
129 | print('the images number is', len(self.imgs))
130 |
131 | def __getitem__(self, index):
132 | """
133 | Args:
134 | index (int): Index
135 |
136 | Returns:
137 | tuple: (image, target) where target is class_index of the target class.
138 | """
139 | path, target = self.imgs[index]
140 | img = self.loader(path)
141 | # print(target)
142 | if self.transform is not None:
143 | img_input = self.transform(img)
144 | if self.transform_keep is not None:
145 | img_keep = self.transform_keep(img)
146 | if self.target_transform is not None:
147 | target = self.target_transform(target)
148 | target_loss = torch.Tensor(len(self.classes)).fill_(0)
149 | target_loss[target] = 1
150 | if self.dataset_path is not None:
151 | path_image = path.replace(self.dataset_path, '')
152 | return img_input, img_keep, target, target_loss, path_image
153 |
154 | def __len__(self):
155 | return len(self.imgs)
156 |
157 | def __repr__(self):
158 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
159 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
160 | fmt_str += ' Root Location: {}\n'.format(self.root)
161 | tmp = ' Transforms (if any): '
162 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
163 | tmp = ' Target Transforms (if any): '
164 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
165 | return fmt_str
166 |
--------------------------------------------------------------------------------
/data/folder_train.py:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | ## This file is modified from the original folder.py for the process of PartNet training. The return values:
3 | ##
4 | ## img_input: same as the 'img' in the folder.py. It indicates the inputs for the model.
5 | ## target: same as the 'target' in the folder.py. It is a int number representing the label of the img_input.
6 | ## target_loss: The one-hot label for the image. It is used in the Binary Cross-entropy loss.
7 | ##
8 | ###################################################################################
9 | import torch.utils.data as data
10 | import torch
11 | from PIL import Image
12 | import os
13 | import os.path
14 | from PIL import ImageFile
15 | ImageFile.LOAD_TRUNCATED_IMAGES = True
16 |
17 | print('the folder.py has been changed')
18 |
19 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
20 |
21 |
22 | def is_image_file(filename):
23 | """Checks if a file is an image.
24 |
25 | Args:
26 | filename (string): path to a file
27 |
28 | Returns:
29 | bool: True if the filename ends with a known image extension
30 | """
31 | filename_lower = filename.lower()
32 | return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)
33 |
34 |
35 | def find_classes(dir):
36 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
37 | classes.sort()
38 | class_to_idx = {classes[i]: i for i in range(len(classes))}
39 | return classes, class_to_idx
40 |
41 |
42 | def make_dataset(dir, class_to_idx):
43 | images = []
44 | dir = os.path.expanduser(dir)
45 | for target in sorted(os.listdir(dir)):
46 | d = os.path.join(dir, target)
47 | if not os.path.isdir(d):
48 | continue
49 |
50 | for root, _, fnames in sorted(os.walk(d)):
51 | for fname in sorted(fnames):
52 | if is_image_file(fname):
53 | path = os.path.join(root, fname)
54 | item = (path, class_to_idx[target])
55 | images.append(item)
56 |
57 | return images
58 |
59 |
60 | def pil_loader(path):
61 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
62 | with open(path, 'rb') as f:
63 | img = Image.open(f)
64 | return img.convert('RGB')
65 |
66 |
67 | def accimage_loader(path):
68 | import accimage
69 | try:
70 | return accimage.Image(path)
71 | except IOError:
72 | # Potentially a decoding problem, fall back to PIL.Image
73 | return pil_loader(path)
74 |
75 |
76 | def default_loader(path):
77 | from torchvision import get_image_backend
78 | if get_image_backend() == 'accimage':
79 | return accimage_loader(path)
80 | else:
81 | return pil_loader(path)
82 |
83 |
84 | class ImageFolder_train(data.Dataset):
85 | """A generic data loader where the images are arranged in this way: ::
86 |
87 | root/dog/xxx.png
88 | root/dog/xxy.png
89 | root/dog/xxz.png
90 |
91 | root/cat/123.png
92 | root/cat/nsdf3.png
93 | root/cat/asd932_.png
94 |
95 | Args:
96 | root (string): Root directory path.
97 | transform (callable, optional): A function/transform that takes in an PIL image
98 | and returns a transformed version. E.g, ``transforms.RandomCrop``
99 | target_transform (callable, optional): A function/transform that takes in the
100 | target and transforms it.
101 | loader (callable, optional): A function to load an image given its path.
102 |
103 | Attributes:
104 | classes (list): List of the class names.
105 | class_to_idx (dict): Dict with items (class_name, class_index).
106 | imgs (list): List of (image path, class_index) tuples
107 | """
108 |
109 | def __init__(self, root, transform=None, target_transform=None,
110 | loader=default_loader):
111 | classes, class_to_idx = find_classes(root)
112 | imgs = make_dataset(root, class_to_idx)
113 | if len(imgs) == 0:
114 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
115 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
116 |
117 | self.root = root
118 | self.imgs = imgs
119 | self.classes = classes
120 | self.class_to_idx = class_to_idx
121 | self.transform = transform
122 | self.target_transform = target_transform
123 | self.loader = loader
124 | print('the classes number is', len(self.classes))
125 | print('the images number is', len(self.imgs))
126 |
127 | def __getitem__(self, index):
128 | """
129 | Args:
130 | index (int): Index
131 |
132 | Returns:
133 | tuple: (image, target) where target is class_index of the target class.
134 | """
135 | path, target = self.imgs[index]
136 | img = self.loader(path)
137 | # print(target)
138 | if self.transform is not None:
139 | img = self.transform(img)
140 | if self.target_transform is not None:
141 | target = self.target_transform(target)
142 | target_loss = torch.Tensor(len(self.classes)).fill_(0)
143 | target_loss[target] = 1
144 | return img, target, target_loss
145 |
146 | def __len__(self):
147 | return len(self.imgs)
148 |
149 | def __repr__(self):
150 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
151 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
152 | fmt_str += ' Root Location: {}\n'.format(self.root)
153 | tmp = ' Transforms (if any): '
154 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
155 | tmp = ' Target Transforms (if any): '
156 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
157 | return fmt_str
158 |
--------------------------------------------------------------------------------
/data/prepare_data.py:
--------------------------------------------------------------------------------
1 | from data.cub200 import cub200
2 | from data.stanford_dogs import stanford_dogs
3 | from data.flowers import flowers
4 | from data.cars import cars
5 | from data.aircrafts import aircrafts
6 |
7 | def generate_dataloader(args, process_name, part_index=-1):
8 | print('the required dataset is', args.dataset)
9 | if args.dataset == 'cub200':
10 | train_loader, val_loader = cub200(args, process_name, part_index)
11 | elif args.dataset == 'stanford_dogs':
12 | raise ValueError('the required dataset is not prepared')
13 | # train_loader, val_loader = stanford_dogs(args, process_name, part_index)
14 | elif args.dataset == 'flowers':
15 | train_loader, val_loader = flowers(args, process_name, part_index)
16 | elif args.dataset == 'cars':
17 | raise ValueError('the required dataset is not prepared')
18 | # train_loader, val_loader = cars(args, process_name, part_index)
19 | elif args.dataset == 'aircrafts':
20 | raise ValueError('the required dataset is not prepared')
21 | # train_loader, val_loader = aircrafts(args, process_name, part_index)
22 | else:
23 | raise ValueError('the required dataset is not prepared')
24 |
25 | return train_loader, val_loader
26 |
27 |
--------------------------------------------------------------------------------
/data/stanford_dogs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import torch
4 | import torchvision.transforms as transforms
5 | from data.folder_train import ImageFolder_train
6 | from data.folder_download import ImageFolder_download
7 |
8 | def split_train_test_images(data_dir):
9 | #data_dir = '/data/Stanford_dog/'
10 | raise ValueError('the process of generate splited training and test images is empty')
11 |
12 | def stanford_dogs(args, process_name, part_index):
13 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
14 | std=[0.229, 0.224, 0.225])
15 | if process_name == 'image_classifier' or process_name == 'partnet' or process_name == 'download_proposals':
16 | traindir = os.path.join(args.data_path, 'splited_image/train')
17 | valdir = os.path.join(args.data_path, 'splited_image/val')
18 | if not os.path.isdir(traindir):
19 | print('the cub images are not well splited, split all images into train and val set')
20 | split_train_test_images(args.data_path)
21 | if process_name == 'download_proposals':
22 | print('the image pre-process for process: download proposals is Resize 512 and Center Crop 448')
23 | train_dataset = ImageFolder_download(
24 | root=traindir,
25 | transform=transforms.Compose([
26 | transforms.Resize(512),
27 | transforms.CenterCrop(448),
28 | transforms.ToTensor(),
29 | normalize,
30 | ]),
31 | transform_keep=transforms.Compose([
32 | transforms.Resize(512),
33 | transforms.CenterCrop(448),
34 | transforms.ToTensor()
35 | ]),
36 | dataset_path=args.data_path
37 | )
38 | train_loader = torch.utils.data.DataLoader(
39 | train_dataset, batch_size=args.batch_size_partnet, shuffle=False, num_workers=args.workers,
40 | pin_memory=True, sampler=None
41 | )
42 | val_loader = torch.utils.data.DataLoader(
43 | ImageFolder_download(
44 | root=valdir,
45 | transform=transforms.Compose([
46 | transforms.Resize(512),
47 | transforms.CenterCrop(448),
48 | transforms.ToTensor(),
49 | normalize,
50 | ]),
51 | transform_keep=transforms.Compose([
52 | transforms.Resize(512),
53 | transforms.CenterCrop(448),
54 | transforms.ToTensor(),
55 | ]),
56 | dataset_path=args.data_path),
57 | batch_size=args.batch_size_partnet, shuffle=False,
58 | num_workers=args.workers, pin_memory=True
59 | )
60 | return train_loader, val_loader
61 | elif process_name == 'image_classifier':
62 | print('the image pre-process for process: image_classifier is Resize 512 and Random Crop 448')
63 | train_dataset = ImageFolder_train(
64 | traindir,
65 | transforms.Compose([
66 | transforms.Resize(512),
67 | transforms.RandomCrop(448),
68 | transforms.ToTensor(),
69 | normalize,
70 | ])
71 | )
72 | train_loader = torch.utils.data.DataLoader(
73 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
74 | pin_memory=True, sampler=None
75 | )
76 | val_loader = torch.utils.data.DataLoader(
77 | ImageFolder_train(valdir, transforms.Compose([
78 | transforms.Resize(512),
79 | transforms.CenterCrop(448),
80 | transforms.ToTensor(),
81 | normalize,
82 | ])),
83 | batch_size=args.batch_size, shuffle=False,
84 | num_workers=args.workers, pin_memory=True
85 | )
86 | return train_loader, val_loader
87 | elif process_name == 'partnet':
88 | print('the image pre-process for process: partnet is Resize 512 and Random Crop 448')
89 | train_dataset = ImageFolder_train(
90 | traindir,
91 | transforms.Compose([
92 | transforms.Resize(512),
93 | transforms.RandomCrop(448),
94 | transforms.ToTensor(),
95 | normalize,
96 | ])
97 | )
98 | train_loader = torch.utils.data.DataLoader(
99 | train_dataset, batch_size=args.batch_size_partnet, shuffle=True, num_workers=args.workers,
100 | pin_memory=True, sampler=None
101 | )
102 | val_loader = torch.utils.data.DataLoader(
103 | ImageFolder_train(valdir, transforms.Compose([
104 | transforms.Resize(512),
105 | transforms.CenterCrop(448),
106 | transforms.ToTensor(),
107 | normalize,
108 | ])),
109 | batch_size=args.batch_size_partnet, shuffle=False,
110 | num_workers=args.workers, pin_memory=True
111 | )
112 | return train_loader, val_loader
113 | elif process_name == 'part_classifiers':
114 | traindir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/train/'
115 | valdir = args.data_path + 'PartNet' + args.arch + '/part_' + str(part_index) + '/splited_image/val/'
116 | print('the image pre-process for process: part_classifier is Resize (448, 448) directly')
117 | train_dataset = ImageFolder_train(
118 | traindir,
119 | transforms.Compose([
120 | transforms.Resize((448, 448)),
121 | transforms.ToTensor(),
122 | normalize,
123 | ])
124 | )
125 | train_loader = torch.utils.data.DataLoader(
126 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers,
127 | pin_memory=True, sampler=None
128 | )
129 | val_loader = torch.utils.data.DataLoader(
130 | ImageFolder_train(valdir, transforms.Compose([
131 | transforms.Resize((448, 448)),
132 | transforms.ToTensor(),
133 | normalize,
134 | ])),
135 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
136 | pin_memory=True, sampler=None
137 | )
138 | return train_loader, val_loader
139 |
140 |
141 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | ##############################################################################
2 | # Pytorch PartNet
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Yabin Zhang
5 | ##############################################################################
6 |
7 | from opts import opts # The options for the project
8 | from process import Process1_Image_Classifier
9 | from process import Process2_PartNet
10 | from process import Process3_Download_Proposals
11 | from process import Process4_Part_Classifiers
12 | from process import Process5_Final_Result
13 | import os
14 | def main():
15 | global args
16 | args = opts()
17 | if args.test_only:
18 | os.environ['CUDA_VISIBLE_DEVICES'] = "0"
19 | print('download the detected part images in the disk')
20 | Process3_Download_Proposals(args) ### We use hook in this process, so only one gpu should be used.
21 | os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3,4,5,6,7"
22 | print('only test the final results, please make sure all the needed files is prepared.')
23 | Process5_Final_Result(args)
24 | else:
25 | # fine-tune on the Fine-Grained dataset an ImageNet pre-trained model
26 | os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3,4,5,6,7"
27 | Process1_Image_Classifier(args)
28 | # Train the PartNet based on the above fine-tuned model on the Fine-Grained dataset
29 | Process2_PartNet(args)
30 | # Download the proposals detected by the PartNet
31 | os.environ['CUDA_VISIBLE_DEVICES'] = "0"
32 | Process3_Download_Proposals(args) ### We use hook in this process, so only one gpu should be used.
33 | os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3,4,5,6,7"
34 | # Train individual Classifier for each part
35 | Process4_Part_Classifiers(args)
36 | # Averaging the probabilities of the above models (image level + part 1-2-3 + ParNet) for final prediction
37 | Process5_Final_Result(args)
38 |
39 | if __name__ == '__main__':
40 | main()
41 |
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/models/DPP/Readme.md:
--------------------------------------------------------------------------------
1 | # Discretized Part Proposals (DPP)
2 | It is a C implementation of the DPP module
3 | ## Compile Process
4 | Requirement: same with: https://github.com/jwyang/faster-rcnn.pytorch
5 |
6 | Command to compile:
7 | ```python
8 | python build.py
9 | ```
10 |
11 |
--------------------------------------------------------------------------------
/models/DPP/_ext/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/_ext/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/DPP/_ext/dpp/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.ffi import _wrap_function
3 | from ._dpp import lib as _lib, ffi as _ffi
4 |
5 | __all__ = []
6 | def _import_symbols(locals):
7 | for symbol in dir(_lib):
8 | fn = getattr(_lib, symbol)
9 | if callable(fn):
10 | locals[symbol] = _wrap_function(fn, _ffi)
11 | else:
12 | locals[symbol] = fn
13 | __all__.append(symbol)
14 |
15 | _import_symbols(locals())
16 |
--------------------------------------------------------------------------------
/models/DPP/_ext/dpp/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/_ext/dpp/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/DPP/_ext/dpp/_dpp.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/_ext/dpp/_dpp.so
--------------------------------------------------------------------------------
/models/DPP/build.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import torch
4 | from torch.utils.ffi import create_extension
5 |
6 |
7 | sources = ['src/dpp.c']
8 | headers = ['src/dpp.h']
9 | defines = []
10 | with_cuda = False
11 |
12 | # if torch.cuda.is_available():
13 | # print('Including CUDA code.')
14 | # sources += ['src/roi_pooling_cuda.c']
15 | # headers += ['src/roi_pooling_cuda.h']
16 | # defines += [('WITH_CUDA', None)]
17 | # with_cuda = True
18 |
19 | this_file = os.path.dirname(os.path.realpath(__file__))
20 | print(this_file)
21 | #extra_objects = ['src/dpp.cu.o']
22 | #extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
23 |
24 | ffi = create_extension(
25 | '_ext.dpp',
26 | headers=headers,
27 | sources=sources,
28 | define_macros=defines,
29 | relative_to=__file__,
30 | with_cuda=with_cuda # ,
31 | # extra_objects=extra_objects
32 | )
33 |
34 | if __name__ == '__main__':
35 | ffi.build()
36 |
--------------------------------------------------------------------------------
/models/DPP/functions/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/functions/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/DPP/functions/__pycache__/dpp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/functions/__pycache__/dpp.cpython-36.pyc
--------------------------------------------------------------------------------
/models/DPP/functions/dpp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from .._ext import dpp
4 | import pdb
5 | import time
6 |
7 | class DPPFunction(Function):
8 | def __init__(ctx, square_size, proposals_per_square, proposals_per_image, spatial_scale):
9 | ctx.square_size = square_size
10 | ctx.proposals_per_square = proposals_per_square
11 | ctx.spatital_scale = spatial_scale
12 | ctx.output = torch.cuda.FloatTensor()
13 | ctx.proposals_per_image = proposals_per_image
14 | ctx.box_plan = torch.Tensor([[-1, -1, 1, 1],
15 | [-2, -2, 2, 2],
16 | [-1, -3, 1, 3],
17 | [-3, -1, 3, 1],
18 | [-3, -3, 3, 3],
19 | [-2, -4, 2, 4],
20 | [-4, -2, 4, 2],
21 | [-4, -4, 4, 4],
22 | [-3, -5, 3, 5],
23 | [-5, -3, 5, 3], # 10
24 | [-5, -5, 5, 5],
25 | [-4, -7, 4, 7],
26 | [-7, -4, 7, 4],
27 | [-6, -6, 6, 6],
28 | [-4, -8, 4, 8], # 15
29 | [-8, -4, 8, 4],
30 | [-7, -7, 7, 7],
31 | [-5, -10, 5, 10],
32 | [-10, -5, 10, 5],
33 | [-8, -8, 8, 8], # 20
34 | [-6, -11, 6, 11],
35 | [-11, -6, 11, 6],
36 | [-9, -9, 9, 9],
37 | [-7, -12, 7, 12],
38 | [-12, -7, 12, 12],
39 | [-10, -10, 10, 10],
40 | [-7, -14, 7, 14],
41 | [-14, -7, 14, 7]])
42 | ctx.square_num = int(28 / ctx.square_size) ### it should be 7
43 | calculate_num = ctx.square_num * ctx.square_num * proposals_per_square
44 | if ctx.proposals_per_image != calculate_num:
45 | raise ValueError('the number generated by dpp should be', calculate_num)
46 | if ctx.square_size != 4 and ctx.square_size != 7:
47 | raise ValueError('the number of the square for one line should be 4 or 7, but you define:',
48 | ctx.square_size)
49 | if ctx.proposals_per_square > 28:
50 | raise ValueError('the proposals number for each image should below 28')
51 | if ctx.spatital_scale != 16:
52 | raise ValueError('the spatial scale should be 16, but you define is:', ctx.spatital_scale)
53 |
54 | def forward(ctx, features):
55 | #timer= time.time()
56 | batch_size, num_channels, data_height, data_width = features.size()
57 | num_rois = batch_size * ctx.proposals_per_image
58 | output = torch.FloatTensor(num_rois, 5).fill_(1)
59 | histogram = torch.FloatTensor(data_height, data_width).fill_(0)
60 | score_sum = torch.FloatTensor(data_height, data_width).fill_(0)
61 |
62 | if features.is_cuda:
63 | # print('the DPP is in the cpu form')
64 | dpp.dpp_forward(ctx.square_size, ctx.proposals_per_square, ctx.proposals_per_image, ctx.spatital_scale,
65 | ctx.box_plan, histogram, score_sum, output, features.cpu())
66 | # print('the dpp forward time is:', time.time() - timer)
67 | # print('the output of dpp is:', output)
68 | else:
69 | dpp.dpp_forward(ctx.square_size, ctx.proposals_per_square, ctx.proposals_per_image, ctx.spatital_scale,
70 | ctx.box_plan, histogram, score_sum, output, features)
71 | # print('the dpp forward time is:', time.time() - timer)
72 | # print('the output of dpp is:', output)
73 |
74 | return output.cuda()
75 |
76 | def backward(ctx, grad_output):
77 |
78 | return None
79 |
--------------------------------------------------------------------------------
/models/DPP/functions/test_dpp.py:
--------------------------------------------------------------------------------
1 | from .._ext import dpp
2 | import torch
3 | output = torch.FloatTensor(2744, 5).fill_(0)
4 | histogram = torch.FloatTensor(28, 28).fill_(0)
5 | score_sum = torch.FloatTensor(28, 28).fill_(0)
6 | features = torch.randn(2, 512, 28, 28)
7 | box_plan = torch.Tensor([[-1, -1, 1, 1],
8 | [-2, -2, 2, 2],
9 | [-1, -3, 1, 3],
10 | [-3, -1, 3, 1],
11 | [-3, -3, 3, 3],
12 | [-2, -4, 2, 4],
13 | [-4, -2, 4, 2],
14 | [-4, -4, 4, 4],
15 | [-3, -5, 3, 5],
16 | [-5, -3, 5, 3], # 10
17 | [-5, -5, 5, 5],
18 | [-4, -7, 4, 7],
19 | [-7, -4, 7, 4],
20 | [-6, -6, 6, 6],
21 | [-4, -8, 4, 8], # 15
22 | [-8, -4, 8, 4],
23 | [-7, -7, 7, 7],
24 | [-5, -10, 5, 10],
25 | [-10, -5, 10, 5],
26 | [-8, -8, 8, 8], # 20
27 | [-6, -11, 6, 11],
28 | [-11, -6, 11, 6],
29 | [-9, -9, 9, 9],
30 | [-7, -12, 7, 12],
31 | [-12, -7, 12, 12],
32 | [-10, -10, 10, 10],
33 | [-7, -14, 7, 14],
34 | [-14, -7, 14, 7]])
35 | dpp.dpp_forward(4, 28, 1372, 16,
36 | box_plan, histogram, score_sum, output, features)
37 |
38 | print(output)
--------------------------------------------------------------------------------
/models/DPP/modules/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/modules/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/DPP/modules/__pycache__/dpp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/DPP/modules/__pycache__/dpp.cpython-36.pyc
--------------------------------------------------------------------------------
/models/DPP/modules/dpp.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.module import Module
2 | from ..functions.dpp import DPPFunction
3 |
4 |
5 | class _DPP(Module):
6 | def __init__(self, square_size, proposals_per_square, proposals_per_image, spatial_scale):
7 | super(_DPP, self).__init__()
8 |
9 | self.square_size = int(square_size)
10 | self.proposals_per_square = int(proposals_per_square)
11 | self.proposals_per_image = int(proposals_per_image)
12 | self.spatial_scale = int(spatial_scale)
13 |
14 | def forward(self, features):
15 | return DPPFunction(self.square_size, self.proposals_per_square, self.proposals_per_image, self.spatial_scale)(features)
16 |
--------------------------------------------------------------------------------
/models/DPP/src/dpp.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | int dpp_forward(int square_size, int proposals_per_square, int proposals_per_image, int spatital_scale, THFloatTensor * box_plan,
6 | THFloatTensor * histogram, THFloatTensor * score_sum, THFloatTensor * output, THFloatTensor * features)
7 | {
8 | // Grab the input tensor
9 | float * box_plan_flat = THFloatTensor_data(box_plan);
10 | float * histogram_flat = THFloatTensor_data(histogram);
11 | float * score_sum_flat = THFloatTensor_data(score_sum);
12 | float * output_flat = THFloatTensor_data(output);
13 | float * features_flat = THFloatTensor_data(features);
14 |
15 | int batch_size = THFloatTensor_size(features, 0);
16 |
17 | int data_height = THFloatTensor_size(features, 2);
18 | // data width
19 | int data_width = THFloatTensor_size(features, 3);
20 | // Number of channels
21 | int num_channels = THFloatTensor_size(features, 1);
22 |
23 | int b, c, h, w, max_h, max_w, max_value, p;
24 | int index_proposals = 0;
25 |
26 | for (b = 0; b < batch_size; b++)
27 | {
28 | // printf("here is the batch of %d", b);
29 | THFloatStorage_fill(THFloatTensor_storage(histogram), 0);
30 | // printf("here is the line after fill lalalla ");
31 | THFloatStorage_fill(THFloatTensor_storage(score_sum), 0);
32 |
33 | for (c =0; c < num_channels; c++)
34 | {
35 | // printf("histogram of %d", c);
36 | //find the max position for each channel, and add 1 to the histogram
37 | const int index_features = b * data_height * data_width * num_channels + c * data_height * data_width;
38 | max_w = 0;
39 | max_h = 0;
40 | max_value = 0;
41 |
42 | for (h=0; h < data_height; h++)
43 | {
44 | for (w=0; w< data_width; w++)
45 | {
46 | const int index_histogram = h * data_width + w;
47 | if(features_flat[index_features + index_histogram] > max_value)
48 | {
49 | max_value = features_flat[index_features + index_histogram];
50 | max_w = w;
51 | max_h = h;
52 | }
53 | }
54 | }
55 | histogram_flat[max_h * data_width + max_w] += 1;
56 | }
57 |
58 | /// calculate the score sum
59 | for (c =0; c < num_channels; c++)
60 | {
61 | // printf("score sum of %d", c);
62 | //add values at each position of the feature to the score sum
63 | const int index_features = b * data_height * data_width * num_channels + c * data_height * data_width;
64 |
65 | for (h=0; h < data_height; h++)
66 | {
67 | for (w=0; w< data_width; w++)
68 | {
69 | const int index_score_sum = h *data_width + w;
70 | score_sum_flat[index_score_sum] += features_flat[index_features + index_score_sum];
71 | }
72 | }
73 | }
74 |
75 | //
76 | int sub_h, sub_w;
77 | for (h=0; h 28:
203 | # raise ValueError('the proposals number for each image should below 28')
204 | #
205 | #
206 | # def forward(self, features):
207 | # timer = time.time()
208 | # batch_size, num_channels, data_height, data_width = features.size()
209 | # features_float = features.float()
210 | #
211 | # num_rois = batch_size * self.proposals_per_image
212 | # output = torch.Tensor(num_rois, 5).fill_(0)
213 | #
214 | # if self.spatital_scale != 16:
215 | # raise ValueError('the spatial scale should be 16, but you define is:', self.spatital_scale)
216 | # proposals_index = 0
217 | #
218 | # for i in range(batch_size):
219 | # roi_memory = torch.Tensor(28,28).zero_()
220 | # roi_score = torch.sum(features_float[i], 0)
221 | # output[i*self.proposals_per_image: (i+1)*self.proposals_per_image, 0] = i
222 | # for j in range(num_channels):
223 | # one_channel = features_float[i][j]
224 | # max_per_row, max_column_per_row = torch.max(one_channel, 1)
225 | # max_p, max_row = torch.max(max_per_row, 0)
226 | # x_center = max_row[0]
227 | # max_col = max_column_per_row[x_center]
228 | # y_center = max_col
229 | # roi_memory[x_center][y_center] = roi_memory[x_center][y_center] + 1
230 | # for x in range(self.square_num):
231 | # for y in range(self.square_num):
232 | # temp = roi_memory[x * self.square_size:(x+1)*self.square_size, y*self.square_size: (y+1)*self.square_size]
233 | # max_per_row, max_column_per_row = torch.max(temp, 1)
234 | # max_p, max_row = torch.max(max_per_row, 0)
235 | # x_center = max_row[0]
236 | # max_col = max_column_per_row[x_center]
237 | # y_center = max_col
238 | # find_repeat = torch.eq(temp, temp[x_center][y_center])
239 | # if torch.sum(find_repeat) > 1:
240 | # score_temp = roi_score[x * self.square_size:(x+1)*self.square_size, y*self.square_size:(y+1)*self.square_size]
241 | # max_per_row, max_column_per_row = torch.max(score_temp, 1)
242 | # max_p, max_row = torch.max(max_per_row, 0)
243 | # x_center = max_row[0]
244 | # max_col = max_column_per_row[x_center]
245 | # y_center = max_col
246 | # x_center = x_center + x * self.square_size
247 | # y_center = y_center + y * self.square_size
248 | # for k in range(self.proposals_per_square):
249 | # output[proposals_index][1:5] = torch.Tensor([x_center +self.box_plan[k][0], y_center+self.box_plan[k][1], x_center+self.box_plan[k][2], y_center+self.box_plan[k][3]])
250 | # proposals_index = proposals_index + 1
251 | # output[:, 1:5].mul_(self.spatital_scale)
252 | # output[:, 1:5].clamp_(0, 448)
253 | # #self.output.resize_(output.size()).copy_(output, non_blocking=False)
254 | # print('the dpp time is:', time.time() - timer)
255 | # return output.cuda()
256 | #
257 | #
258 | # def backward(self, grad_output):
259 | #
260 | # return None
261 |
262 | # class DPP(Module):
263 | # def __init__(self, square_size, proposals_per_square, proposals_per_image, spatial_scale):
264 | # super(DPP, self).__init__()
265 | #
266 | # self.square_size = int(square_size)
267 | # self.proposals_per_square = int(proposals_per_square)
268 | # self.proposals_per_image = int(proposals_per_image)
269 | # self.spatial_scale = int(spatial_scale)
270 | #
271 | # def forward(self, features):
272 | # return DPPFunction(self.square_size, self.proposals_per_square, self.proposals_per_image, self.spatial_scale)(features)
273 |
274 | class Classification_stream(Module):
275 | def __init__(self, proposal_num, num_classes):
276 | super(Classification_stream, self).__init__()
277 | self.proposals_num = proposal_num
278 | self.num_classes = num_classes
279 | self.classifier = nn.Sequential(
280 | nn.Linear(512 * 3 * 3, 4096),
281 | nn.ReLU(True),
282 | nn.Dropout(),
283 | nn.Linear(4096, self.num_classes+1),
284 | )
285 | self.softmax = nn.Softmax(2)
286 | def forward(self, x):
287 | x = self.classifier(x)
288 | x = x.view(-1, self.proposals_num, self.num_classes+1)
289 | x = self.softmax(x)
290 | x = x.narrow(2, 0, self.num_classes)
291 | return x
292 |
293 | class Detection_stream(Module):
294 | def __init__(self, proposal_num, part_num):
295 | super(Detection_stream, self).__init__()
296 | self.proposals_num = proposal_num
297 | self.part_num = part_num
298 | self.detector = nn.Sequential(
299 | nn.Linear(512*3*3, 4096),
300 | nn.ReLU(True),
301 | nn.Dropout(),
302 | nn.Linear(4096, self.part_num+1),
303 | )
304 | self.softmax_cls = nn.Softmax(2)
305 | self.softmax_nor = nn.Softmax(1)
306 | def forward(self, x):
307 | x = self.detector(x)
308 | x = x.view(-1, self.proposals_num, self.part_num+1)
309 | x = self.softmax_cls(x)
310 | x = x.narrow(2, 0, self.part_num)
311 | x = self.softmax_nor(x)
312 | return x
313 |
314 | class construct_partnet(nn.Module):
315 | def __init__(self, conv_model, args):
316 | super(construct_partnet, self).__init__()
317 | self.conv_model = conv_model
318 | self.roi_pool = _RoIPooling(3, 3, 1.0/16)
319 | self.args = args
320 | self.DPP = _DPP(args.square_size, args.proposals_per_square, args.proposals_num, args.stride)
321 | self.classification_stream = Classification_stream(args.proposals_num, args.num_classes)
322 | self.detection_stream = Detection_stream(args.proposals_num, args.num_part)
323 |
324 | def forward(self, x):
325 | x = self.conv_model(x)
326 | rois = self.DPP(x)
327 | x = self.roi_pool(x, rois)
328 | x = x.view(x.size(0), -1)
329 | x_c = self.classification_stream(x)
330 | x_d = self.detection_stream(x)
331 | x_c = x_c.transpose(1, 2)
332 | mix = torch.matmul(x_c, x_d)
333 | x = mix.mean(2)
334 |
335 | return x
336 |
337 | def vgg16_bn(args, **kwargs):
338 | """VGG 16-layer model (configuration "D") with batch normalization
339 |
340 | Args:
341 | pretrained (bool): If True, returns a model pre-trained on ImageNet
342 | """
343 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
344 | if args.pretrain:
345 | print('load the imageNet pretrained model')
346 | pretrained_dict = model_zoo.load_url(model_urls['vgg16_bn'])
347 | model_dict = model.state_dict()
348 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
349 | model_dict.update(pretrained_dict_temp)
350 | model.load_state_dict(model_dict)
351 | print(args.finetuned_model)
352 | if args.finetuned_model != '':
353 | print('load the model that has been finetuned on cub', args.finetuned_model)
354 | pretrained_dict = torch.load(args.finetuned_model)['state_dict']
355 | pretrained_dict_temp = copy.deepcopy(pretrained_dict)
356 | model_state_dict = model.state_dict()
357 |
358 | for k_tmp in pretrained_dict_temp.keys():
359 | if k_tmp.find('module.base_conv') != -1:
360 | k = k_tmp.replace('module.base_conv.', '')
361 | pretrained_dict[k] = pretrained_dict.pop(k_tmp)
362 | # ipdb.set_trace()
363 | # print(pretrained_dict)
364 | pretrained_dict_temp2 = {k: v for k, v in pretrained_dict.items() if k in model_state_dict}
365 | # print(pretrained_dict_temp2)
366 | model_state_dict.update(pretrained_dict_temp2)
367 | model.load_state_dict(model_state_dict)
368 |
369 | print('here load the fine tuned conv layer to the model')
370 |
371 | ### change the model to PartNet by ADD: DPP + ROIPooling + two stream.
372 | PartNet = construct_partnet(model, args)
373 |
374 | return PartNet
375 |
376 |
377 | def vgg19(pretrained=False, **kwargs):
378 | """VGG 19-layer model (configuration "E")
379 |
380 | Args:
381 | pretrained (bool): If True, returns a model pre-trained on ImageNet
382 | """
383 | model = VGG(make_layers(cfg['E']), **kwargs)
384 | if pretrained:
385 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
386 | return model
387 |
388 |
389 |
390 |
391 | def vgg19_bn(args, **kwargs):
392 | """VGG 19-layer model (configuration 'E') with batch normalization
393 |
394 | Args:
395 | pretrained (bool): If True, returns a model pre-trained on ImageNet
396 | """
397 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
398 | finetuned_model = args.dataset + '/Image_Classifier/model_best.pth.tar'
399 | if args.pretrain:
400 | print('load the imageNet pretrained model')
401 | pretrained_dict = model_zoo.load_url(model_urls['vgg19_bn'])
402 | model_dict = model.state_dict()
403 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
404 | model_dict.update(pretrained_dict_temp)
405 | model.load_state_dict(model_dict)
406 | if finetuned_model != '':
407 | print('load the model that has been finetuned on cub', finetuned_model)
408 | pretrained_dict = torch.load(finetuned_model)['state_dict']
409 | pretrained_dict_temp = copy.deepcopy(pretrained_dict)
410 | model_state_dict = model.state_dict()
411 |
412 | for k_tmp in pretrained_dict_temp.keys():
413 | if k_tmp.find('module.base_conv') != -1:
414 | k = k_tmp.replace('module.base_conv.', '')
415 | pretrained_dict[k] = pretrained_dict.pop(k_tmp)
416 | pretrained_dict_temp2 = {k: v for k, v in pretrained_dict.items() if k in model_state_dict}
417 | model_state_dict.update(pretrained_dict_temp2)
418 | model.load_state_dict(model_state_dict)
419 | PartNet = construct_partnet(model, args)
420 | return PartNet
421 |
422 |
423 | def partnet_vgg(args, **kwargs):
424 | print("==> creating model '{}' ".format(args.arch))
425 | if args.arch == 'vgg19_bn':
426 | return vgg19_bn(args)
427 | elif args.arch == 'vgg16_bn':
428 | return vgg16_bn(args)
429 | else:
430 | raise ValueError('Unrecognized model architecture', args.arch)
431 |
--------------------------------------------------------------------------------
/models/roi_pooling/Readme.md:
--------------------------------------------------------------------------------
1 | # ROI Pooling
2 | We use the released code from: https://github.com/jwyang/faster-rcnn.pytorch/tree/master/lib/model/roi_pooling
3 |
4 |
--------------------------------------------------------------------------------
/models/roi_pooling/_ext/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/_ext/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/roi_pooling/_ext/roi_pooling/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.utils.ffi import _wrap_function
3 | from ._roi_pooling import lib as _lib, ffi as _ffi
4 |
5 | __all__ = []
6 | def _import_symbols(locals):
7 | for symbol in dir(_lib):
8 | fn = getattr(_lib, symbol)
9 | if callable(fn):
10 | locals[symbol] = _wrap_function(fn, _ffi)
11 | else:
12 | locals[symbol] = fn
13 | __all__.append(symbol)
14 |
15 | _import_symbols(locals())
16 |
--------------------------------------------------------------------------------
/models/roi_pooling/_ext/roi_pooling/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/_ext/roi_pooling/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/roi_pooling/_ext/roi_pooling/_roi_pooling.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/_ext/roi_pooling/_roi_pooling.so
--------------------------------------------------------------------------------
/models/roi_pooling/build.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import torch
4 | from torch.utils.ffi import create_extension
5 |
6 |
7 | sources = ['src/roi_pooling.c']
8 | headers = ['src/roi_pooling.h']
9 | defines = []
10 | with_cuda = False
11 |
12 | if torch.cuda.is_available():
13 | print('Including CUDA code.')
14 | sources += ['src/roi_pooling_cuda.c']
15 | headers += ['src/roi_pooling_cuda.h']
16 | defines += [('WITH_CUDA', None)]
17 | with_cuda = True
18 |
19 | this_file = os.path.dirname(os.path.realpath(__file__))
20 | print(this_file)
21 | extra_objects = ['src/roi_pooling.cu.o']
22 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
23 |
24 | ffi = create_extension(
25 | '_ext.roi_pooling',
26 | headers=headers,
27 | sources=sources,
28 | define_macros=defines,
29 | relative_to=__file__,
30 | with_cuda=with_cuda,
31 | extra_objects=extra_objects
32 | )
33 |
34 | if __name__ == '__main__':
35 | ffi.build()
36 |
--------------------------------------------------------------------------------
/models/roi_pooling/functions/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/functions/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/roi_pooling/functions/__pycache__/roi_pool.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/functions/__pycache__/roi_pool.cpython-36.pyc
--------------------------------------------------------------------------------
/models/roi_pooling/functions/roi_pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from .._ext import roi_pooling
4 | # import time
5 |
6 | class RoIPoolFunction(Function):
7 | def __init__(ctx, pooled_height, pooled_width, spatial_scale):
8 | ctx.pooled_width = pooled_width
9 | ctx.pooled_height = pooled_height
10 | ctx.spatial_scale = spatial_scale
11 | ctx.feature_size = None
12 |
13 | def forward(ctx, features, rois):
14 | # timer = time.time()
15 | ctx.feature_size = features.size()
16 | batch_size, num_channels, data_height, data_width = ctx.feature_size
17 | num_rois = rois.size(0)
18 | output = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_()
19 | ctx.argmax = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_().int()
20 | ctx.rois = rois
21 | if not features.is_cuda:
22 | _features = features.permute(0, 2, 3, 1)
23 | roi_pooling.roi_pooling_forward(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
24 | _features, rois, output)
25 | else:
26 | roi_pooling.roi_pooling_forward_cuda(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
27 | features, rois, output, ctx.argmax)
28 | # print('the forward time of roipooling is:', time.time() - timer)
29 | return output
30 |
31 | def backward(ctx, grad_output):
32 | assert(ctx.feature_size is not None and grad_output.is_cuda)
33 | batch_size, num_channels, data_height, data_width = ctx.feature_size
34 | grad_input = grad_output.new(batch_size, num_channels, data_height, data_width).zero_()
35 |
36 | roi_pooling.roi_pooling_backward_cuda(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
37 | grad_output, ctx.rois, grad_input, ctx.argmax)
38 |
39 | return grad_input, None
40 |
--------------------------------------------------------------------------------
/models/roi_pooling/modules/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/modules/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/roi_pooling/modules/__pycache__/roi_pool.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/modules/__pycache__/roi_pool.cpython-36.pyc
--------------------------------------------------------------------------------
/models/roi_pooling/modules/__pycache__/roi_pool_py.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/modules/__pycache__/roi_pool_py.cpython-36.pyc
--------------------------------------------------------------------------------
/models/roi_pooling/modules/roi_pool.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.module import Module
2 | from ..functions.roi_pool import RoIPoolFunction
3 |
4 |
5 | class _RoIPooling(Module):
6 | def __init__(self, pooled_height, pooled_width, spatial_scale):
7 | super(_RoIPooling, self).__init__()
8 |
9 | self.pooled_width = int(pooled_width)
10 | self.pooled_height = int(pooled_height)
11 | self.spatial_scale = float(spatial_scale)
12 |
13 | def forward(self, features, rois):
14 | return RoIPoolFunction(self.pooled_height, self.pooled_width, self.spatial_scale)(features, rois)
15 |
--------------------------------------------------------------------------------
/models/roi_pooling/src/roi_pooling.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | int roi_pooling_forward(int pooled_height, int pooled_width, float spatial_scale,
5 | THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output)
6 | {
7 | // Grab the input tensor
8 | float * data_flat = THFloatTensor_data(features);
9 | float * rois_flat = THFloatTensor_data(rois);
10 |
11 | float * output_flat = THFloatTensor_data(output);
12 |
13 | // Number of ROIs
14 | int num_rois = THFloatTensor_size(rois, 0);
15 | int size_rois = THFloatTensor_size(rois, 1);
16 | // batch size
17 | int batch_size = THFloatTensor_size(features, 0);
18 | if(batch_size != 1)
19 | {
20 | return 0;
21 | }
22 | // data height
23 | int data_height = THFloatTensor_size(features, 1);
24 | // data width
25 | int data_width = THFloatTensor_size(features, 2);
26 | // Number of channels
27 | int num_channels = THFloatTensor_size(features, 3);
28 |
29 | // Set all element of the output tensor to -inf.
30 | THFloatStorage_fill(THFloatTensor_storage(output), -1);
31 |
32 | // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R
33 | int index_roi = 0;
34 | int index_output = 0;
35 | int n;
36 | for (n = 0; n < num_rois; ++n)
37 | {
38 | int roi_batch_ind = rois_flat[index_roi + 0];
39 | int roi_start_w = round(rois_flat[index_roi + 1] * spatial_scale);
40 | int roi_start_h = round(rois_flat[index_roi + 2] * spatial_scale);
41 | int roi_end_w = round(rois_flat[index_roi + 3] * spatial_scale);
42 | int roi_end_h = round(rois_flat[index_roi + 4] * spatial_scale);
43 | // CHECK_GE(roi_batch_ind, 0);
44 | // CHECK_LT(roi_batch_ind, batch_size);
45 |
46 | int roi_height = fmaxf(roi_end_h - roi_start_h + 1, 1);
47 | int roi_width = fmaxf(roi_end_w - roi_start_w + 1, 1);
48 | float bin_size_h = (float)(roi_height) / (float)(pooled_height);
49 | float bin_size_w = (float)(roi_width) / (float)(pooled_width);
50 |
51 | int index_data = roi_batch_ind * data_height * data_width * num_channels;
52 | const int output_area = pooled_width * pooled_height;
53 |
54 | int c, ph, pw;
55 | for (ph = 0; ph < pooled_height; ++ph)
56 | {
57 | for (pw = 0; pw < pooled_width; ++pw)
58 | {
59 | int hstart = (floor((float)(ph) * bin_size_h));
60 | int wstart = (floor((float)(pw) * bin_size_w));
61 | int hend = (ceil((float)(ph + 1) * bin_size_h));
62 | int wend = (ceil((float)(pw + 1) * bin_size_w));
63 |
64 | hstart = fminf(fmaxf(hstart + roi_start_h, 0), data_height);
65 | hend = fminf(fmaxf(hend + roi_start_h, 0), data_height);
66 | wstart = fminf(fmaxf(wstart + roi_start_w, 0), data_width);
67 | wend = fminf(fmaxf(wend + roi_start_w, 0), data_width);
68 |
69 | const int pool_index = index_output + (ph * pooled_width + pw);
70 | int is_empty = (hend <= hstart) || (wend <= wstart);
71 | if (is_empty)
72 | {
73 | for (c = 0; c < num_channels * output_area; c += output_area)
74 | {
75 | output_flat[pool_index + c] = 0;
76 | }
77 | }
78 | else
79 | {
80 | int h, w, c;
81 | for (h = hstart; h < hend; ++h)
82 | {
83 | for (w = wstart; w < wend; ++w)
84 | {
85 | for (c = 0; c < num_channels; ++c)
86 | {
87 | const int index = (h * data_width + w) * num_channels + c;
88 | if (data_flat[index_data + index] > output_flat[pool_index + c * output_area])
89 | {
90 | output_flat[pool_index + c * output_area] = data_flat[index_data + index];
91 | }
92 | }
93 | }
94 | }
95 | }
96 | }
97 | }
98 |
99 | // Increment ROI index
100 | index_roi += size_rois;
101 | index_output += pooled_height * pooled_width * num_channels;
102 | }
103 | return 1;
104 | }
--------------------------------------------------------------------------------
/models/roi_pooling/src/roi_pooling.cu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Gorilla-Lab-SCUT/PartNet/c5c5488615ea3aadc8b5ecc90d46ef26dba2db3b/models/roi_pooling/src/roi_pooling.cu.o
--------------------------------------------------------------------------------
/models/roi_pooling/src/roi_pooling.h:
--------------------------------------------------------------------------------
1 | int roi_pooling_forward(int pooled_height, int pooled_width, float spatial_scale,
2 | THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output);
--------------------------------------------------------------------------------
/models/roi_pooling/src/roi_pooling_cuda.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "roi_pooling_kernel.h"
4 |
5 | extern THCState *state;
6 |
7 | int roi_pooling_forward_cuda(int pooled_height, int pooled_width, float spatial_scale,
8 | THCudaTensor * features, THCudaTensor * rois, THCudaTensor * output, THCudaIntTensor * argmax)
9 | {
10 | // Grab the input tensor
11 | float * data_flat = THCudaTensor_data(state, features);
12 | float * rois_flat = THCudaTensor_data(state, rois);
13 |
14 | float * output_flat = THCudaTensor_data(state, output);
15 | int * argmax_flat = THCudaIntTensor_data(state, argmax);
16 |
17 | // Number of ROIs
18 | int num_rois = THCudaTensor_size(state, rois, 0);
19 | int size_rois = THCudaTensor_size(state, rois, 1);
20 | if (size_rois != 5)
21 | {
22 | return 0;
23 | }
24 |
25 | // batch size
26 | // int batch_size = THCudaTensor_size(state, features, 0);
27 | // if (batch_size != 1)
28 | // {
29 | // return 0;
30 | // }
31 | // data height
32 | int data_height = THCudaTensor_size(state, features, 2);
33 | // data width
34 | int data_width = THCudaTensor_size(state, features, 3);
35 | // Number of channels
36 | int num_channels = THCudaTensor_size(state, features, 1);
37 |
38 | cudaStream_t stream = THCState_getCurrentStream(state);
39 |
40 | ROIPoolForwardLaucher(
41 | data_flat, spatial_scale, num_rois, data_height,
42 | data_width, num_channels, pooled_height,
43 | pooled_width, rois_flat,
44 | output_flat, argmax_flat, stream);
45 |
46 | return 1;
47 | }
48 |
49 | int roi_pooling_backward_cuda(int pooled_height, int pooled_width, float spatial_scale,
50 | THCudaTensor * top_grad, THCudaTensor * rois, THCudaTensor * bottom_grad, THCudaIntTensor * argmax)
51 | {
52 | // Grab the input tensor
53 | float * top_grad_flat = THCudaTensor_data(state, top_grad);
54 | float * rois_flat = THCudaTensor_data(state, rois);
55 |
56 | float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad);
57 | int * argmax_flat = THCudaIntTensor_data(state, argmax);
58 |
59 | // Number of ROIs
60 | int num_rois = THCudaTensor_size(state, rois, 0);
61 | int size_rois = THCudaTensor_size(state, rois, 1);
62 | if (size_rois != 5)
63 | {
64 | return 0;
65 | }
66 |
67 | // batch size
68 | int batch_size = THCudaTensor_size(state, bottom_grad, 0);
69 | // if (batch_size != 1)
70 | // {
71 | // return 0;
72 | // }
73 | // data height
74 | int data_height = THCudaTensor_size(state, bottom_grad, 2);
75 | // data width
76 | int data_width = THCudaTensor_size(state, bottom_grad, 3);
77 | // Number of channels
78 | int num_channels = THCudaTensor_size(state, bottom_grad, 1);
79 |
80 | cudaStream_t stream = THCState_getCurrentStream(state);
81 | ROIPoolBackwardLaucher(
82 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height,
83 | data_width, num_channels, pooled_height,
84 | pooled_width, rois_flat,
85 | bottom_grad_flat, argmax_flat, stream);
86 |
87 | return 1;
88 | }
89 |
--------------------------------------------------------------------------------
/models/roi_pooling/src/roi_pooling_cuda.h:
--------------------------------------------------------------------------------
1 | int roi_pooling_forward_cuda(int pooled_height, int pooled_width, float spatial_scale,
2 | THCudaTensor * features, THCudaTensor * rois, THCudaTensor * output, THCudaIntTensor * argmax);
3 |
4 | int roi_pooling_backward_cuda(int pooled_height, int pooled_width, float spatial_scale,
5 | THCudaTensor * top_grad, THCudaTensor * rois, THCudaTensor * bottom_grad, THCudaIntTensor * argmax);
--------------------------------------------------------------------------------
/models/roi_pooling/src/roi_pooling_kernel.cu:
--------------------------------------------------------------------------------
1 | // #ifdef __cplusplus
2 | // extern "C" {
3 | // #endif
4 |
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include "roi_pooling_kernel.h"
10 |
11 |
12 | #define DIVUP(m, n) ((m) / (m) + ((m) % (n) > 0))
13 |
14 | #define CUDA_1D_KERNEL_LOOP(i, n) \
15 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
16 | i += blockDim.x * gridDim.x)
17 |
18 | // CUDA: grid stride looping
19 | #define CUDA_KERNEL_LOOP(i, n) \
20 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
21 | i < (n); \
22 | i += blockDim.x * gridDim.x)
23 |
24 | __global__ void ROIPoolForward(const int nthreads, const float* bottom_data,
25 | const float spatial_scale, const int height, const int width,
26 | const int channels, const int pooled_height, const int pooled_width,
27 | const float* bottom_rois, float* top_data, int* argmax_data)
28 | {
29 | CUDA_KERNEL_LOOP(index, nthreads)
30 | {
31 | // (n, c, ph, pw) is an element in the pooled output
32 | // int n = index;
33 | // int pw = n % pooled_width;
34 | // n /= pooled_width;
35 | // int ph = n % pooled_height;
36 | // n /= pooled_height;
37 | // int c = n % channels;
38 | // n /= channels;
39 | int pw = index % pooled_width;
40 | int ph = (index / pooled_width) % pooled_height;
41 | int c = (index / pooled_width / pooled_height) % channels;
42 | int n = index / pooled_width / pooled_height / channels;
43 |
44 | // bottom_rois += n * 5;
45 | int roi_batch_ind = bottom_rois[n * 5 + 0];
46 | int roi_start_w = round(bottom_rois[n * 5 + 1] * spatial_scale);
47 | int roi_start_h = round(bottom_rois[n * 5 + 2] * spatial_scale);
48 | int roi_end_w = round(bottom_rois[n * 5 + 3] * spatial_scale);
49 | int roi_end_h = round(bottom_rois[n * 5 + 4] * spatial_scale);
50 |
51 | // Force malformed ROIs to be 1x1
52 | int roi_width = fmaxf(roi_end_w - roi_start_w + 1, 1);
53 | int roi_height = fmaxf(roi_end_h - roi_start_h + 1, 1);
54 | float bin_size_h = (float)(roi_height) / (float)(pooled_height);
55 | float bin_size_w = (float)(roi_width) / (float)(pooled_width);
56 |
57 | int hstart = (int)(floor((float)(ph) * bin_size_h));
58 | int wstart = (int)(floor((float)(pw) * bin_size_w));
59 | int hend = (int)(ceil((float)(ph + 1) * bin_size_h));
60 | int wend = (int)(ceil((float)(pw + 1) * bin_size_w));
61 |
62 | // Add roi offsets and clip to input boundaries
63 | hstart = fminf(fmaxf(hstart + roi_start_h, 0), height);
64 | hend = fminf(fmaxf(hend + roi_start_h, 0), height);
65 | wstart = fminf(fmaxf(wstart + roi_start_w, 0), width);
66 | wend = fminf(fmaxf(wend + roi_start_w, 0), width);
67 | bool is_empty = (hend <= hstart) || (wend <= wstart);
68 |
69 | // Define an empty pooling region to be zero
70 | float maxval = is_empty ? 0 : -FLT_MAX;
71 | // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
72 | int maxidx = -1;
73 | // bottom_data += roi_batch_ind * channels * height * width;
74 |
75 | int bottom_data_batch_offset = roi_batch_ind * channels * height * width;
76 | int bottom_data_offset = bottom_data_batch_offset + c * height * width;
77 |
78 | for (int h = hstart; h < hend; ++h) {
79 | for (int w = wstart; w < wend; ++w) {
80 | // int bottom_index = (h * width + w) * channels + c;
81 | // int bottom_index = (c * height + h) * width + w;
82 | int bottom_index = h * width + w;
83 | if (bottom_data[bottom_data_offset + bottom_index] > maxval) {
84 | maxval = bottom_data[bottom_data_offset + bottom_index];
85 | maxidx = bottom_data_offset + bottom_index;
86 | }
87 | }
88 | }
89 | top_data[index] = maxval;
90 | if (argmax_data != NULL)
91 | argmax_data[index] = maxidx;
92 | }
93 | }
94 |
95 | int ROIPoolForwardLaucher(
96 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height,
97 | const int width, const int channels, const int pooled_height,
98 | const int pooled_width, const float* bottom_rois,
99 | float* top_data, int* argmax_data, cudaStream_t stream)
100 | {
101 | const int kThreadsPerBlock = 1024;
102 | int output_size = num_rois * pooled_height * pooled_width * channels;
103 | cudaError_t err;
104 |
105 | ROIPoolForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(
106 | output_size, bottom_data, spatial_scale, height, width, channels, pooled_height,
107 | pooled_width, bottom_rois, top_data, argmax_data);
108 |
109 | // dim3 blocks(DIVUP(output_size, kThreadsPerBlock),
110 | // DIVUP(output_size, kThreadsPerBlock));
111 | // dim3 threads(kThreadsPerBlock);
112 | //
113 | // ROIPoolForward<<>>(
114 | // output_size, bottom_data, spatial_scale, height, width, channels, pooled_height,
115 | // pooled_width, bottom_rois, top_data, argmax_data);
116 |
117 | err = cudaGetLastError();
118 | if(cudaSuccess != err)
119 | {
120 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
121 | exit( -1 );
122 | }
123 |
124 | return 1;
125 | }
126 |
127 |
128 | __global__ void ROIPoolBackward(const int nthreads, const float* top_diff,
129 | const int* argmax_data, const int num_rois, const float spatial_scale,
130 | const int height, const int width, const int channels,
131 | const int pooled_height, const int pooled_width, float* bottom_diff,
132 | const float* bottom_rois) {
133 | CUDA_1D_KERNEL_LOOP(index, nthreads)
134 | {
135 |
136 | // (n, c, ph, pw) is an element in the pooled output
137 | int n = index;
138 | int w = n % width;
139 | n /= width;
140 | int h = n % height;
141 | n /= height;
142 | int c = n % channels;
143 | n /= channels;
144 |
145 | float gradient = 0;
146 | // Accumulate gradient over all ROIs that pooled this element
147 | for (int roi_n = 0; roi_n < num_rois; ++roi_n)
148 | {
149 | const float* offset_bottom_rois = bottom_rois + roi_n * 5;
150 | int roi_batch_ind = offset_bottom_rois[0];
151 | // Skip if ROI's batch index doesn't match n
152 | if (n != roi_batch_ind) {
153 | continue;
154 | }
155 |
156 | int roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
157 | int roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
158 | int roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
159 | int roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
160 |
161 | // Skip if ROI doesn't include (h, w)
162 | const bool in_roi = (w >= roi_start_w && w <= roi_end_w &&
163 | h >= roi_start_h && h <= roi_end_h);
164 | if (!in_roi) {
165 | continue;
166 | }
167 |
168 | int offset = roi_n * pooled_height * pooled_width * channels;
169 | const float* offset_top_diff = top_diff + offset;
170 | const int* offset_argmax_data = argmax_data + offset;
171 |
172 | // Compute feasible set of pooled units that could have pooled
173 | // this bottom unit
174 |
175 | // Force malformed ROIs to be 1x1
176 | int roi_width = fmaxf(roi_end_w - roi_start_w + 1, 1);
177 | int roi_height = fmaxf(roi_end_h - roi_start_h + 1, 1);
178 |
179 | float bin_size_h = (float)(roi_height) / (float)(pooled_height);
180 | float bin_size_w = (float)(roi_width) / (float)(pooled_width);
181 |
182 | int phstart = floor((float)(h - roi_start_h) / bin_size_h);
183 | int phend = ceil((float)(h - roi_start_h + 1) / bin_size_h);
184 | int pwstart = floor((float)(w - roi_start_w) / bin_size_w);
185 | int pwend = ceil((float)(w - roi_start_w + 1) / bin_size_w);
186 |
187 | phstart = fminf(fmaxf(phstart, 0), pooled_height);
188 | phend = fminf(fmaxf(phend, 0), pooled_height);
189 | pwstart = fminf(fmaxf(pwstart, 0), pooled_width);
190 | pwend = fminf(fmaxf(pwend, 0), pooled_width);
191 |
192 | for (int ph = phstart; ph < phend; ++ph) {
193 | for (int pw = pwstart; pw < pwend; ++pw) {
194 | if (offset_argmax_data[(c * pooled_height + ph) * pooled_width + pw] == index)
195 | {
196 | gradient += offset_top_diff[(c * pooled_height + ph) * pooled_width + pw];
197 | }
198 | }
199 | }
200 | }
201 | bottom_diff[index] = gradient;
202 | }
203 | }
204 |
205 | int ROIPoolBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois,
206 | const int height, const int width, const int channels, const int pooled_height,
207 | const int pooled_width, const float* bottom_rois,
208 | float* bottom_diff, const int* argmax_data, cudaStream_t stream)
209 | {
210 | const int kThreadsPerBlock = 1024;
211 | int output_size = batch_size * height * width * channels;
212 | cudaError_t err;
213 |
214 | ROIPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(
215 | output_size, top_diff, argmax_data, num_rois, spatial_scale, height, width, channels, pooled_height,
216 | pooled_width, bottom_diff, bottom_rois);
217 |
218 | // dim3 blocks(DIVUP(output_size, kThreadsPerBlock),
219 | // DIVUP(output_size, kThreadsPerBlock));
220 | // dim3 threads(kThreadsPerBlock);
221 | //
222 | // ROIPoolBackward<<>>(
223 | // output_size, top_diff, argmax_data, num_rois, spatial_scale, height, width, channels, pooled_height,
224 | // pooled_width, bottom_diff, bottom_rois);
225 |
226 | err = cudaGetLastError();
227 | if(cudaSuccess != err)
228 | {
229 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
230 | exit( -1 );
231 | }
232 |
233 | return 1;
234 | }
235 |
236 |
237 | // #ifdef __cplusplus
238 | // }
239 | // #endif
240 |
--------------------------------------------------------------------------------
/models/roi_pooling/src/roi_pooling_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _ROI_POOLING_KERNEL
2 | #define _ROI_POOLING_KERNEL
3 |
4 | #ifdef __cplusplus
5 | extern "C" {
6 | #endif
7 |
8 | int ROIPoolForwardLaucher(
9 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height,
10 | const int width, const int channels, const int pooled_height,
11 | const int pooled_width, const float* bottom_rois,
12 | float* top_data, int* argmax_data, cudaStream_t stream);
13 |
14 |
15 | int ROIPoolBackwardLaucher(const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois,
16 | const int height, const int width, const int channels, const int pooled_height,
17 | const int pooled_width, const float* bottom_rois,
18 | float* bottom_diff, const int* argmax_data, cudaStream_t stream);
19 |
20 | #ifdef __cplusplus
21 | }
22 | #endif
23 |
24 | #endif
25 |
26 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | ####################################################
2 | ## Prepare the vgg model for the image level classifier and the part level classifier.
3 | ###################################################
4 | import torch.nn as nn
5 | import torch.utils.model_zoo as model_zoo
6 | import math
7 | import copy
8 | import torch
9 |
10 | __all__ = [
11 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
12 | 'vgg19_bn', 'vgg19',
13 | ]
14 |
15 |
16 | model_urls = {
17 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
18 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
19 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
20 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
21 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
22 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
23 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
24 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
25 | }
26 |
27 |
28 | class VGG(nn.Module):
29 |
30 | def __init__(self, features, num_classes=1000):
31 | super(VGG, self).__init__()
32 | self.features = features
33 | # self.classifier = nn.Sequential(
34 | # nn.Linear(512 * 7 * 7, 4096),
35 | # nn.ReLU(True),
36 | # nn.Dropout(),
37 | # nn.Linear(4096, 4096),
38 | # nn.ReLU(True),
39 | # nn.Dropout(),
40 | # nn.Linear(4096, num_classes),
41 | # )
42 | self.avgpool = nn.AvgPool2d(28)
43 | self._initialize_weights()
44 |
45 | def forward(self, x):
46 | x = self.features(x)
47 | x = self.avgpool(x)
48 | x = x.view(x.size(0), -1)
49 | # x = self.classifier(x)
50 | return x
51 |
52 | def _initialize_weights(self):
53 | for m in self.modules():
54 | if isinstance(m, nn.Conv2d):
55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
56 | m.weight.data.normal_(0, math.sqrt(2. / n))
57 | if m.bias is not None:
58 | m.bias.data.zero_()
59 | elif isinstance(m, nn.BatchNorm2d):
60 | m.weight.data.fill_(1)
61 | m.bias.data.zero_()
62 | elif isinstance(m, nn.Linear):
63 | m.weight.data.normal_(0, 0.01)
64 | m.bias.data.zero_()
65 |
66 |
67 | def make_layers(cfg, batch_norm=False):
68 | layers = []
69 | in_channels = 3
70 | print('the vgg network structure has been changed')
71 | for v in cfg:
72 | if v == 'M':
73 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
74 | else:
75 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
76 | if batch_norm:
77 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
78 | else:
79 | layers += [conv2d, nn.ReLU(inplace=True)]
80 | in_channels = v
81 | return nn.Sequential(*layers)
82 |
83 |
84 | cfg = {
85 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
86 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
87 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
88 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
89 | }
90 |
91 |
92 | def vgg11(pretrained=False, **kwargs):
93 | """VGG 11-layer model (configuration "A")
94 |
95 | Args:
96 | pretrained (bool): If True, returns a model pre-trained on ImageNet
97 | """
98 | model = VGG(make_layers(cfg['A']), **kwargs)
99 | if pretrained:
100 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
101 | return model
102 |
103 |
104 | def vgg11_bn(pretrained=False, **kwargs):
105 | """VGG 11-layer model (configuration "A") with batch normalization
106 |
107 | Args:
108 | pretrained (bool): If True, returns a model pre-trained on ImageNet
109 | """
110 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
111 | if pretrained:
112 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
113 | return model
114 |
115 |
116 | def vgg13(pretrained=False, **kwargs):
117 | """VGG 13-layer model (configuration "B")
118 |
119 | Args:
120 | pretrained (bool): If True, returns a model pre-trained on ImageNet
121 | """
122 | model = VGG(make_layers(cfg['B']), **kwargs)
123 | if pretrained:
124 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
125 | return model
126 |
127 |
128 | def vgg13_bn(pretrained=False, **kwargs):
129 | """VGG 13-layer model (configuration "B") with batch normalization
130 |
131 | Args:
132 | pretrained (bool): If True, returns a model pre-trained on ImageNet
133 | """
134 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
135 | if pretrained:
136 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
137 | return model
138 |
139 |
140 | def vgg16(pretrained=False, **kwargs):
141 | """VGG 16-layer model (configuration "D")
142 |
143 | Args:
144 | pretrained (bool): If True, returns a model pre-trained on ImageNet
145 | """
146 | model = VGG(make_layers(cfg['D']), **kwargs)
147 | if pretrained:
148 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
149 | return model
150 |
151 |
152 | def vgg16_bn(args, **kwargs):
153 | """VGG 16-layer model (configuration "D") with batch normalization
154 |
155 | Args:
156 | pretrained (bool): If True, returns a model pre-trained on ImageNet
157 | """
158 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
159 |
160 | if args.pretrain:
161 | print('load the imageNet pretrained model')
162 | pretrained_dict = model_zoo.load_url(model_urls['vgg16_bn'])
163 | model_dict = model.state_dict()
164 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
165 | model_dict.update(pretrained_dict_temp)
166 | model.load_state_dict(model_dict)
167 | if args.pretrained_model != '':
168 | print('load the pretrained model from:', args.pretrained_model)
169 | pretrained_dict = torch.load(args.pretrained_model)['state_dict']
170 | pretrained_dict_temp = copy.deepcopy(pretrained_dict)
171 | model_state_dict = model.state_dict()
172 |
173 | for k_tmp in pretrained_dict_temp.keys():
174 | if k_tmp.find('module.') != -1:
175 | k = k_tmp.replace('module.', '')
176 | pretrained_dict[k] = pretrained_dict.pop(k_tmp)
177 | pretrained_dict_temp2 = {k: v for k, v in pretrained_dict.items() if k in model_state_dict}
178 | print(pretrained_dict_temp2)
179 | model_state_dict.update(pretrained_dict_temp2)
180 | model.load_state_dict(model_state_dict)
181 |
182 | source_model = Share_convs(model, 512, args.num_classes_t)
183 | return source_model
184 |
185 |
186 | def vgg19(pretrained=False, **kwargs):
187 | """VGG 19-layer model (configuration "E")
188 |
189 | Args:
190 | pretrained (bool): If True, returns a model pre-trained on ImageNet
191 | """
192 | model = VGG(make_layers(cfg['E']), **kwargs)
193 | if pretrained:
194 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
195 | return model
196 |
197 |
198 | class Share_convs(nn.Module):
199 | def __init__(self, base_conv, convout_dimension, num_class):
200 | super(Share_convs, self).__init__()
201 | self.base_conv = base_conv
202 | self.fc = nn.Linear(convout_dimension, num_class)
203 | def forward(self, x):
204 | x = self.base_conv(x)
205 | x = self.fc(x)
206 | return x
207 |
208 | def vgg19_bn(args, **kwargs):
209 | """VGG 19-layer model (configuration 'E') with batch normalization
210 |
211 | Args:
212 | pretrained (bool): If True, returns a model pre-trained on ImageNet
213 | """
214 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
215 |
216 | if args.pretrain:
217 | print('load the imageNet pretrained model')
218 | pretrained_dict = model_zoo.load_url(model_urls['vgg19_bn'])
219 | model_dict = model.state_dict()
220 | pretrained_dict_temp = {k: v for k, v in pretrained_dict.items() if k in model_dict}
221 | model_dict.update(pretrained_dict_temp)
222 | model.load_state_dict(model_dict)
223 | if args.pretrained_model != '':
224 | print('load the pretrained model from:', args.pretrained_model)
225 | pretrained_dict = torch.load(args.pretrained_model)['state_dict']
226 | pretrained_dict_temp = copy.deepcopy(pretrained_dict)
227 | model_state_dict = model.state_dict()
228 |
229 | for k_tmp in pretrained_dict_temp.keys():
230 | if k_tmp.find('module.') != -1:
231 | k = k_tmp.replace('module.', '')
232 | pretrained_dict[k] = pretrained_dict.pop(k_tmp)
233 |
234 | pretrained_dict_temp2 = {k: v for k, v in pretrained_dict.items() if k in model_state_dict}
235 |
236 | model_state_dict.update(pretrained_dict_temp2)
237 | model.load_state_dict(model_state_dict)
238 |
239 | source_model = Share_convs(model, 512, args.num_classes)
240 | return source_model
241 |
242 |
243 | def vgg(args, **kwargs):
244 | print("==> creating model '{}' ".format(args.arch))
245 | if args.arch == 'vgg19_bn':
246 | return vgg19_bn(args)
247 | elif args.arch == 'vgg16_bn':
248 | return vgg16_bn(args)
249 | else:
250 | raise ValueError('Unrecognized model architecture', args.arch)
251 |
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def opts():
5 | parser = argparse.ArgumentParser(description='Train resnet on the cub dataset',
6 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7 | parser.add_argument('--data_path', type=str, default='/home/lab-zhangyabin/project/fine-grained/CUB_200_2011/',
8 | help='Root of the data set')
9 | parser.add_argument('--dataset', type=str, default='cub200',
10 | help='choose between flowers/cub200')
11 | # Optimization options
12 | parser.add_argument('--epochs', type=int, default=160, help='Number of epochs to train')
13 | parser.add_argument('--schedule', type=int, nargs='+', default=[80, 120],
14 | help='Decrease learning rate at these epochs.')
15 | parser.add_argument('--epochs_part', type=int, default=30, help='Number of epochs to train for the part classifier')
16 | parser.add_argument('--schedule_part', type=int, nargs='+', default=[10, 20],
17 | help='Decrease learning rate at these epochs in the training of part classifier.')
18 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size.')
19 | parser.add_argument('--batch_size_partnet', type=int, default=64, help='Batch size for partnet, set to 64 due to the GPU memory constrain')
20 | parser.add_argument('--lr', '--learning_rate', type=float, default=0.1, help='The initial learning rate.')
21 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
22 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='Weight decay (L2 penalty).')
23 |
24 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
25 | # checkpoints
26 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
27 | help='manual epoch number (useful on restarts)')
28 | parser.add_argument('--resume', type=str, default='', help='Checkpoints path to resume(default none)')
29 | parser.add_argument('--pretrained_model', type=str, default='', help='Dir of the ImageNet pretrained modified models')
30 | parser.add_argument('--pretrain', action='store_true', help='whether using pretrained model')
31 | parser.add_argument('--test_only', action='store_true', help='Test only flag')
32 | # Architecture
33 | parser.add_argument('--arch', type=str, default='', help='Model name')
34 | parser.add_argument('--proposals_num', type=int, default=1372, help='the number of proposals for one image')
35 | parser.add_argument('--square_size', type=int, default=4, help='the side length of each cell in DPP module')
36 | parser.add_argument('--proposals_per_square', type=int, default=28, help='the num of proposals per square')
37 | parser.add_argument('--stride', type=int, default=16, help='Stride of the used model')
38 | parser.add_argument('--num_part', type=int, default=3, help='the number of part to be detected in the partnet')
39 | parser.add_argument('--num_classes', type=int, default=200, help='the number of fine-grained classes')
40 | parser.add_argument('--num_select_proposals', type=int, default=50, help='the number of fine-grained classes')
41 |
42 | parser.add_argument('--svb', action='store_true', help='whether apply svb on the classifier')
43 | parser.add_argument('--svb_factor', type=float, default=1.5, help='svb factor in the SVB method')
44 | # i/o
45 | parser.add_argument('--log', type=str, default='./checkpoints', help='Log folder')
46 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
47 | help='number of data loading workers (default: 4)')
48 | parser.add_argument('--print_freq', '-p', default=10, type=int,
49 | metavar='N', help='print frequency (default: 10)')
50 | args = parser.parse_args()
51 | return args
52 |
--------------------------------------------------------------------------------
/process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import shutil
4 | import torch.optim
5 | import torch.nn as nn
6 | import torch.backends.cudnn as cudnn
7 | from models.model_construct import Model_Construct
8 | from data.prepare_data import generate_dataloader
9 | from trainer import train # For the training process
10 | from trainer import validate # For the validate (test) process
11 | from trainer import download_part_proposals
12 | from trainer import download_scores
13 | from trainer import svb
14 | from trainer import svb_det
15 | import time
16 |
17 | def Process1_Image_Classifier(args):
18 |
19 | log_now = args.dataset + '/Image_Classifier'
20 | process_name = 'image_classifier'
21 | if os.path.isfile(log_now + '/final.txt'):
22 | print('the Process1_Image_Classifier is finished')
23 | return
24 | best_prec1 = 0
25 | model = Model_Construct(args, process_name)
26 | model = torch.nn.DataParallel(model).cuda()
27 | criterion = nn.CrossEntropyLoss().cuda()
28 | optimizer = torch.optim.SGD([
29 | {'params': model.module.base_conv.parameters(), 'name': 'pre-trained'},
30 | {'params': model.module.fc.parameters(), 'lr': args.lr, 'name': 'new-added'}
31 | ],
32 | lr=args.lr,
33 | momentum=args.momentum,
34 | weight_decay=args.weight_decay)
35 | start_epoch = args.start_epoch
36 | if args.resume:
37 | if os.path.isfile(args.resume):
38 | print("==> loading checkpoints '{}'".format(args.resume))
39 | checkpoint = torch.load(args.resume)
40 | start_epoch = checkpoint['epoch']
41 | best_prec1 = checkpoint['best_prec1']
42 | model.load_state_dict(checkpoint['state_dict'])
43 | optimizer.load_state_dict(checkpoint['optimizer'])
44 | print("==> loaded checkpoint '{}'(epoch {})"
45 | .format(args.resume, checkpoint['epoch']))
46 | args.resume = ''
47 | else:
48 | raise ValueError('The file to be resumed from is not exited', args.resume)
49 | else:
50 | if not os.path.isdir(log_now):
51 | os.makedirs(log_now)
52 | log = open(os.path.join(log_now, 'log.txt'), 'w')
53 | state = {k: v for k, v in args._get_kwargs()}
54 | log.write(json.dumps(state) + '\n')
55 | log.close()
56 | cudnn.benchmark = True
57 | train_loader, val_loader = generate_dataloader(args, process_name, -1)
58 | if args.test_only:
59 | validate(val_loader, model, criterion, 2000, args)
60 |
61 | for epoch in range(start_epoch, args.epochs):
62 | # train for one epoch
63 | train(train_loader, model, criterion, optimizer, epoch, log_now, process_name, args)
64 | # evaluate on the val data
65 | prec1 = validate(val_loader, model, criterion, epoch, log_now, process_name, args)
66 | # record the best prec1 and save checkpoint
67 | is_best = prec1 > best_prec1
68 | best_prec1 = max(prec1, best_prec1)
69 | if is_best:
70 | log = open(os.path.join(log_now, 'log.txt'), 'a')
71 | log.write(
72 | "best acc %3f" % (best_prec1))
73 | log.close()
74 | save_checkpoint({
75 | 'epoch': epoch + 1,
76 | 'arch': args.arch,
77 | 'state_dict': model.state_dict(),
78 | 'best_prec1': best_prec1,
79 | 'optimizer' : optimizer.state_dict(),
80 | }, is_best, log_now)
81 | #download_scores(val_loader, model, log_now, process_name, args)
82 | log = open(os.path.join(log_now, 'final.txt'), 'w')
83 | log.write(
84 | "best acc %3f" % (best_prec1))
85 | log.close()
86 |
87 |
88 | def Process2_PartNet(args):
89 | log_now = args.dataset + '/PartNet'
90 | process_name = 'partnet'
91 | if os.path.isfile(log_now + '/final.txt'):
92 | print('the Process2_PartNet is finished')
93 | return
94 | best_prec1 = 0
95 | model = Model_Construct(args, process_name)
96 | model = torch.nn.DataParallel(model).cuda()
97 | criterion = nn.BCELoss().cuda()
98 | # print(model)
99 | # print('the learning rate for the new added layer is set to 1e-3 to slow down the speed of learning.')
100 | optimizer = torch.optim.SGD([
101 | {'params': model.module.conv_model.parameters(), 'name': 'pre-trained'},
102 | {'params': model.module.classification_stream.parameters(), 'name': 'new-added'},
103 | {'params': model.module.detection_stream.parameters(), 'name': 'new-added'}
104 | ],
105 | lr=args.lr,
106 | momentum=args.momentum,
107 | weight_decay=args.weight_decay)
108 | start_epoch = args.start_epoch
109 | if args.resume:
110 | if os.path.isfile(args.resume):
111 | print("==> loading checkpoints '{}'".format(args.resume))
112 | checkpoint = torch.load(args.resume)
113 | start_epoch = checkpoint['epoch']
114 | best_prec1 = checkpoint['best_prec1']
115 | model.load_state_dict(checkpoint['state_dict'])
116 | optimizer.load_state_dict(checkpoint['optimizer'])
117 | print("==> loaded checkpoint '{}'(epoch {})"
118 | .format(args.resume, checkpoint['epoch']))
119 | args.resume = ''
120 | else:
121 | raise ValueError('The file to be resumed from is not exited', args.resume)
122 | else:
123 | if not os.path.isdir(log_now):
124 | os.makedirs(log_now)
125 | log = open(os.path.join(log_now, 'log.txt'), 'w')
126 | state = {k: v for k, v in args._get_kwargs()}
127 | log.write(json.dumps(state) + '\n')
128 | log.close()
129 | cudnn.benchmark = True
130 | train_loader, val_loader = generate_dataloader(args, process_name, -1)
131 | if args.test_only:
132 | validate(val_loader, model, criterion, 2000, args)
133 | for epoch in range(start_epoch, args.epochs):
134 | # train for one epoch
135 | train(train_loader, model, criterion, optimizer, epoch, log_now, process_name, args)
136 | # evaluate on the val data
137 | prec1 = validate(val_loader, model, criterion, epoch, log_now, process_name, args)
138 | # record the best prec1 and save checkpoint
139 | is_best = prec1 > best_prec1
140 | best_prec1 = max(prec1, best_prec1)
141 | if is_best:
142 | log = open(os.path.join(log_now, 'log.txt'), 'a')
143 | log.write(
144 | "best acc %3f" % (best_prec1))
145 | log.close()
146 | save_checkpoint({
147 | 'epoch': epoch + 1,
148 | 'arch': args.arch,
149 | 'state_dict': model.state_dict(),
150 | 'best_prec1': best_prec1,
151 | 'optimizer': optimizer.state_dict(),
152 | }, is_best, log_now)
153 | svb_timer = time.time()
154 | if args.svb and epoch != (args.epochs - 1):
155 | svb(model, args)
156 | print('!!!!!!!!!!!!!!!!!! the svb constrain is only applied on the classification stream.')
157 | svb_det(model, args)
158 | print('the svb time is: ', time.time() - svb_timer)
159 | #download_scores(val_loader, model, log_now, process_name, args)
160 | log = open(os.path.join(log_now, 'final.txt'), 'w')
161 | log.write(
162 | "best acc %3f" % (best_prec1))
163 | log.close()
164 |
165 | def Process3_Download_Proposals(args):
166 | log_now = args.dataset + '/Download_Proposals'
167 | process_name = 'download_proposals'
168 | if os.path.isfile(log_now + '/final.txt'):
169 | print('the Process3_download proposals is finished')
170 | return
171 |
172 | model = Model_Construct(args, process_name)
173 | model = torch.nn.DataParallel(model).cuda()
174 |
175 | optimizer = torch.optim.SGD([
176 | {'params': model.module.conv_model.parameters(), 'name': 'pre-trained'},
177 | {'params': model.module.classification_stream.parameters(), 'name': 'new-added'},
178 | {'params': model.module.detection_stream.parameters(), 'name': 'new-added'}
179 | ],
180 | lr=args.lr,
181 | momentum=args.momentum,
182 | weight_decay=args.weight_decay)
183 | log_partnet_model = args.dataset + '/PartNet/model_best.pth.tar'
184 | checkpoint = torch.load(log_partnet_model)
185 | model.load_state_dict(checkpoint['state_dict'])
186 | print('load the pre-trained partnet model from:', log_partnet_model)
187 |
188 | if args.resume:
189 | if os.path.isfile(args.resume):
190 | print("==> loading checkpoints '{}'".format(args.resume))
191 | checkpoint = torch.load(args.resume)
192 | start_epoch = checkpoint['epoch']
193 | best_prec1 = checkpoint['best_prec1']
194 | model.load_state_dict(checkpoint['state_dict'])
195 | optimizer.load_state_dict(checkpoint['optimizer'])
196 | print("==> loaded checkpoint '{}'(epoch {})"
197 | .format(args.resume, checkpoint['epoch']))
198 | args.resume = ''
199 | else:
200 | raise ValueError('The file to be resumed from is not exited', args.resume)
201 | else:
202 | if not os.path.isdir(log_now):
203 | os.makedirs(log_now)
204 | log = open(os.path.join(log_now, 'log.txt'), 'w')
205 | state = {k: v for k, v in args._get_kwargs()}
206 | log.write(json.dumps(state) + '\n')
207 | log.close()
208 |
209 | cudnn.benchmark = True
210 | train_loader, val_loader = generate_dataloader(args, process_name)
211 |
212 | for epoch in range(1):
213 |
214 | download_part_proposals(train_loader, model, epoch, log_now, process_name, 'train', args)
215 |
216 | best_prec1 = download_part_proposals(val_loader, model, epoch, log_now, process_name, 'val', args)
217 |
218 | log = open(os.path.join(log_now, 'final.txt'), 'w')
219 | log.write(
220 | "best acc %3f" % (best_prec1))
221 | log.close()
222 |
223 |
224 | def Process4_Part_Classifiers(args):
225 | for i in range(args.num_part): ### if the process is break in this section, more modification is needed.
226 | log_now = args.dataset + '/Part_Classifiers_' + str(i)
227 | process_name = 'part_classifiers'
228 | if os.path.isfile(log_now + '/final.txt'):
229 | print('the Process4_Part_Classifier is finished', i)
230 | continue
231 | best_prec1 = 0
232 | model = Model_Construct(args, process_name)
233 | model = torch.nn.DataParallel(model).cuda()
234 | criterion = nn.CrossEntropyLoss().cuda()
235 | optimizer = torch.optim.SGD([
236 | {'params': model.module.base_conv.parameters(), 'name': 'pre-trained'},
237 | {'params': model.module.fc.parameters(), 'lr': args.lr, 'name': 'new-added'}
238 | ],
239 | lr=args.lr,
240 | momentum=args.momentum,
241 | weight_decay=args.weight_decay)
242 | log_image_model = args.dataset + '/Image_Classifier/model_best.pth.tar'
243 | checkpoint = torch.load(log_image_model)
244 | model.load_state_dict(checkpoint['state_dict'])
245 | print('load the cub fine-tuned model from:', log_image_model)
246 | start_epoch = args.start_epoch
247 | if args.resume:
248 | if os.path.isfile(args.resume):
249 | print("==> loading checkpoints '{}'".format(args.resume))
250 | checkpoint = torch.load(args.resume)
251 | start_epoch = checkpoint['epoch']
252 | best_prec1 = checkpoint['best_prec1']
253 | model.load_state_dict(checkpoint['state_dict'])
254 | optimizer.load_state_dict(checkpoint['optimizer'])
255 | print("==> loaded checkpoint '{}'(epoch {})"
256 | .format(args.resume, checkpoint['epoch']))
257 | args.resume = ''
258 | else:
259 | raise ValueError('The file to be resumed from is not exited', args.resume)
260 | else:
261 | if not os.path.isdir(log_now):
262 | os.makedirs(log_now)
263 | log = open(os.path.join(log_now, 'log.txt'), 'w')
264 | state = {k: v for k, v in args._get_kwargs()}
265 | log.write(json.dumps(state) + '\n')
266 | log.close()
267 | cudnn.benchmark = True
268 | train_loader, val_loader = generate_dataloader(args, process_name, i)
269 | if args.test_only:
270 | validate(val_loader, model, criterion, 2000, args)
271 | for epoch in range(start_epoch, args.epochs_part):
272 | # train for one epoch
273 | train(train_loader, model, criterion, optimizer, epoch, log_now, process_name, args)
274 | # evaluate on the val data
275 | prec1 = validate(val_loader, model, criterion, epoch, log_now, process_name, args)
276 | # record the best prec1 and save checkpoint
277 | is_best = prec1 > best_prec1
278 | best_prec1 = max(prec1, best_prec1)
279 | if is_best:
280 | log = open(os.path.join(log_now, 'log.txt'), 'a')
281 | log.write(
282 | "best acc %3f" % (best_prec1))
283 | log.close()
284 | save_checkpoint({
285 | 'epoch': epoch + 1,
286 | 'arch': args.arch,
287 | 'state_dict': model.state_dict(),
288 | 'best_prec1': best_prec1,
289 | 'optimizer' : optimizer.state_dict(),
290 | }, is_best, log_now)
291 | #download_scores(val_loader, model, log_now, process_name, args)
292 | log = open(os.path.join(log_now, 'final.txt'), 'w')
293 | log.write(
294 | "best acc %3f" % (best_prec1))
295 | log.close()
296 |
297 | def Process5_Final_Result(args):
298 | ############################# Image Level Classifier #############################
299 | log_now = args.dataset + '/Image_Classifier'
300 | process_name = 'image_classifier'
301 | model = Model_Construct(args, process_name)
302 | model = torch.nn.DataParallel(model).cuda()
303 | pre_trained_model = log_now + '/model_best.pth.tar'
304 | checkpoint = torch.load(pre_trained_model)
305 | model.load_state_dict(checkpoint['state_dict'])
306 | train_loader, val_loader = generate_dataloader(args, process_name, -1)
307 | download_scores(val_loader, model, log_now, process_name, args)
308 | ############################# PartNet ############################################
309 | log_now = args.dataset + '/PartNet'
310 | process_name = 'partnet'
311 | model = Model_Construct(args, process_name)
312 | model = torch.nn.DataParallel(model).cuda()
313 | pre_trained_model = log_now + '/model_best.pth.tar'
314 | checkpoint = torch.load(pre_trained_model)
315 | model.load_state_dict(checkpoint['state_dict'])
316 | train_loader, val_loader = generate_dataloader(args, process_name)
317 | download_scores(val_loader, model, log_now, process_name, args)
318 | ############################# Three Part Level Classifiers #######################
319 | for i in range(args.num_part): ### if the process is break in this section, more modification is needed.
320 | log_now = args.dataset + '/Part_Classifiers_' + str(i)
321 | process_name = 'part_classifiers'
322 | model = Model_Construct(args, process_name)
323 | model = torch.nn.DataParallel(model).cuda()
324 | pre_trained_model = log_now + '/model_best.pth.tar'
325 | checkpoint = torch.load(pre_trained_model)
326 | model.load_state_dict(checkpoint['state_dict'])
327 | train_loader, val_loader = generate_dataloader(args, process_name, i)
328 | download_scores(val_loader, model, log_now, process_name, args)
329 |
330 |
331 | log_image = args.dataset + '/Image_Classifier'
332 | process_image = 'image_classifier'
333 |
334 | log_partnet = args.dataset + '/PartNet'
335 | process_partnet = 'partnet'
336 |
337 | log_part0 = args.dataset + '/Part_Classifiers_' + str(0)
338 | process_part0 = 'part_classifiers'
339 |
340 | log_part1 = args.dataset + '/Part_Classifiers_' + str(1)
341 | process_part1 = 'part_classifiers'
342 |
343 | log_part2 = args.dataset + '/Part_Classifiers_' + str(2)
344 | process_part2 = 'part_classifiers'
345 |
346 | image_table = torch.load(log_image + '/' + process_image + '.pth.tar')
347 | image_probability = image_table['scores']
348 | labels = image_table['labels']
349 | partnet_table = torch.load(log_partnet + '/' + process_partnet + '.pth.tar')
350 | partnet_probability = partnet_table['scores']
351 | #######################
352 | part0_table = torch.load(log_part0 + '/' + process_part0 + '.pth.tar')
353 | part0_probability = part0_table['scores']
354 | ##########################
355 | part1_table = torch.load(log_part1 + '/' + process_part1 + '.pth.tar')
356 | part1_probability = part1_table['scores']
357 | ##########################
358 | part2_table = torch.load(log_part2 + '/' + process_part2 + '.pth.tar')
359 | part2_probability = part2_table['scores']
360 | ##########################
361 |
362 | probabilities_group = []
363 | probabilities_group.append(image_probability)
364 | probabilities_group.append(part0_probability)
365 | probabilities_group.append(part1_probability)
366 | probabilities_group.append(part2_probability)
367 | probabilities_group.append(partnet_probability)
368 | count = 0
369 | for i in range(len(labels)):
370 | probability = probabilities_group[0][i]
371 | for j in range(len(probabilities_group)):
372 | probability = probabilities_group[j][i] + probability
373 | probability = probability - probabilities_group[0][i]
374 | label = labels[i]
375 | value, index = probability.sort(0, descending=True)
376 | if index[0] == label:
377 | count = count + 1
378 | top1 = count / len(labels)
379 | print('the final results obtained by averaging part0-1-2 image partnet is', top1)
380 |
381 | def save_checkpoint(state, is_best, log_now):
382 | filename = 'checkpoint.pth.tar'
383 | dir_save_file = os.path.join(log_now, filename)
384 | torch.save(state, dir_save_file)
385 | if is_best:
386 | shutil.copyfile(dir_save_file, os.path.join(log_now, 'model_best.pth.tar'))
387 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | ########################################## Use the following command to test the final results ####################################
3 | python main.py --batch_size 128 --batch_size_partnet 64 --momentum 0.9 --weight_decay 1e-4 --data_path /data/flower102/ --test_only \
4 | --proposals_num 1372 --square_size 4 --proposals_per_square 28 --workers 8 --lr 0.1 --svb --svb_factor 1.5 \
5 | --dataset flowers --print_freq 1 --arch vgg19_bn --num_part 3 --epochs_part 30 --num_classes 102 --num_select_proposals 50 \
6 | --epochs 160 --pretrained_model /home/lab-zhang.yabin/PAFGN/pytorch-PartNet/pretrain_vgg_imagenet/vgg19-bn-modified/model_best.pth.tar \
7 |
8 | #python main.py --batch_size 128 --batch_size_partnet 64 --momentum 0.9 --weight_decay 1e-4 --data_path /data/CUB_200_2011/ --test_only \
9 | # --proposals_num 1372 --square_size 4 --proposals_per_square 28 --workers 8 --lr 0.1 --svb --svb_factor 1.5 --schedule 80 120 --schedule_part 10 20 \
10 | # --dataset cub200 --print_freq 1 --arch vgg19_bn --num_part 3 --epochs_part 30 --num_classes 200 --num_select_proposals 50 \
11 | # --epochs 160 --pretrained_model /home/lab-zhang.yabin/PAFGN/pytorch-PartNet/pretrain_vgg_imagenet/vgg19-bn-modified/model_best.pth.tar
12 |
13 |
14 | ########################################### use the following command to train the PartNet #######################################
15 | #python main.py --batch_size 128 --batch_size_partnet 64 --momentum 0.9 --weight_decay 1e-4 --data_path /data/flower102/ \
16 | # --proposals_num 1372 --square_size 4 --proposals_per_square 28 --workers 8 --lr 0.1 --svb --svb_factor 1.5 \
17 | # --dataset flowers --print_freq 1 --arch vgg19_bn --num_part 3 --epochs_part 30 --num_classes 102 --num_select_proposals 50 \
18 | # --epochs 160 --pretrained_model /home/lab-zhang.yabin/PAFGN/pytorch-PartNet/pretrain_vgg_imagenet/vgg19-bn-modified/model_best.pth.tar \
19 |
20 |
21 | #python main.py --batch_size 128 --batch_size_partnet 64 --momentum 0.9 --weight_decay 1e-4 --data_path /data/CUB_200_2011/ \
22 | # --proposals_num 1372 --square_size 4 --proposals_per_square 28 --workers 8 --lr 0.1 --svb --svb_factor 1.5 --schedule 80 120 --schedule_part 10 20 \
23 | # --dataset cub200 --print_freq 1 --arch vgg19_bn --num_part 3 --epochs_part 30 --num_classes 200 --num_select_proposals 50 \
24 | # --epochs 160 --pretrained_model /home/lab-zhang.yabin/PAFGN/pytorch-PartNet/pretrain_vgg_imagenet/vgg19-bn-modified/model_best.pth.tar
25 |
26 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import os
4 | import ipdb
5 | from PIL import Image
6 | import torch.nn as nn
7 | import torchvision.transforms as transforms
8 | import numpy as np
9 |
10 | def train(train_loader, model, criterion, optimizer, epoch, log_now, process_name, args):
11 | batch_time = AverageMeter()
12 | data_time = AverageMeter()
13 | losses = AverageMeter()
14 | top1 = AverageMeter()
15 | top5 = AverageMeter()
16 | model.train()
17 | adjust_learning_rate(optimizer, epoch, process_name, args)
18 | end = time.time()
19 | for i, (input, target, target_loss) in enumerate(train_loader):
20 | data_time.update(time.time() - end)
21 | if process_name == 'partnet':
22 | target = target.cuda(non_blocking=True)
23 | target_loss = target_loss.cuda(non_blocking=True)
24 | target_var = torch.autograd.Variable(target_loss)
25 | elif process_name == 'image_classifier' or process_name == 'part_classifiers':
26 | target = target.cuda(non_blocking=True)
27 | target_var = torch.autograd.Variable(target)
28 | else:
29 | raise ValueError('the required process type is not supported')
30 | input_var = torch.autograd.Variable(input)
31 | # print(target_var)
32 | # ipdb.set_trace()
33 | output = model(input_var)
34 | # print(output)
35 | loss = criterion(output, target_var)
36 | #mesure accuracy and record loss
37 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
38 | losses.update(loss.data[0], input.size(0))
39 | top1.update(prec1[0], input.size(0))
40 | top5.update(prec5[0], input.size(0))
41 |
42 | #compute gradient and do SGD step
43 | optimizer.zero_grad()
44 | loss.backward()
45 | optimizer.step()
46 | batch_time.update(time.time() - end)
47 | end = time.time()
48 | if i % args.print_freq == 0:
49 | print('Train: [{0}][{1}/{2}]\t'
50 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
51 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
52 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
53 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
54 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
55 | epoch, i, len(train_loader), batch_time=batch_time,
56 | data_time=data_time, loss=losses, top1=top1, top5=top5))
57 | log = open(os.path.join(log_now, 'log.txt'), 'a')
58 | log.write("\n")
59 | log.write("Train:epoch: %d, loss: %4f, Top1 acc: %3f, Top5 acc: %3f" % (epoch, losses.avg, top1.avg, top5.avg))
60 | log.close()
61 |
62 |
63 | def validate(val_loader, model, criterion, epoch, log_now, process_name, args):
64 | batch_time = AverageMeter()
65 | losses = AverageMeter()
66 | top1 = AverageMeter()
67 | top5 = AverageMeter()
68 | # switch to evaluate mode
69 | model.eval()
70 |
71 | end = time.time()
72 | for i, (input, target, target_loss) in enumerate(val_loader):
73 | target = target.cuda(non_blocking=True)
74 | input_var = torch.autograd.Variable(input)
75 | if process_name == 'partnet':
76 | target = target.cuda(non_blocking=True)
77 | target_loss = target_loss.cuda(non_blocking=True)
78 | target_var = torch.autograd.Variable(target_loss)
79 | elif process_name == 'image_classifier' or process_name == 'part_classifiers':
80 | target = target.cuda(non_blocking=True)
81 | target_var = torch.autograd.Variable(target)
82 | # compute output
83 | with torch.no_grad():
84 | output = model(input_var)
85 | loss = criterion(output, target_var)
86 | # measure accuracy and record loss
87 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
88 | losses.update(loss.data[0], input.size(0))
89 | top1.update(prec1[0], input.size(0))
90 | top5.update(prec5[0], input.size(0))
91 |
92 | # measure elapsed time
93 | batch_time.update(time.time() - end)
94 | end = time.time()
95 |
96 | if i % args.print_freq == 0:
97 | print('Test: [{0}][{1}/{2}]\t'
98 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
99 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
100 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
101 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
102 | epoch, i, len(val_loader), batch_time=batch_time, loss=losses,
103 | top1=top1, top5=top5))
104 |
105 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
106 | .format(top1=top1, top5=top5))
107 | log = open(os.path.join(log_now, 'log.txt'), 'a')
108 | log.write("\n")
109 | log.write(" Train:epoch: %d, loss: %4f, Top1 acc: %3f, Top5 acc: %3f" %\
110 | (epoch, losses.avg, top1.avg, top5.avg))
111 | log.close()
112 | return top1.avg
113 |
114 | def download_part_proposals(data_loader, model, epoch, log_now, process_name, train_or_val, args):
115 | batch_time = AverageMeter()
116 | data_time = AverageMeter()
117 |
118 | top1 = AverageMeter()
119 | top5 = AverageMeter()
120 | model.eval()
121 | # adjust_learning_rate(optimizer, epoch, args)
122 | end = time.time()
123 | tensor2image = transforms.ToPILImage()
124 | if train_or_val == 'train':
125 | num_selected = args.num_select_proposals
126 | elif train_or_val == 'val':
127 | num_selected = 1
128 | else:
129 | raise ValueError('only accept train or val module')
130 | for i, (input, input_keep, target, target_loss, path_image) in enumerate(data_loader):
131 | data_time.update(time.time() - end)
132 | target = target.cuda(non_blocking=True)
133 | target_loss = target_loss.cuda(non_blocking=True)
134 | input_var = torch.autograd.Variable(input)
135 | target_var = torch.autograd.Variable(target_loss)
136 |
137 | proposals_scores_h = torch.Tensor(input.size(0), args.proposals_num, args.num_part + 1).fill_(
138 | 0) # all the proposals scores
139 | all_proposals = torch.Tensor(input.size(0) * args.proposals_num, 5).fill_(0) # all the
140 |
141 | def hook_scores(module, inputdata, output):
142 | proposals_scores_h.copy_(output.data)
143 |
144 | def hook_proposals(module, inputdata, output):
145 | all_proposals.copy_(output.data)
146 |
147 | handle_scores = model.module.detection_stream.softmax_cls.register_forward_hook(hook_scores)
148 | handle_proposals = model.module.DPP.register_forward_hook(hook_proposals)
149 | with torch.no_grad():
150 | output = model(input_var)
151 |
152 | handle_proposals.remove() ## delete the hook after used.
153 | handle_scores.remove()
154 | # print(output)
155 | for j in range(input.size(0)):
156 | real_image = tensor2image(input_keep[j])
157 | scores_for_image = proposals_scores_h[j]
158 | value, sort = torch.sort(scores_for_image, 0, descending=True) # the score from large to small
159 | # print('value:', value)
160 | # print('sort:', sort)
161 | proposals_one = all_proposals[j*args.proposals_num:(j+1)*args.proposals_num, 1:5]
162 | # print(sort)
163 | #check whether the dir file exist, if not, create one.
164 | img_dir = path_image[j]
165 |
166 | # ipdb.set_trace()
167 | for num_p in range(0, args.num_part+1):
168 | last_dash = img_dir.rfind('/')
169 | image_folder = args.data_path + 'PartNet' + args.arch + '/part_' + str(num_p) + '/' + img_dir[0:last_dash]
170 | if not os.path.isdir(image_folder):
171 | os.makedirs(image_folder)
172 | image_name = img_dir[last_dash:]
173 | last_point = image_name.find('.')
174 | name = image_name[0:last_point]
175 | one_part_sort = sort[:, num_p]
176 | for k in range(num_selected): # for each proposals
177 | select_proposals = np.array(proposals_one[one_part_sort[k]])
178 | cropped_image = real_image.crop((select_proposals[0], select_proposals[1], select_proposals[2], select_proposals[3]))
179 | dir_to_save = image_folder + name + '_' + str(k) + image_name[last_point:]
180 | cropped_image.save(dir_to_save)
181 |
182 | #mesure accuracy and record loss
183 | prec1, prec5 = accuracy(output.data, target, topk=(1,5))
184 | top1.update(prec1[0], input.size(0))
185 | top5.update(prec5[0], input.size(0))
186 |
187 | batch_time.update(time.time() - end)
188 | end = time.time()
189 | if i % args.print_freq == 0:
190 | print(train_or_val + ': [{0}][{1}/{2}]\t'
191 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
192 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
193 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
194 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
195 | epoch, i, len(data_loader), batch_time=batch_time,
196 | data_time=data_time, top1=top1, top5=top5))
197 | log = open(os.path.join(log_now, 'log.txt'), 'a')
198 | log.write("\n")
199 | log.write(train_or_val + ":epoch: %d, Top1 acc: %3f, Top5 acc: %3f" % (epoch, top1.avg, top5.avg))
200 | log.close()
201 | return top1.avg
202 |
203 | def download_scores(val_loader, model, log_now, process_name, args):
204 | if not os.path.isdir(log_now):
205 | raise ValueError('the log dir request is not exist')
206 | file_to_save_or_load = log_now + '/' + process_name + '.pth.tar'
207 | probabilities = []
208 | labels = []
209 | batch_time = AverageMeter()
210 | top1 = AverageMeter()
211 | top5 = AverageMeter()
212 | # switch to evaluate mode
213 | model.eval()
214 | end = time.time()
215 | softmax = nn.Softmax()
216 | for i, (input, target, target_loss) in enumerate(val_loader):
217 | target = target.cuda(async=True)
218 | input_var = torch.autograd.Variable(input)
219 |
220 | # compute output
221 | with torch.no_grad():
222 | output = model(input_var)
223 | # print(output)
224 | if process_name == 'partnet':
225 | output = torch.nn.functional.normalize(output, p=1, dim=1)
226 | else:
227 | output = softmax(output)
228 |
229 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
230 | for j in range(input.size(0)):
231 | # maybe here need to sub tensor to save memory.
232 | probabilities.append(output.data[j].cpu().clone())
233 | labels.append(target[j])
234 |
235 | top1.update(prec1[0], input.size(0))
236 | top5.update(prec5[0], input.size(0))
237 |
238 | # measure elapsed time
239 | batch_time.update(time.time() - end)
240 | end = time.time()
241 | print('Test: [{0}][{1}/{2}]\t'
242 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
243 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
244 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
245 | 1, i, len(val_loader), batch_time=batch_time,
246 | top1=top1, top5=top5))
247 |
248 | log = open(os.path.join(log_now, 'log.txt'), 'a')
249 | log.write("\n")
250 | log.write(process_name)
251 | log.write(" Val:epoch: %d, Top1 acc: %3f, Top5 acc: %3f" % \
252 | (1, top1.avg, top5.avg))
253 | log.close()
254 | torch.save({'scores': probabilities, 'labels': labels}, file_to_save_or_load)
255 |
256 | return probabilities, labels
257 |
258 | def svb(model, args):
259 | print('the layer used for svb is', model.module.classification_stream.classifier[3])
260 | svb_model = model.module.classification_stream.classifier[3]
261 | tmpbatchM = svb_model.weight.data.t().clone()
262 | tmpU, tmpS, tmpV = torch.svd(tmpbatchM)
263 | for idx in range(0, tmpS.size(0)):
264 | if tmpS[idx] > args.svb_factor:
265 | tmpS[idx] = args.svb_factor
266 | elif tmpS[idx] < 1 / args.svb_factor:
267 | tmpS[idx] = 1 / args.svb_factor
268 | tmpbatchM = torch.mm(torch.mm(tmpU, torch.diag(tmpS.cuda())), tmpV.t()).t().contiguous()
269 | svb_model.weight.data.copy_(tmpbatchM.view_as(svb_model.weight.data))
270 |
271 | def svb_det(model, args): ## it is not use in our experiments
272 | print('the layer used for svb is', model.module.detection_stream.detector[3])
273 | svb_model = model.module.detection_stream.detector[3]
274 | tmpbatchM = svb_model.weight.data.t().clone()
275 | tmpU, tmpS, tmpV = torch.svd(tmpbatchM)
276 | for idx in range(0, tmpS.size(0)):
277 | if tmpS[idx] > args.svb_factor:
278 | tmpS[idx] = args.svb_factor
279 | elif tmpS[idx] < 1 / args.svb_factor:
280 | tmpS[idx] = 1 / args.svb_factor
281 | tmpbatchM = torch.mm(torch.mm(tmpU, torch.diag(tmpS.cuda())), tmpV.t()).t().contiguous()
282 | svb_model.weight.data.copy_(tmpbatchM.view_as(svb_model.weight.data))
283 |
284 | class AverageMeter(object):
285 | """Computes and stores the average and current value"""
286 | def __init__(self):
287 | self.reset()
288 |
289 | def reset(self):
290 | self.val = 0
291 | self.avg = 0
292 | self.sum = 0
293 | self.count = 0
294 |
295 | def update(self, val, n=1):
296 | self.val = val
297 | self.sum += val * n
298 | self.count += n
299 | self.avg = self.sum / self.count
300 |
301 |
302 | def adjust_learning_rate(optimizer, epoch, process_name, args):
303 | """Adjust the learning rate according the epoch"""
304 | # print(epoch)
305 | # print(args.schedule[1])
306 | if process_name == 'part_classifiers':
307 | exp = epoch >= args.schedule_part[1] and 2 or epoch >= args.schedule_part[0] and 1 or 0
308 | exp_pre = epoch >= args.schedule_part[1] and 2 or epoch >= args.schedule_part[0] and 2 or 2
309 | elif process_name == 'partnet':
310 | exp = epoch >= args.schedule[1] and 2 or epoch >= args.schedule[0] and 1 or 0
311 | exp_pre = epoch >= args.schedule[1] and 2 or epoch >= args.schedule[0] and 2 or 2
312 | else:
313 | exp = epoch >= args.schedule[1] and 2 or epoch >= args.schedule[0] and 1 or 0
314 | exp_pre = epoch >= args.schedule[1] and 2 or epoch >= args.schedule[0] and 2 or 2
315 |
316 | # print(exp)
317 | lr = args.lr * (args.gamma ** exp)
318 | lr_pre = args.lr * (args.gamma ** exp_pre)
319 | print('LR for new-added', lr)
320 | print('LR for old', lr_pre)
321 | for param_group in optimizer.param_groups:
322 | if param_group['name'] == 'pre-trained':
323 | param_group['lr'] = lr_pre
324 | else:
325 | param_group['lr'] = lr
326 |
327 |
328 | def accuracy(output, target, topk=(1,)):
329 | """Computes the precision@k for the specified values of k"""
330 | maxk = max(topk)
331 | batch_size = target.size(0)
332 | _, pred = output.topk(maxk, 1, True, True)
333 | pred = pred.t()
334 | correct = pred.eq(target.view(1, -1).expand_as(pred))
335 |
336 | res = []
337 | for k in topk:
338 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
339 | res.append(correct_k.mul_(100.0 / batch_size))
340 | return res
341 |
--------------------------------------------------------------------------------
| |