├── .gitignore ├── README.md ├── configs ├── README.md ├── dataset.yml ├── example.yml └── prototype.yml ├── data ├── FCDB_sliding_windows.json ├── flickr_pro_train.pkl └── flickr_pro_val.pkl ├── requirements.txt ├── setup.py ├── tests ├── conftest.py └── test_evaluation.py ├── tools ├── create_dataset.py ├── evaluate.py └── train.py └── vfn ├── __init__.py ├── config ├── __init__.py └── parser.py ├── data ├── FCDB.py ├── FlickrPro.py ├── ICDB.py ├── __init__.py ├── dataset.py ├── evaluation.py ├── image_downloader.py └── ioutils.py ├── network ├── __init__.py ├── backbones.py ├── losses.py └── models.py └── utils ├── __init__.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | datasets/ 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | .idea 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | # PyCharm 110 | .idea/ 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-view-finding-network 2 | 3 | This is a PyTorch implementation of the [view finding network](https://github.com/yiling-chen/view-finding-network) method. 4 | 5 | ## Getting Started 6 | 7 | ### Installation 8 | 9 | ```bash 10 | git clone https://github.com/remorsecs/pytorch-view-finding-network 11 | 12 | cd pytorch-view-finding-network/ 13 | 14 | pip install -r requirements.txt 15 | 16 | python setup.py build develop 17 | ``` 18 | 19 | ### Usage 20 | 21 | #### Configuration 22 | 23 | The configuration file will be parsed by class `ConfigParser` in `viewfinder_benchmark/config/parser.py`. 24 | 25 | You can follow the `configs/prototype.yml` or take the `configs/example.yml` as reference. 26 | 27 | The supported format is visible in `configs/README.md`. 28 | 29 | 30 | #### Start Visdom Server 31 | 32 | We use Visdom for data visualization. 33 | 34 | Start a Visdom server from the command line before training: 35 | 36 | ```bash 37 | visdom 38 | ``` 39 | 40 | It will launch a Visdom server and output PID. The Visdom server can be accessed by going to 41 | `http://localhost:8097` (default port) in browser. 42 | 43 | You can visit the [official site](https://github.com/facebookresearch/visdom#usage) 44 | for more information. 45 | 46 | 47 | #### Train 48 | 49 | ```bash 50 | cd tools/ 51 | 52 | python train.py -c '../configs/example.yml' 53 | ``` 54 | 55 | #### Evaluate 56 | 57 | ```bash 58 | cd tools/ 59 | 60 | python evaluate.py -c '../configs/example.yml' 61 | ``` 62 | -------------------------------------------------------------------------------- /configs/README.md: -------------------------------------------------------------------------------- 1 | # Configuration 2 | 3 | ## Supported Format 4 | 5 | ```yaml 6 | 7 | --- 8 | checkpoint: 9 | root_dir: # Path to checkpoint root, type: str 10 | prefix: # Prefix for each checkpoint file, type: str 11 | 12 | weight: # Path to trained model for evaluation, type: str 13 | 14 | device: # Computing devices, follow the `torch.device` argument. type: str 15 | # See: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device 16 | 17 | model: 18 | backbone: 19 | name: # Name for class in module `viewfinder_benchmark.network.backbones`, supports 'AlexNet', 'VGG', type: str 20 | pretrained: # To load pretrained parameters, type: bool 21 | 22 | train: 23 | num_epochs: # Number of epochs, type: int 24 | viz: # Determine to launch the visdom server, type: bool 25 | 26 | optimizer: 27 | name: # Name for class `torch.optim`, supports 'Adam', type: str 28 | # The rest of arguments defined here will pass to `torch.optim`. 29 | 30 | loss: 31 | name: # Name for loss function, supports 'hinge', type: str 32 | 33 | dataset: 34 | name: # Name for dataset class, supports 'FlickrPro', type: str 35 | root_dir: # Path to dataset root, type: str 36 | gulpio_dir: # Path to GulpIO dataset root, type: str 37 | 38 | dataloader: 39 | # The arguments defined here will pass to `torch.utils.data.DataLoader`. 40 | 41 | validation: 42 | viz: # Determine to launch the visdom server, type: bool 43 | 44 | dataset: 45 | name: # Name for dataset class, supports 'FlickrPro', type: str 46 | root_dir: # Path to dataset root, type: str 47 | gulpio_dir: # Path to GulpIO dataset root, type: str 48 | 49 | dataloader: 50 | # The arguments defined here will pass to `torch.utils.data.DataLoader`. 51 | 52 | evaluate: 53 | FCDB: 54 | root_dir: # Path to FCDB dataset, type: str 55 | download: # Determine to download FCDB dataset, type: bool 56 | 57 | ICDB: 58 | root_dir: # Path to ICDB dataset, type: str 59 | download: # Determine to download ICDB dataset, type: bool 60 | ... 61 | 62 | ``` -------------------------------------------------------------------------------- /configs/dataset.yml: -------------------------------------------------------------------------------- 1 | --- 2 | FlickrPro: 3 | train: 4 | meta: '../data/flickr_pro_train.pkl' 5 | root_dir: '../datasets/FlickrPro/src' 6 | gulpio_dir: '../datasets/FlickrPro/train' 7 | download: False 8 | val: 9 | meta: '../data/flickr_pro_val.pkl' 10 | root_dir: '../datasets/FlickrPro/src' 11 | gulpio_dir: '../datasets/FlickrPro/val' 12 | download: False 13 | 14 | FCDB: 15 | test: 16 | root_dir: '../datasets/FCDB' 17 | download: False 18 | 19 | ICDB: 20 | test: 21 | root_dir: '../datasets/ICDB' 22 | download: False 23 | ... -------------------------------------------------------------------------------- /configs/example.yml: -------------------------------------------------------------------------------- 1 | --- 2 | checkpoint: 3 | root_dir: 'ckpt/exp01' 4 | prefix: 'exp01' 5 | 6 | weight: '../ckpt/exp01/exp01_AlexNet_15.pth' 7 | 8 | device: 'cuda:0' # 'cuda:0', 'cuda:1', ... or 'cpu' 9 | 10 | model: 11 | backbone: 12 | name: 'AlexNet' 13 | pretrained: True 14 | 15 | train: 16 | num_epochs: 15 17 | viz: True 18 | 19 | optimizer: 20 | name: 'Adam' 21 | lr: 0.001 22 | 23 | loss: 24 | name: 'hinge' 25 | 26 | dataset: 27 | name: 'FlickrPro' 28 | root_dir: '../datasets/FlickrPro/src/' 29 | gulpio_dir: '../datasets/FlickrPro/train/' 30 | 31 | dataloader: 32 | batch_size: 50 33 | shuffle: False 34 | 35 | validation: 36 | viz: True 37 | 38 | dataset: 39 | name: 'FlickrPro' 40 | root_dir: '../datasets/FlickrPro/src/' 41 | gulpio_dir: '../datasets/FlickrPro/val/' 42 | 43 | dataloader: 44 | batch_size: 50 45 | shuffle: False 46 | 47 | 48 | evaluate: 49 | FCDB: 50 | root_dir: '../datasets/FCDB' 51 | download: True 52 | 53 | ICDB: 54 | root_dir: '../datasets/ICDB' 55 | download: True 56 | ... 57 | -------------------------------------------------------------------------------- /configs/prototype.yml: -------------------------------------------------------------------------------- 1 | --- 2 | checkpoint: 3 | root_dir: # Path to checkpoint root, type: str 4 | prefix: # Prefix for each checkpoint file, type: str 5 | 6 | weight: # Path to trained model for evaluation, type: str 7 | 8 | device: # Computing devices, follow the `torch.device` argument. type: str 9 | # See: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device 10 | 11 | model: 12 | backbone: 13 | name: # Name for class in module `viewfinder_benchmark.network.backbones`, supports 'AlexNet', 'VGG', type: str 14 | pretrained: # To load pretrained parameters, type: bool 15 | 16 | train: 17 | num_epochs: # Number of epochs, type: int 18 | viz: # Determine to launch the visdom server, type: bool 19 | 20 | optimizer: 21 | name: # Name for class `torch.optim`, supports 'Adam', type: str 22 | # The rest of arguments defined here will pass to `torch.optim`. 23 | 24 | loss: 25 | name: # Name for loss function, supports 'hinge', type: str 26 | 27 | dataset: 28 | name: # Name for dataset class, supports 'FlickrPro', type: str 29 | root_dir: # Path to dataset root, type: str 30 | gulpio_dir: # Path to GulpIO dataset root, type: str 31 | 32 | dataloader: 33 | # The arguments defined here will pass to `torch.utils.data.DataLoader`. 34 | 35 | validation: 36 | viz: # Determine to launch the visdom server, type: bool 37 | 38 | dataset: 39 | name: # Name for dataset class, supports 'FlickrPro', type: str 40 | root_dir: # Path to dataset root, type: str 41 | gulpio_dir: # Path to GulpIO dataset root, type: str 42 | 43 | dataloader: 44 | # The arguments defined here will pass to `torch.utils.data.DataLoader`. 45 | 46 | evaluate: 47 | FCDB: 48 | root_dir: # Path to FCDB dataset, type: str 49 | download: # Determine to download FCDB dataset, type: bool 50 | 51 | ICDB: 52 | root_dir: # Path to ICDB dataset, type: str 53 | download: # Determine to download ICDB dataset, type: bool 54 | ... -------------------------------------------------------------------------------- /data/flickr_pro_train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aquastripe/pytorch-view-finding-network/5518f3ec80d9c9169bc1ac54223cb4d2ea372cae/data/flickr_pro_train.pkl -------------------------------------------------------------------------------- /data/flickr_pro_val.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aquastripe/pytorch-view-finding-network/5518f3ec80d9c9169bc1ac54223cb4d2ea372cae/data/flickr_pro_val.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0 2 | torchvision>=0.3 3 | pytorch-ignite 4 | visdom 5 | scikit-learn 6 | pyyaml 7 | opencv-python 8 | gulpio 9 | pytest -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='pytorch-view-finding-network', 5 | version="0.1", 6 | author="remorsecs", 7 | url="https://github.com/remorsecs/pytorch-view-finding-network", 8 | license="MIT", 9 | packages=find_packages(exclude=("configs", "tests")), 10 | ) -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torchvision.transforms import transforms as T 4 | 5 | from tests.test_evaluation import SimpleImageCropper 6 | from vfn.data.evaluation import ImageCropperEvaluator 7 | from vfn.data.FCDB import FCDB 8 | 9 | 10 | def pytest_addoption(parser): 11 | parser.addoption('--fcdb', type=str, default='') 12 | 13 | 14 | @pytest.fixture(scope='module') 15 | def root_fcdb(request): 16 | return request.config.getoption('--fcdb') 17 | 18 | 19 | @pytest.fixture(scope='module') 20 | def evaluator_simple_net_on_FCDB(root_fcdb): 21 | model = SimpleImageCropper() 22 | dataset = FCDB(root_fcdb, download=False) 23 | device = torch.device('cuda:0') 24 | transforms = T.Compose([ 25 | T.ToPILImage(), 26 | T.Resize((16, 16)), 27 | T.ToTensor(), 28 | ]) 29 | return ImageCropperEvaluator(model, dataset, device, transforms) 30 | -------------------------------------------------------------------------------- /tests/test_evaluation.py: -------------------------------------------------------------------------------- 1 | # To run this test code: 2 | # $ py.test test_evaluation.py --fcdb=/path/to/FCDB/dataset 3 | from itertools import chain 4 | 5 | import pytest 6 | import torch 7 | import torch.nn as nn 8 | 9 | torch.manual_seed(0) 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = False 12 | 13 | 14 | def generate_crops_v1(width, height, device=None): 15 | if device is None: 16 | device = torch.device('cpu') 17 | 18 | d_x = list(chain.from_iterable([i * j] * 5 for i in range(10, 0, -2) for j in range(5))) 19 | d_y = list(chain.from_iterable([i * j for i in range(5)] * 5 for j in range(10, 0, -2))) 20 | d_len = list(chain.from_iterable([i] * 25 for i in range(5, 10))) 21 | 22 | d_x = torch.tensor(d_x, dtype=torch.float32, device=device) * 0.01 23 | d_y = torch.tensor(d_y, dtype=torch.float32, device=device) * 0.01 24 | d_len = torch.tensor(d_len, dtype=torch.float32, device=device) * 0.1 25 | 26 | x = width * d_x 27 | y = height * d_y 28 | w = width * d_len 29 | h = height * d_len 30 | 31 | crops = torch.stack([x, y, w, h]).int().t_().tolist() 32 | return crops 33 | 34 | 35 | def generate_crops_v2(width, height, device=None): 36 | if device is None: 37 | device = torch.device('cpu') 38 | 39 | d_x = [0, 0, 0, 0, 0, 10, 10, 10, 10, 10, 20, 20, 20, 20, 20, 30, 30, 30, 30, 30, 40, 40, 40, 40, 40, 40 | 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 16, 16, 16, 16, 16, 24, 24, 24, 24, 24, 32, 32, 32, 32, 32, 41 | 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 24, 24, 24, 24, 24, 42 | 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 8, 8, 8, 8, 8, 12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 43 | 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 6, 6, 6, 6, 6, 8, 8, 8, 8, 8, ] 44 | d_y = [0, 10, 20, 30, 40, 0, 10, 20, 30, 40, 0, 10, 20, 30, 40, 0, 10, 20, 30, 40, 0, 10, 20, 30, 40, 45 | 0, 8, 16, 24, 32, 0, 8, 16, 24, 32, 0, 8, 16, 24, 32, 0, 8, 16, 24, 32, 0, 8, 16, 24, 32, 46 | 0, 6, 12, 18, 24, 0, 6, 12, 18, 24, 0, 6, 12, 18, 24, 0, 6, 12, 18, 24, 0, 6, 12, 18, 24, 47 | 0, 4, 8, 12, 16, 0, 4, 8, 12, 16, 0, 4, 8, 12, 16, 0, 4, 8, 12, 16, 0, 4, 8, 12, 16, 48 | 0, 2, 4, 6, 8, 0, 2, 4, 6, 8, 0, 2, 4, 6, 8, 0, 2, 4, 6, 8, 0, 2, 4, 6, 8, ] 49 | d_len = [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 50 | 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 51 | 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 52 | 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 53 | 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ] 54 | 55 | d_x = torch.tensor(d_x, dtype=torch.float32, device=device) * 0.01 56 | d_y = torch.tensor(d_y, dtype=torch.float32, device=device) * 0.01 57 | d_len = torch.tensor(d_len, dtype=torch.float32, device=device) * 0.1 58 | 59 | x = width * d_x 60 | y = height * d_y 61 | w = width * d_len 62 | h = height * d_len 63 | 64 | crops = torch.stack([x, y, w, h]).int().t_().tolist() 65 | return crops 66 | 67 | 68 | def generate_crops_v3(width, height): 69 | d_x = [0, 0, 0, 0, 0, 10, 10, 10, 10, 10, 20, 20, 20, 20, 20, 30, 30, 30, 30, 30, 40, 40, 40, 40, 40, 70 | 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 16, 16, 16, 16, 16, 24, 24, 24, 24, 24, 32, 32, 32, 32, 32, 71 | 0, 0, 0, 0, 0, 6, 6, 6, 6, 6, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 24, 24, 24, 24, 24, 72 | 0, 0, 0, 0, 0, 4, 4, 4, 4, 4, 8, 8, 8, 8, 8, 12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 73 | 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 6, 6, 6, 6, 6, 8, 8, 8, 8, 8, ] 74 | d_y = [0, 10, 20, 30, 40, 0, 10, 20, 30, 40, 0, 10, 20, 30, 40, 0, 10, 20, 30, 40, 0, 10, 20, 30, 40, 75 | 0, 8, 16, 24, 32, 0, 8, 16, 24, 32, 0, 8, 16, 24, 32, 0, 8, 16, 24, 32, 0, 8, 16, 24, 32, 76 | 0, 6, 12, 18, 24, 0, 6, 12, 18, 24, 0, 6, 12, 18, 24, 0, 6, 12, 18, 24, 0, 6, 12, 18, 24, 77 | 0, 4, 8, 12, 16, 0, 4, 8, 12, 16, 0, 4, 8, 12, 16, 0, 4, 8, 12, 16, 0, 4, 8, 12, 16, 78 | 0, 2, 4, 6, 8, 0, 2, 4, 6, 8, 0, 2, 4, 6, 8, 0, 2, 4, 6, 8, 0, 2, 4, 6, 8, ] 79 | d_len = [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 80 | 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 81 | 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 82 | 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 83 | 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, ] 84 | x = list(map(lambda item: int(item * width * 0.01), d_x)) 85 | y = list(map(lambda item: int(item * height * 0.01), d_y)) 86 | w = list(map(lambda item: int(item * width * 0.1), d_len)) 87 | h = list(map(lambda item: int(item * height * 0.1), d_len)) 88 | 89 | return list(zip(x, y, w, h)) 90 | 91 | 92 | class SimpleImageCropper(nn.Module): 93 | 94 | def __init__(self): 95 | super(SimpleImageCropper, self).__init__() 96 | self.fc1 = nn.Sequential( 97 | nn.Flatten(), 98 | nn.Linear(16 * 16 * 3, 1), 99 | ) 100 | 101 | def forward(self, x): 102 | x = x.view(-1, 16 * 16 * 3) 103 | return self.fc1(x) 104 | 105 | 106 | def test_evaluate_iou(evaluator_simple_net_on_FCDB): 107 | iou_expected = torch.tensor([0.4262]) 108 | iou_actual = evaluator_simple_net_on_FCDB.intersection_over_union 109 | assert torch.allclose(iou_expected, iou_actual, atol=1e-04) 110 | 111 | 112 | def test_evaluate_boundary_displacement(evaluator_simple_net_on_FCDB): 113 | boundary_displacement_expected = torch.tensor([0.1602]) 114 | boundary_displacement_actual = evaluator_simple_net_on_FCDB.boundary_displacement 115 | assert torch.allclose(boundary_displacement_expected, boundary_displacement_actual, atol=1e-04) 116 | 117 | 118 | def test_alpha_recall(evaluator_simple_net_on_FCDB): 119 | alpha_recall_expected = torch.tensor([3.7681]) 120 | alpha_recall_actual = evaluator_simple_net_on_FCDB.alpha_recall 121 | assert torch.allclose(alpha_recall_expected, alpha_recall_actual, atol=1e-04) 122 | 123 | 124 | def test_all_metrics(evaluator_simple_net_on_FCDB): 125 | iou_expected = torch.tensor([0.4262]) 126 | iou_actual = evaluator_simple_net_on_FCDB.intersection_over_union 127 | 128 | boundary_displacement_expected = torch.tensor([0.1602]) 129 | boundary_displacement_actual = evaluator_simple_net_on_FCDB.boundary_displacement 130 | 131 | alpha_recall_expected = torch.tensor([3.7681]) 132 | alpha_recall_actual = evaluator_simple_net_on_FCDB.alpha_recall 133 | 134 | assert torch.allclose(iou_expected, iou_actual, atol=1e-04) 135 | assert torch.allclose(boundary_displacement_expected, boundary_displacement_actual, atol=1e-04) 136 | assert torch.allclose(alpha_recall_expected, alpha_recall_actual, atol=1e-04) 137 | -------------------------------------------------------------------------------- /tools/create_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import cv2 5 | import argparse 6 | import shutil 7 | import multiprocessing 8 | import yaml 9 | from gulpio.fileio import GulpIngestor 10 | from gulpio.dataset import GulpDirectory 11 | from gulpio.loader import DataLoader 12 | from vfn.data.dataset import ImagePairListAdapter, ImagePairVisDataset 13 | from vfn.data.FlickrPro import FlickrPro 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser( 18 | description='Create image pair data for training trackers using GulpIO.') 19 | parser.add_argument('-w', '--workers', type=int, default=-1, 20 | help="num of workers. -x uses (all - x) cores [-1 default].") 21 | parser.add_argument('-c', '--config', type=str, default='../configs/dataset.yml', 22 | help='configuration file') 23 | parser.add_argument('-N', '--name', type=str, default='FlickrPro', 24 | help='Dataset name [default: "FlickrPro"]') 25 | parser.add_argument('-r', '--root_folder', type=str, 26 | help='root folder of GulpIO data') 27 | parser.add_argument('-n', '--images_per_chunk', type=int, default=2048, 28 | help='number of images in one chunk [default: 2048]') 29 | parser.add_argument('-S', '--image_size', type=int, default=-1, 30 | help='size of smaller edge of resized frames [default: -1 (no resizing)]') 31 | parser.add_argument('-s', '--shuffle', action='store_true', 32 | help='Shuffle the dataset before ingestion [default: False]') 33 | parser.add_argument('-v', '--viz', action='store_true', 34 | help='Visualize the dataset [default: False]') 35 | args = parser.parse_args() 36 | 37 | return args 38 | 39 | 40 | def parse_config(config_path, name): 41 | with open(config_path, 'r') as f: 42 | config = yaml.load(f, Loader=yaml.BaseLoader) 43 | 44 | if name not in config: 45 | raise Exception('Unrecognized dataset {}'.format(name)) 46 | 47 | return config 48 | 49 | 50 | def check_existing_dataset(data_path): 51 | gd = GulpDirectory(data_path) 52 | if gd.num_chunks > 0: 53 | print("Found existing dataset containing {} chunks in {}.".format(gd.num_chunks, data_path)) 54 | print("Erase the existing dataset and create a new one? (y/n)") 55 | action = input() 56 | while action not in ['y', 'n']: 57 | action = input("Enter (y/n):") 58 | if action == 'y': 59 | shutil.rmtree(data_path) 60 | elif action == 'n': 61 | "Exiting. Nothing changed." 62 | sys.exit() 63 | 64 | 65 | if __name__ == "__main__": 66 | args = parse_args() 67 | num_workers = args.workers 68 | 69 | if num_workers < 0: 70 | num_workers = multiprocessing.cpu_count() + num_workers 71 | 72 | config = parse_config(args.config, args.name) 73 | 74 | # image_list_file = args.image_list 75 | images_per_chunk = args.images_per_chunk 76 | img_size = args.image_size 77 | viz = args.viz 78 | 79 | if viz: 80 | dataset = ImagePairVisDataset(args.root_folder) 81 | loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False) 82 | print("data size:", len(dataset)) 83 | 84 | for img_full, img_crop in loader: 85 | cv2.imshow("Image Full", img_full[0]) 86 | cv2.imshow("Image Crop", img_crop[0]) 87 | key = cv2.waitKey() 88 | if key == 27: 89 | break 90 | else: 91 | for subset in ['train', 'val']: 92 | check_existing_dataset(config[args.name][subset]['gulpio_dir']) 93 | 94 | adapter = ImagePairListAdapter( 95 | FlickrPro( 96 | config[args.name][subset]['root_dir'], 97 | config[args.name][subset]['meta'], 98 | config[args.name][subset]['download'] 99 | ), 100 | # args.shuffle 101 | ) 102 | 103 | ingestor = GulpIngestor(adapter, config[args.name][subset]['gulpio_dir'], images_per_chunk, num_workers) 104 | ingestor() # call to trigger ingestion 105 | 106 | -------------------------------------------------------------------------------- /tools/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torchvision.transforms import transforms 4 | 5 | from vfn.config.parser import ConfigParser 6 | from vfn.data.evaluation import ImageCropperEvaluator 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('-c', '--config', type=str, help='Path to config file (.yml)', default='../configs/example.yml') 12 | args = parser.parse_args() 13 | 14 | configs = ConfigParser(args.config) 15 | 16 | testsets = [ 17 | configs.parse_FCDB(), 18 | configs.parse_ICDB(subset_selector=1), 19 | configs.parse_ICDB(subset_selector=2), 20 | configs.parse_ICDB(subset_selector=3), 21 | ] 22 | device = configs.parse_device() 23 | model = configs.parse_model().to(device) 24 | weight = torch.load(configs.configs['weight'], map_location=lambda storage, loc: storage) 25 | model.load_state_dict(weight) 26 | data_transforms = transforms.Compose([ 27 | transforms.ToPILImage(), 28 | transforms.Resize((224, 224)), 29 | transforms.ToTensor(), 30 | ]) 31 | 32 | for testset in testsets: 33 | evaluator = ImageCropperEvaluator(model, testset, device, data_transforms) 34 | print('Evaluate on {}'.format(testset)) 35 | print('Average overlap ratio: {:.4f}'.format(evaluator.intersection_over_union)) 36 | print('Average boundary displacement: {:.4f}'.format(evaluator.boundary_displacement)) 37 | print('Alpha recall: {:.4f}'.format(evaluator.alpha_recall)) 38 | print('Total image evaluated: {}'.format(evaluator.num_evaluated_images)) 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.cuda 5 | 6 | from ignite.engine import Events, Engine 7 | from ignite.handlers import ModelCheckpoint 8 | from tqdm import tqdm 9 | from visdom import Visdom 10 | 11 | from vfn.config.parser import ConfigParser 12 | from vfn.utils.visualization import plot_bbox 13 | 14 | 15 | class Trainer(object): 16 | 17 | def __init__(self, configs): 18 | self.configs = configs 19 | self.viz = self.configs.configs['train']['viz'] 20 | self._init_config_settings() 21 | self._init_logger() 22 | self._init_trainer() 23 | self._init_validator() 24 | 25 | def _init_config_settings(self): 26 | self.device = self.configs.parse_device() 27 | self.num_epochs = self.configs.configs['train']['num_epochs'] 28 | self.model_name = self.configs.get_model_name() 29 | self.model = self.configs.parse_model().to(self.device) 30 | self.data_loaders = self.configs.parse_dataloader() 31 | self.optimizer = self.configs.parse_optimizer() 32 | self.optimizer = self.optimizer(self.model.parameters()) 33 | self.loss_fn = self.configs.parse_loss_function() 34 | 35 | def _init_logger(self): 36 | self.desc = 'Loss: {:.6f}' 37 | self.pbar = tqdm( 38 | initial=0, 39 | leave=False, 40 | total=len(self.data_loaders['train']), 41 | desc=self.desc.format(0), 42 | ascii=True, 43 | ) 44 | self.log_interval = 1 45 | if self.viz: 46 | self.vis = Visdom() 47 | 48 | def _init_trainer(self): 49 | self.trainer = Engine(self._inference) 50 | self.trainer.add_event_handler(Events.EPOCH_STARTED, self._set_model_stage, is_train=True) 51 | self.trainer.add_event_handler(Events.ITERATION_COMPLETED, self._log_iteration, is_train=True) 52 | self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self._log_epoch, is_train=True) 53 | self.trainer.add_event_handler(Events.EPOCH_COMPLETED, self._run_validation) 54 | ckpt_handler = ModelCheckpoint( 55 | dirname=self.configs.configs['checkpoint']['root_dir'], 56 | filename_prefix=self.configs.configs['checkpoint']['prefix'], 57 | save_interval=1, 58 | n_saved=self.num_epochs, 59 | require_empty=False, 60 | ) 61 | self.trainer.add_event_handler(Events.EPOCH_COMPLETED, ckpt_handler, {self.model_name: self.model}) 62 | 63 | def _init_validator(self): 64 | self.validator = Engine(self._inference) 65 | self.validator.add_event_handler(Events.EPOCH_STARTED, self._set_model_stage) 66 | self.validator.add_event_handler(Events.ITERATION_COMPLETED, self._log_iteration) 67 | self.validator.add_event_handler(Events.EPOCH_COMPLETED, self._log_epoch) 68 | 69 | def _set_model_stage(self, engine, is_train=False): 70 | self.model.train(is_train) 71 | torch.set_grad_enabled(is_train) 72 | engine.state.is_train = is_train 73 | engine.state.cum_average_loss = 0 74 | 75 | def _inference(self, engine, batch): 76 | # fetch inputs and transfer to specific device 77 | image_raw, image_crop = batch 78 | image_raw, image_crop = image_raw.to(self.device), image_crop.to(self.device) 79 | 80 | # forward 81 | score_I = self.model(image_raw) 82 | score_C = self.model(image_crop) 83 | loss = self.loss_fn(score_I, score_C) 84 | 85 | # backward 86 | if engine.state.is_train: 87 | self.optimizer.zero_grad() 88 | loss.backward() 89 | self.optimizer.step() 90 | 91 | engine.state.iteration_loss = loss.mean().item() 92 | engine.state.cum_average_loss += loss.mean().item() 93 | 94 | def _log_iteration(self, engine, is_train=False): 95 | average_loss = engine.state.cum_average_loss / engine.state.iteration 96 | self.pbar.desc = self.desc.format(average_loss) 97 | self.pbar.update(self.log_interval) 98 | 99 | if is_train and self.viz: 100 | self.vis.line( 101 | Y=np.array([self.trainer.state.iteration_loss]), 102 | X=np.array([self.trainer.state.iteration]), 103 | win='loss-iteration', 104 | env=self.configs.configs['checkpoint']['prefix'], 105 | update='append', 106 | name='train', 107 | opts=dict( 108 | title='Learning Curve', 109 | showlegend=True, 110 | xlabel='Iteration', 111 | ylabel='Loss', 112 | ) 113 | ) 114 | 115 | def _log_epoch(self, engine, is_train=False): 116 | stage = 'Training' if is_train else 'Validation' 117 | self.pbar.refresh() 118 | if is_train: 119 | tqdm.write('Epoch {}:'.format(self.trainer.state.epoch)) 120 | 121 | average_loss = engine.state.cum_average_loss / engine.state.iteration 122 | tqdm.write('{} Loss: {:.6f}'.format(stage, average_loss)) 123 | self.pbar.n = self.pbar.last_print_n = 0 124 | if self.viz: 125 | self.vis.line( 126 | Y=np.array([average_loss]), 127 | X=np.array([self.trainer.state.epoch]), 128 | win='loss-epoch', 129 | env=self.configs.configs['checkpoint']['prefix'], 130 | update='append', 131 | name=stage, 132 | opts=dict( 133 | title='Learning Curve', 134 | showlegend=True, 135 | xlabel='Epoch', 136 | ylabel='Loss', 137 | ) 138 | ) 139 | 140 | def _run_validation(self, engine): 141 | self.validator.run(self.data_loaders['val']) 142 | 143 | def run(self): 144 | self.trainer.run(self.data_loaders['train'], self.num_epochs) 145 | 146 | 147 | def main(): 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('-c', '--config', type=str, help='Path to config file (.yml)', default='../configs/example.yml') 150 | args = parser.parse_args() 151 | 152 | configs = ConfigParser(args.config) 153 | trainer = Trainer(configs) 154 | trainer.run() 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /vfn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aquastripe/pytorch-view-finding-network/5518f3ec80d9c9169bc1ac54223cb4d2ea372cae/vfn/__init__.py -------------------------------------------------------------------------------- /vfn/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aquastripe/pytorch-view-finding-network/5518f3ec80d9c9169bc1ac54223cb4d2ea372cae/vfn/config/__init__.py -------------------------------------------------------------------------------- /vfn/config/parser.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.optim as optim 5 | import yaml 6 | 7 | from torch.utils.data import DataLoader 8 | from torchvision import transforms 9 | 10 | import vfn.network.backbones as backbones 11 | import vfn.network.losses as losses 12 | import vfn.network.models as models 13 | 14 | from vfn.data.FCDB import FCDB 15 | from vfn.data.ICDB import ICDB 16 | from vfn.data.dataset import ImagePairDataset 17 | 18 | 19 | class ConfigParser: 20 | 21 | def __init__(self, config_file): 22 | self._load(config_file) 23 | self._init_name() 24 | self.input_dim = 0 25 | 26 | def _init_name(self): 27 | # All `*_name` are disposable because the `pop()` operation will modify the origin `configs`. 28 | self.backbone_name = self.configs['model']['backbone'].pop('name') 29 | self.optimizer_name = self.configs['train']['optimizer'].pop('name') 30 | self.loss_name = self.configs['train']['loss'].pop('name') 31 | self.dataset_name = self.configs['train']['dataset'].pop('name') 32 | 33 | def _load(self, config_file): 34 | with open(config_file, 'r') as f: 35 | self.configs = yaml.load(f, Loader=yaml.FullLoader) 36 | 37 | def get_model_name(self): 38 | return self.backbone_name 39 | 40 | def parse_model(self): 41 | backbone_model = None 42 | if self.backbone_name == 'AlexNet': 43 | backbone_model = backbones.AlexNet 44 | elif self.backbone_name == 'VGG': 45 | backbone_model = backbones.VGG 46 | 47 | backbone = backbone_model(**self.configs['model']['backbone']) 48 | model = models.ViewFindingNet(backbone=backbone) 49 | self.input_dim = backbone.input_dim() 50 | return model 51 | 52 | def parse_optimizer(self): 53 | optimizer_fn = None 54 | 55 | if self.optimizer_name == 'SGD': 56 | optimizer_fn = optim.SGD 57 | elif self.optimizer_name == 'Adam': 58 | optimizer_fn = optim.Adam 59 | 60 | optimizer = partial(optimizer_fn, **self.configs['train']['optimizer']) 61 | return optimizer 62 | 63 | def parse_loss_function(self): 64 | loss_fn = None 65 | 66 | if self.loss_name == 'hinge': 67 | loss_fn = losses.hinge_loss 68 | elif self.loss_name == 'ranknet': 69 | loss_fn = losses.ranknet_loss 70 | 71 | return loss_fn 72 | 73 | def parse_dataloader(self): 74 | 75 | # build data augmentation transforms 76 | data_transform = transforms.Compose([ 77 | transforms.ToPILImage(), 78 | transforms.Resize((self.input_dim, self.input_dim)), 79 | transforms.RandomHorizontalFlip(), 80 | transforms.ColorJitter(brightness=0.01, contrast=0.05), 81 | transforms.ToTensor(), 82 | ]) 83 | 84 | train_dataset = ImagePairDataset(self.configs['train']['dataset']['gulpio_dir'], 85 | data_transform) 86 | val_dataset = ImagePairDataset(self.configs['validation']['dataset']['gulpio_dir'], 87 | data_transform) 88 | 89 | print('train_size:', len(train_dataset)) 90 | print('val_size:', len(val_dataset)) 91 | 92 | data_loaders = dict( 93 | train=DataLoader(train_dataset, num_workers=8, **self.configs['train']['dataloader']), 94 | val=DataLoader(val_dataset, num_workers=8, **self.configs['validation']['dataloader']), 95 | ) 96 | return data_loaders 97 | 98 | def parse_device(self): 99 | return torch.device(self.configs['device']) 100 | 101 | def parse_FCDB(self): 102 | return FCDB(**self.configs['evaluate']['FCDB']) 103 | 104 | def parse_ICDB(self, subset_selector): 105 | return ICDB(subset=subset_selector, **self.configs['evaluate']['ICDB']) 106 | -------------------------------------------------------------------------------- /vfn/data/FCDB.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import os 5 | import cv2 6 | import json 7 | import shutil 8 | from tqdm import trange 9 | from torch.utils.data import Dataset 10 | from vfn.data.ioutils import download 11 | from vfn.data.image_downloader import ImageDownloader 12 | from vfn.data.evaluation import ImageCropperEvaluator 13 | 14 | 15 | class FCDB(Dataset): 16 | __meta_name = 'FCDB-%s.json' 17 | 18 | def __init__(self, root_dir, subset='testing', download=True): 19 | super(FCDB, self).__init__() 20 | assert subset in ['training', 'testing', 'all'], 'Unknown subset {}' % subset 21 | 22 | self.root_dir = root_dir 23 | self.subset = subset 24 | self.meta_file = os.path.join(root_dir, self.__meta_name % subset) 25 | 26 | if download: 27 | self._download(root_dir) 28 | 29 | self.img_list, self.img_sizes, self.annotations = self._fetch_metadata() 30 | 31 | def get_all_items(self): 32 | return self.img_list, self.img_sizes, self.annotations 33 | 34 | def __len__(self): 35 | return len(self.img_list) 36 | 37 | def __getitem__(self, index): 38 | return self.img_list[index], self.img_sizes[index], self.annotations[index] 39 | 40 | def __str__(self): 41 | return 'FCDB dataset' 42 | 43 | def _download(self, root_dir): 44 | if not os.path.isdir(root_dir): 45 | os.makedirs(root_dir) 46 | 47 | # download annotation file to root_dir 48 | url_fmt = \ 49 | 'https://raw.githubusercontent.com/yiling-chen/flickr-cropping-dataset/master/cropping_%s_set.json' 50 | if not os.path.exists(self.meta_file): 51 | print('Downloading FCDB annotation file...') 52 | if self.subset in ['training', 'testing']: 53 | anno_url = url_fmt % self.subset 54 | download(anno_url, os.path.join(self.root_dir, os.path.basename(anno_url))) 55 | shutil.move(os.path.join(self.root_dir, os.path.basename(anno_url)), self.meta_file) 56 | print() 57 | elif self.subset == 'all': 58 | # download both the training and testing sets 59 | anno_url = url_fmt % 'training' 60 | train_path = os.path.join(self.root_dir, os.path.basename(anno_url)) 61 | download(anno_url, train_path) 62 | print() 63 | anno_url = url_fmt % 'testing' 64 | test_path = os.path.join(self.root_dir, os.path.basename(anno_url)) 65 | download(anno_url, test_path) 66 | print() 67 | # Merge training and testing sets 68 | merge_dataset = json.load(open(train_path, 'r')) + json.load(open(test_path, 'r')) 69 | json.dump(merge_dataset, open(self.meta_file, 'w')) 70 | os.remove(train_path) 71 | os.remove(test_path) 72 | 73 | # Collect URLs and pass to ImageDownloader 74 | db = json.load(open(self.meta_file, 'r')) 75 | img_urls = [x['url'] for x in db] 76 | ImageDownloader.download(root_dir, img_urls) 77 | 78 | def _fetch_metadata(self): 79 | assert os.path.isfile(self.meta_file), "Metadata does not exist! Please download the FCDB dataset first!" 80 | 81 | print('Reading metadata...') 82 | db = json.load(open(self.meta_file, 'r')) 83 | img_list, img_sizes, annotations = [], [], [] 84 | for i in trange(len(db)): 85 | # Some images might not be available on Flickr anymore, skip them 86 | img_path = os.path.join(self.root_dir, os.path.basename(db[i]['url'])) 87 | if not os.path.exists(img_path): 88 | continue 89 | img_list.append(img_path) 90 | annotations.append(db[i]['crop']) 91 | height, width = cv2.imread(img_path).shape[:2] 92 | img_sizes.append((width, height)) 93 | print('Unpacked', len(img_list), 'records.') 94 | 95 | return img_list, img_sizes, annotations 96 | 97 | 98 | def main(): 99 | db = FCDB("/mnt/Data-2/Projects/faster-view-finding-network/FCDB", subset='all') 100 | _, img_sizes, ground_truth = db.get_all_items() 101 | 102 | evaluator = ImageCropperEvaluator() 103 | # evaluate ground truth, this should get perfect results 104 | evaluator.evaluate(ground_truth, ground_truth, img_sizes) 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /vfn/data/FlickrPro.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import os 5 | import pickle 6 | 7 | from torch.utils.data import Dataset 8 | from tqdm import trange 9 | from vfn.data.image_downloader import ImageDownloader 10 | 11 | 12 | class FlickrPro(Dataset): 13 | 14 | def __init__(self, root_dir, meta_file, download=False): 15 | super(FlickrPro, self).__init__() 16 | self.root_dir = root_dir 17 | self.meta_file = meta_file 18 | self._fetch_metadata() 19 | 20 | if download: 21 | self._download_images() 22 | 23 | def __len__(self): 24 | return len(self.filenames) 25 | 26 | def __getitem__(self, i): 27 | return self.filenames[i], self.annotations[i] 28 | 29 | def get_all_items(self): 30 | return self.filenames, self.annotations, self.urls 31 | 32 | def _download_images(self): 33 | if not os.path.isdir(self.root_dir): 34 | os.makedirs(self.root_dir) 35 | 36 | print('Downloading FlickrPro images...') 37 | ImageDownloader.download(self.root_dir, self.urls) 38 | print('Done') 39 | 40 | def _fetch_metadata(self): 41 | assert os.path.isfile(self.meta_file), "Metadata does not exist! Please download the FlickrPro dataset first!" 42 | 43 | print('Reading metadata...') 44 | with open(self.meta_file, 'rb') as f: 45 | db = pickle.load(f) 46 | 47 | self.filenames = [] 48 | self.annotations = [] 49 | self.urls = [] 50 | 51 | for i in trange(len(db) // 14): 52 | url = db[i*14]['url'] 53 | self.urls.append(url) 54 | 55 | filename = os.path.join(self.root_dir, os.path.basename(url)) 56 | 57 | for j in range(14): 58 | self.filenames.append(filename) 59 | self.annotations.append(db[i*14 + j]['crop']) 60 | 61 | # print(len(self.filenames), len(self.urls), len(self.annotations)) 62 | print('Unpacked', len(db), 'records.') 63 | 64 | 65 | if __name__ == "__main__": 66 | print(os.getcwd()) 67 | flickr_pro = FlickrPro(root_dir="../../datasets/FlickrPro/src", 68 | meta_file="../../data/flickr_pro_train.pkl", 69 | download=False) 70 | print(flickr_pro[0]) 71 | -------------------------------------------------------------------------------- /vfn/data/ICDB.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import os 5 | import cv2 6 | from tqdm import trange 7 | from collections import defaultdict 8 | from torch.utils.data import Dataset 9 | from vfn.data.ioutils import download, extract 10 | from vfn.data.evaluation import ImageCropperEvaluator 11 | 12 | 13 | class ICDB(Dataset): 14 | __meta_name = 'cuhk_cropping.zip' 15 | 16 | def __init__(self, root_dir, subset=1, download=True): 17 | super(ICDB, self).__init__() 18 | # CUHK dataset contains three annotation subsets. 19 | assert subset in [1, 2, 3], 'Unknown subset: %d (Valid value: 1, 2, 3)' % subset 20 | 21 | self.subset = subset 22 | self.root_dir = root_dir 23 | self.image_dir = os.path.join(root_dir, 'All_Images') 24 | self.meta_file = os.path.join(root_dir, self.__meta_name) 25 | 26 | if download: 27 | self._download(root_dir) 28 | 29 | self.img_list, self.img_sizes, self.crops, self.category = self._fetch_metadata() 30 | self.img_groups, self.crop_groups, self. size_groups = self._group_metadata() 31 | 32 | def __len__(self): 33 | return len(self.img_list) 34 | 35 | def __getitem__(self, index): 36 | return self.img_list[index], self.img_sizes[index], self.crops[index], self.category[index] 37 | 38 | def __str__(self): 39 | return 'ICDB dataset (Subset-{})'.format(self.subset) 40 | 41 | def get_metadata_by_group(self, label): 42 | assert label in self.img_groups.keys(), 'Unknown category %s' % label 43 | return self.img_groups[label], self.crop_groups[label], self.size_groups[label] 44 | 45 | def get_categories(self): 46 | return self.img_groups.keys() 47 | 48 | def _download(self, root_dir): 49 | if not os.path.isdir(root_dir): 50 | os.makedirs(root_dir) 51 | 52 | # download annotation file to root_dir 53 | if not os.path.exists(self.meta_file): 54 | print('Downloading CUHK ICDB dataset...') 55 | anno_url = \ 56 | 'http://personal.ie.cuhk.edu.hk/~ccloy/files/datasets/cuhk_cropping.zip' 57 | download(anno_url, self.meta_file) 58 | 59 | if not os.path.isdir(self.image_dir): 60 | if not os.path.isfile(os.path.join(self.root_dir, 'All_Images.zip')): 61 | print('\nExtracting dataset...') 62 | extract(self.meta_file, root_dir) 63 | 64 | print('Extracting images...') 65 | extract(os.path.join(self.root_dir, 'All_Images.zip'), self.root_dir) 66 | 67 | def _fetch_metadata(self): 68 | annotation_file = os.path.join(self.root_dir, 'Cropping parameters.txt') 69 | assert os.path.exists(annotation_file), 'Parameter file does not exist!' 70 | 71 | with open(annotation_file, 'r') as f: 72 | lines = f.readlines() 73 | 74 | print('Reading metadata...') 75 | num_images = round(len(lines) / 4) 76 | img_list, img_sizes, annotations, category = [], [], [], [] 77 | for i in trange(num_images): 78 | label, filename = lines[i*4].strip().split('\\') 79 | crop = [int(x) for x in lines[i*4 + self.subset].split(' ')] 80 | # convert from (y1, y2, x1, x2) to (x, y, w, h) format 81 | annotations.append([crop[2], crop[0], crop[3] - crop[2], crop[1] - crop[0]]) 82 | img_path = os.path.join(self.image_dir, filename) 83 | img_list.append(img_path) 84 | height, width = cv2.imread(os.path.join(self.image_dir, filename)).shape[:2] 85 | img_sizes.append((width, height)) 86 | category.append(label) 87 | 88 | return img_list, img_sizes, annotations, category 89 | 90 | def _group_metadata(self): 91 | img_groups, crop_groups, size_groups = defaultdict(list), defaultdict(list), defaultdict(list) 92 | for img, size, crop, label in zip(self.img_list, self.img_sizes, self.crops, self.category): 93 | img_groups[label].append(img) 94 | crop_groups[label].append(crop) 95 | size_groups[label].append(size) 96 | return img_groups, crop_groups, size_groups 97 | 98 | 99 | def main(): 100 | db = ICDB("../../../ICDB") 101 | print(db[0]) 102 | 103 | _, crops, sizes = db.get_metadata_by_group('animal') 104 | print(db.get_categories()) 105 | 106 | evaluator = ImageCropperEvaluator() 107 | # evaluate ground truth, this should get perfect results 108 | evaluator.evaluate(crops, crops, sizes) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /vfn/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aquastripe/pytorch-view-finding-network/5518f3ec80d9c9169bc1ac54223cb4d2ea372cae/vfn/data/__init__.py -------------------------------------------------------------------------------- /vfn/data/dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | from gulpio.adapters import AbstractDatasetAdapter 7 | from gulpio.utils import ImageNotFound 8 | from gulpio.dataset import GulpDirectory, GulpIOEmptyFolder 9 | 10 | 11 | class ImagePairListAdapter(AbstractDatasetAdapter): 12 | def __init__(self, src_dataset, shuffle=True): 13 | self.source_dataset = src_dataset 14 | self.data = self._parse_dataset() 15 | if shuffle: 16 | random.shuffle(self.data) 17 | 18 | def __len__(self): 19 | return len(self.source_dataset) 20 | 21 | def _parse_dataset(self): 22 | filenames, annotations, _ = self.source_dataset.get_all_items() 23 | data = [] 24 | for i, (filename, annotation) in enumerate(zip(filenames, annotations)): 25 | data.append({'id': i, 26 | 'filename': filename, 27 | 'annotation': annotation}) 28 | return data 29 | 30 | def iter_data(self, slice_element=None): 31 | slice_element = slice_element or slice(0, len(self)) 32 | slice_data = self.data[slice_element] 33 | for item in slice_data: 34 | try: 35 | image = cv2.imread(item['filename']) 36 | if len(image.shape) == 2: 37 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) 38 | img_full = np.copy(image) 39 | x, y, w, h = item['annotation'] 40 | img_crop = np.copy(image[y:y+h, x:x+w, :]) 41 | if 0 in img_crop.shape: 42 | print(img_full.shape) 43 | print(item['annotation']) 44 | continue 45 | except ImageNotFound: 46 | print("Failed to read {}!".format(item['filename'])) 47 | continue # skip the item if image is not readable 48 | # Always encapsulate your data into a dict with (id, meta, frames) keys 49 | # which will be processed by gulpio ChunkWriter 50 | result = {'meta': item, 51 | 'frames': [img_full, img_crop], 52 | 'id': item['id']} 53 | yield result 54 | 55 | 56 | class ImagePairDataset(Dataset): 57 | def __init__(self, data_path, transforms): 58 | """Simple image pair data loader for GulpIO format. 59 | Args: 60 | data_path (str): path to GulpIO dataset folder 61 | is_va (bool): sets the necessary augmention procedure. 62 | transform (object): set of augmentation steps defined by 63 | Compose(). Default is None. 64 | target_transform (func): performs preprocessing on labels if 65 | defined. Default is None. 66 | """ 67 | self.gd = GulpDirectory(data_path) 68 | self.num_chunks = self.gd.num_chunks 69 | self.transforms = transforms 70 | if self.num_chunks == 0: 71 | raise (GulpIOEmptyFolder("Found 0 data binaries in subfolders " + 72 | "of: ".format(data_path))) 73 | 74 | self.data_path = data_path 75 | print("Found {} chunks in {}".format(self.num_chunks, self.data_path)) 76 | self.items = list(self.gd.merged_meta_dict.items()) 77 | 78 | def __getitem__(self, index): 79 | """ 80 | With the given index, it fetches frames. This function is called 81 | by PyTorch DataLoader threads. Each Dataloader thread loads a single 82 | batch by calling this function per instance. 83 | """ 84 | item_id, item_info = self.items[index] 85 | img, meta = self.gd[item_id] 86 | img_full, img_crop = img 87 | 88 | img_full = img_full[..., [2, 1, 0]] 89 | img_crop = img_crop[..., [2, 1, 0]] 90 | 91 | if self.transforms: 92 | img_full = self.transforms(img_full) 93 | img_crop = self.transforms(img_crop) 94 | 95 | return img_full, img_crop 96 | 97 | def __len__(self): 98 | return len(self.items) 99 | 100 | 101 | class ImagePairVisDataset(Dataset): 102 | def __init__(self, data_path): 103 | """Simple image pair data loader for GulpIO format. 104 | Args: 105 | data_path (str): path to GulpIO dataset folder 106 | """ 107 | self.gd = GulpDirectory(data_path) 108 | self.num_chunks = self.gd.num_chunks 109 | if self.num_chunks == 0: 110 | raise (GulpIOEmptyFolder("Found 0 data binaries in subfolders " + 111 | "of: ".format(data_path))) 112 | 113 | self.data_path = data_path 114 | print("Found {} chunks in {}".format(self.num_chunks, self.data_path)) 115 | self.items = list(self.gd.merged_meta_dict.items()) 116 | 117 | def __getitem__(self, index): 118 | """ 119 | With the given index, it fetches frames. This function is called 120 | by PyTorch DataLoader threads. Each Dataloader thread loads a single 121 | batch by calling this function per instance. 122 | """ 123 | item_id, item_info = self.items[index] 124 | img, meta = self.gd[item_id] 125 | img_full, img_crop = img 126 | 127 | # img_full = cv2.cvtColor(img_full, cv2.COLOR_BGR2RGB) 128 | # img_crop = cv2.cvtColor(img_crop, cv2.COLOR_BGR2RGB) 129 | img_full = img_full[..., [2, 1, 0]] 130 | img_crop = img_crop[..., [2, 1, 0]] 131 | 132 | return img_full, img_crop 133 | 134 | def __len__(self): 135 | return len(self.items) 136 | -------------------------------------------------------------------------------- /vfn/data/evaluation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torchvision.transforms import transforms as T 5 | from tqdm import tqdm 6 | 7 | 8 | 9 | class ImageCropperEvaluator(object): 10 | 11 | def __init__(self, model, dataset, device, transforms=None, generate_crops_fn=None): 12 | with torch.no_grad(): 13 | self._predict(model, dataset, device, transforms, generate_crops_fn) 14 | 15 | self._intersection_over_union_value = None 16 | self._mean_intersection_over_union_value = None 17 | self._mean_boundary_displacement_value = None 18 | self._alpha_recall_value = None 19 | 20 | @property 21 | def num_evaluated_images(self): 22 | return len(self._ground_truths) 23 | 24 | @property 25 | def intersection_over_union(self): 26 | if self._mean_intersection_over_union_value is None: 27 | self._init_intersection_over_union_value() 28 | 29 | return self._mean_intersection_over_union_value 30 | 31 | def _init_intersection_over_union_value(self): 32 | x1, y1, w1, h1 = self._ground_truths.t() 33 | x2, y2, w2, h2 = self._predictions.t() 34 | inter_w = torch.min(x1 + w1, x2 + w2) - torch.max(x1, x2) 35 | inter_w.clamp_min_(0.0) 36 | inter_h = torch.min(y1 + h1, y2 + h2) - torch.max(y1, y2) 37 | inter_h.clamp_min_(0.0) 38 | intersection = inter_w * inter_h 39 | union = (w1 * h1) + (w2 * h2) - intersection 40 | self._intersection_over_union_value = intersection / union 41 | self._mean_intersection_over_union_value = torch.mean(self._intersection_over_union_value).cpu() 42 | 43 | @property 44 | def boundary_displacement(self): 45 | if self._mean_boundary_displacement_value is None: 46 | w, h = self._image_sizes.t() 47 | 48 | x11, y11, w1, h1 = self._ground_truths.t() 49 | y12 = y11 + h1 50 | x12 = x11 + w1 51 | 52 | x21, y21, w2, h2 = self._predictions.t() 53 | y22 = y21 + h2 54 | x22 = x21 + w2 55 | 56 | x_displacement = (torch.abs(x11 - x21) + torch.abs(x12 - x22)) / w 57 | y_displacement = (torch.abs(y11 - y21) + torch.abs(y12 - y22)) / h 58 | self._mean_boundary_displacement_value = torch.mean((x_displacement + y_displacement) / 4).cpu() 59 | 60 | return self._mean_boundary_displacement_value 61 | 62 | @property 63 | def alpha_recall(self, alpha=0.75): 64 | if self._alpha_recall_value is None: 65 | if self._mean_intersection_over_union_value is None: 66 | self._init_intersection_over_union_value() 67 | 68 | self._alpha_recall_value = torch.mean(100 * (self._intersection_over_union_value > alpha).float()).cpu() 69 | 70 | return self._alpha_recall_value 71 | 72 | def _predict(self, model, dataset, device, transforms, generate_crops_fn): 73 | model.eval() 74 | model.to(device) 75 | if transforms is None: 76 | transforms = T.Compose([ 77 | T.ToPILImage(), 78 | T.Resize((224, 224)), 79 | T.ToTensor(), 80 | ]) 81 | if generate_crops_fn is None: 82 | generate_crops_fn = generate_crops 83 | 84 | image_sizes = [] 85 | predictions = [] 86 | ground_truths = [] 87 | for filename, image_size, ground_truth in tqdm(dataset): 88 | image_sizes.append(image_size) 89 | ground_truths.append(ground_truth) 90 | 91 | image = cv2.imread(filename) 92 | image = image[..., [2, 1, 0]] 93 | width, height = image_size 94 | crops = [ground_truth] + generate_crops_fn(width, height) 95 | crop_images = self._generate_crop_images(image, crops, transforms).to(device) 96 | scores = model(crop_images) 97 | idx = scores.argmax().item() 98 | 99 | predictions.append(crops[idx]) 100 | 101 | self._image_sizes = torch.tensor(image_sizes, dtype=torch.float32, device=device) 102 | self._predictions = torch.tensor(predictions, dtype=torch.float32, device=device) 103 | self._ground_truths = torch.tensor(ground_truths, dtype=torch.float32, device=device) 104 | 105 | @staticmethod 106 | def _generate_crop_images(image, crops, transforms): 107 | crop_images = [] 108 | for crop in crops: 109 | x, y, w, h = crop 110 | crop_image = np.copy(image[y:y + h, x:x + w, :]) 111 | crop_image = transforms(crop_image) 112 | crop_images.append(crop_image) 113 | 114 | return torch.stack(crop_images) 115 | 116 | 117 | def generate_crops(width, height): 118 | crops = [] 119 | 120 | for scale in range(5, 10): 121 | scale /= 10 122 | w, h = width * scale, height * scale 123 | dw, dh = width - w, height - h 124 | dw, dh = dw / 5, dh / 5 125 | 126 | for w_idx in range(5): 127 | for h_idx in range(5): 128 | x, y = w_idx * dw, h_idx * dh 129 | crops.append([int(x), int(y), int(w), int(h)]) 130 | 131 | return crops 132 | -------------------------------------------------------------------------------- /vfn/data/image_downloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import os 5 | from tqdm import trange 6 | from vfn.data.ioutils import download 7 | 8 | 9 | class ImageDownloader(object): 10 | def __init__(self): 11 | super(ImageDownloader, self).__init__() 12 | 13 | @staticmethod 14 | def download(root_dir, image_urls): 15 | print('Downloading', len(image_urls), 'images to', root_dir) 16 | for i in trange(len(image_urls), ascii=True): 17 | url = image_urls[i] 18 | filepath = os.path.join(root_dir, os.path.basename(url)) 19 | 20 | if not os.path.exists(filepath): 21 | download(url, filepath, verbose=False) 22 | -------------------------------------------------------------------------------- /vfn/data/ioutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | 3 | import time 4 | import os 5 | import shutil 6 | import zipfile 7 | import sys 8 | import urllib 9 | if sys.version_info[0] == 3: 10 | from urllib.request import urlretrieve 11 | else: 12 | from urllib import urlretrieve 13 | 14 | 15 | def download(url, filename, verbose=True): 16 | r"""Download file from the internet. 17 | 18 | Args: 19 | url (string): URL of the internet file. 20 | filename (string): Path to store the downloaded file. 21 | verbose (bool): Show download progress. 22 | """ 23 | try: 24 | if verbose: 25 | return urlretrieve(url, filename, _reporthook) 26 | else: 27 | return urlretrieve(url, filename) 28 | except urllib.error.HTTPError as e: 29 | print(e) 30 | 31 | 32 | def _reporthook(count, block_size, total_size): 33 | global start_time 34 | if count == 0: 35 | start_time = time.time() 36 | return 37 | duration = time.time() - start_time 38 | progress_size = int(count * block_size) 39 | speed = int(progress_size / (1024 * duration)) 40 | percent = int(count * block_size * 100 / total_size) 41 | sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % 42 | (percent, progress_size / (1024 * 1024), speed, duration)) 43 | sys.stdout.flush() 44 | 45 | 46 | def extract(filename, extract_dir): 47 | r"""Extract zip file. 48 | 49 | Args: 50 | filename (string): Path of the zip file. 51 | extract_dir (string): Directory to store the extracted results. 52 | """ 53 | if os.path.splitext(filename)[1] == '.zip': 54 | if not os.path.isdir(extract_dir): 55 | os.makedirs(extract_dir) 56 | with zipfile.ZipFile(filename) as z: 57 | z.extractall(extract_dir) 58 | else: 59 | raise Exception('Unsupport extension {} of the compressed file {}.'.format( 60 | os.path.splitext(filename)[1]), filename) 61 | 62 | 63 | def compress(dirname, save_file): 64 | """Compress a folder to a zip file. 65 | 66 | Arguments: 67 | dirname {string} -- Directory of all files to be compressed. 68 | save_file {string} -- Path to store the zip file. 69 | """ 70 | shutil.make_archive(save_file, 'zip', dirname) 71 | -------------------------------------------------------------------------------- /vfn/network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aquastripe/pytorch-view-finding-network/5518f3ec80d9c9169bc1ac54223cb4d2ea372cae/vfn/network/__init__.py -------------------------------------------------------------------------------- /vfn/network/backbones.py: -------------------------------------------------------------------------------- 1 | # CNN backbone for vfn 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | 6 | 7 | class Backbone(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | pass 11 | 12 | def forward(self, *input): 13 | pass 14 | 15 | def input_dim(self): 16 | pass 17 | 18 | def output_dim(self): 19 | pass 20 | 21 | 22 | # Example: 23 | # backbone = backbones.AlexNet(pretrained) 24 | # vfn = ViewFindingNet(backbone) 25 | # ... 26 | class AlexNet(Backbone): 27 | 28 | def __init__(self, pretrained=False): 29 | super().__init__() 30 | model = models.alexnet(pretrained) 31 | self.features = model.features 32 | self.avgpool = model.avgpool 33 | 34 | def forward(self, x: torch.Tensor): 35 | x = self.features(x) 36 | x = self.avgpool(x) 37 | return x 38 | 39 | def input_dim(self): 40 | return 227 41 | 42 | def output_dim(self): 43 | # x = torch.randn((1, 3, self.input_dim(), self.input_dim())) 44 | # y = self.forward(x) 45 | # y = y.view((-1)) 46 | # Output: 256 * 6 * 6 = 9216 47 | return 9216 48 | 49 | 50 | class VGG(Backbone): 51 | def __init__(self, pretrained=False): 52 | super().__init__() 53 | model = models.vgg16(pretrained) 54 | self.features = model.features 55 | self.avgpool = model.avgpool 56 | 57 | def forward(self, x: torch.Tensor): 58 | x = self.features(x) 59 | x = self.avgpool(x) 60 | return x 61 | 62 | def input_dim(self): 63 | return 224 64 | 65 | def output_dim(self): 66 | # x = torch.randn((1, 3, self.input_dim(), self.input_dim())) 67 | # y = self.forward(x) 68 | # y = y.view((-1)) 69 | # print(y.size()) 70 | # Output: 512 * 7 * 7 = 25088 71 | return 25088 72 | -------------------------------------------------------------------------------- /vfn/network/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def hinge_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 6 | # Ref: https://pytorch.org/docs/stable/nn.html#marginrankingloss 7 | criterion = nn.MarginRankingLoss(margin=1) 8 | return criterion(x, y, torch.ones_like(x)) 9 | 10 | 11 | def hinge_loss_v2(score_full: torch.Tensor, score_crop: torch.Tensor, g=1.) -> torch.Tensor: 12 | zeros = torch.zeros_like(score_crop) 13 | g = torch.tensor(g, device=zeros.device) 14 | return torch.mean(torch.max(zeros, g + score_crop - score_full)) 15 | 16 | 17 | def ranknet_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 18 | return nn.BCEWithLogitsLoss()(x, y) 19 | -------------------------------------------------------------------------------- /vfn/network/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from vfn.network.backbones import Backbone 4 | 5 | 6 | class ViewFindingNet(nn.Module): 7 | 8 | def __init__(self, backbone: Backbone): 9 | super().__init__() 10 | self.backbone = backbone 11 | self.fc1 = nn.Sequential( 12 | nn.Linear(self.backbone.output_dim(), 1000), 13 | nn.ReLU(True), 14 | ) 15 | self.fc2 = nn.Sequential( 16 | nn.Linear(1000, 1), 17 | ) 18 | 19 | for m in self.modules(): 20 | if isinstance(m, nn.Linear): 21 | nn.init.kaiming_normal_(m.weight) 22 | m.bias.data.zero_() 23 | 24 | def forward(self, image): 25 | x = self.backbone(image) # type: torch.Tensor 26 | x = x.view((x.size(0), -1)) 27 | x = self.fc1(x) 28 | score = self.fc2(x) 29 | return score 30 | 31 | -------------------------------------------------------------------------------- /vfn/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aquastripe/pytorch-view-finding-network/5518f3ec80d9c9169bc1ac54223cb4d2ea372cae/vfn/utils/__init__.py -------------------------------------------------------------------------------- /vfn/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from enum import Enum 5 | from typing import List, Union, Tuple 6 | from PIL import Image, ImageColor, ImageDraw 7 | from torchvision.transforms import ToPILImage 8 | 9 | 10 | class ColorType(Enum): 11 | SLIDING_WINDOWS = ImageColor.getrgb('lightgrey') 12 | GROUNDTRUTH = ImageColor.getrgb('lightpink') 13 | PREDICT = ImageColor.getrgb('greenyellow') 14 | 15 | 16 | def plot_bbox( 17 | image: Union[torch.Tensor, np.ndarray, Image.Image], 18 | bboxes: List[List[int]], 19 | bbox_type: ColorType, 20 | ) -> Image.Image: 21 | if isinstance(image, torch.Tensor): 22 | image = ToPILImage()(image) 23 | elif isinstance(image, np.ndarray): 24 | image = Image.fromarray(image) 25 | 26 | # here `image` is instance of `PIL.Image` 27 | 28 | draw = ImageDraw.Draw(image) 29 | for bbox in bboxes: 30 | x, y, w, h = bbox 31 | x1, y1 = x+w, y+h 32 | draw.rectangle([x, y, x1, y1], outline=bbox_type.value, width=2) 33 | 34 | return image 35 | --------------------------------------------------------------------------------