├── .flake8 ├── .github └── workflows │ └── pythonpackage.yml ├── .gitignore ├── LICENSE ├── README.md ├── bcnn ├── __init__.py ├── data.py ├── model.py └── trainer.py ├── config.py ├── main.py ├── requirements.txt └── scripts └── prepareData.sh /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = 3 | .git, 4 | data, 5 | ckpt, 6 | __pycache__, 7 | .mypy_cache, 8 | max-line-length = 120 9 | 10 | [mypy] 11 | ignore_missing_imports = True 12 | -------------------------------------------------------------------------------- /.github/workflows/pythonpackage.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | max-parallel: 4 11 | matrix: 12 | python-version: [3.6] 13 | 14 | steps: 15 | - uses: actions/checkout@v1 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v1 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install -r requirements.txt 24 | - name: Lint with flake8 25 | run: | 26 | pip install flake8 27 | # stop the build if there are Python syntax errors or undefined names 28 | flake8 --config=.flake8 --select=E901,E999,F821,F822,F823 --count --show-source --statistics bcnn/*.py config.py main.py 29 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 30 | flake8 --config=.flake8 --count --exit-zero --max-complexity=10 --statistics bcnn/*.py config.py main.py 31 | - name: Type checking with mypy 32 | run: | 33 | pip install mypy 34 | mypy --ignore-missing-imports bcnn/*.py config.py main.py 35 | - name: Doc style via pydocstyle 36 | run: | 37 | pip install pydocstyle 38 | pydocstyle bcnn/*.py config.py main.py 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *ipynb* 2 | # .github 3 | .mypy_cache 4 | .vscode 5 | __pycache__ 6 | ckpt/* 7 | data/* 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Riddhiman Dasgupta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Bilinear ConvNets for Fine-Grained Recognition 3 | 4 | This is a [PyTorch](http://pytorch.org/) implementation of Bilinear CNNs as described in the paper [Bilinear CNN Models For Fine-Grained Visual Recognition](http://vis-www.cs.umass.edu/bcnn/) by Tsung-Yu Lin, Aruni Roy Chowdhury, and Subhransu Maji. On the [Caltech-UCSD Birds-200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) or CUB-200-2011 dataset, for the task of 200 class fine-grained bird species classification, this implementation reaches: 5 | 6 | - Accuracy of `84.29%` using the following training regime 7 | - Train only new bilinear classifier, keeping pre-trained layers frozen 8 | - Learning rate: 1e0, Weight Decay: 1e-8, Epochs: 55 9 | - Finetune all pretrained layers as well as bilinear layer jointly 10 | - Learning rate: 1e-2, Weight Decay: 1e-5, Epochs: 25 11 | - Common settings for both training runs 12 | - Optimizer: SGD, Momentum: 0.9, Batch Size: 64, GPUs: 4 13 | - These values are plugged into the config file as defaults 14 | - The original paper reports `84.00%` accuracy on CUB-200-2011 dataset using `VGG-D` pretrained model, which is similar to the `VGG-16` model that this implementation uses. 15 | - Minor differences exist, e.g. no SVM being used, and the L2 normalization is done differently. 16 | 17 | ## Requirements 18 | 19 | - Python (tested on **3.6.9**, should work on **3.5.0** onwards due to typing). 20 | - Other dependencies are in `requirements.txt` 21 | - Currently works with Pytorch 1.1.0, but should work fine with newer versions. 22 | 23 | ## Usage 24 | 25 | The actual model class along with the relevant dataset class and a utility trainer class is packaged into the `bcnn` subfolder, from which the relevant modules can be imported. Dataset downloading and preprocessing is done via a shell script, and a Python driver script is provided to run the actual training/testing loop. 26 | 27 | - Use the script `scripts/prepareData.sh` which does the following: 28 | - **WARNING:** Some of these steps require [GNU Parallel](https://www.gnu.org/software/parallel/), which can be installed [via these methods](https://stackoverflow.com/questions/32093425/installing-gnu-parallel-without-root-permission) 29 | - Download the [CUB-200-2011 dataset](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) and extract it. 30 | - Preprocess the dataset, i.e. resizing smaller edge to 512 pixels maintaining aspect ratio. 31 | - A copy of the dataset is also created where images are cropped to their bounding boxes. 32 | - `main.py` is the actual driver script. It imports relevant modules from the `bcnn` package, and performs the actual pre-training and fine-tuning of the model, and testing it on the test splits. For a list of all command-line arguments, have a look at `config.py`. 33 | - Model checkpoints are saved to the `ckpt/` directory with the name specified by the command line argument `--savedir`. 34 | 35 | If you have a working Python3 environment, simply run the following sequence of steps: 36 | 37 | ```bash 38 | - bash scripts/prepareData.sh 39 | - pip install -r requirements.txt 40 | - export CUDA_VISIBLE_DEVICES=0,1,2,3 41 | - python main.py --gpus 1 2 3 4 --savedir ./ckpt/exp_test 42 | ``` 43 | 44 | ## Notes 45 | 46 | - (**Oct 12, 2019**) GPU memory consumption is not very high, which means batch size can be increased. However, that requires changing other hyperparameters such as learning rate. 47 | 48 | ## Acknowledgements 49 | 50 | [Tsung-Yu Lin](https://people.cs.umass.edu/~tsungyulin/) and [Aruni Roy Chowdhury](https://arunirc.github.io/) released the [original implementation](https://bitbucket.org/tsungyu/bcnn/src/master/) which was invaluable in understanding the model architecture. 51 | [Hao Mood](https://haomood.github.io/homepage/) also released a [PyTorch implementation](https://github.com/HaoMood/bilinear-cnn/) which was critical for finding the right hyperparameters to reach the accuracy reported in the paper. 52 | As usual, shout-out to the [Pytorch team](https://github.com/pytorch/pytorch#the-team) for the incredible library. 53 | 54 | ## Contact 55 | 56 | [Riddhiman Dasgupta](https://dasguptar.github.io/) 57 | *Please create an issue or submit a PR if you find any bugs!* 58 | 59 | ## License 60 | 61 | **MIT** 62 | -------------------------------------------------------------------------------- /bcnn/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize package by importing modules/classes.""" 2 | 3 | from .data import get_data_loader 4 | from .model import BilinearModel 5 | from .trainer import Trainer 6 | 7 | __all__ = [ 8 | 'get_data_loader', 9 | 'BilinearModel', 10 | 'Trainer', 11 | ] 12 | -------------------------------------------------------------------------------- /bcnn/data.py: -------------------------------------------------------------------------------- 1 | """Handle datasets, transforms, and data loaders.""" 2 | 3 | import argparse 4 | from pathlib import Path 5 | from typing import List 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | import torchvision.transforms as transforms 10 | from torchvision.datasets import ImageFolder 11 | 12 | Transform = object 13 | 14 | common_transforms: List[Transform] = [ 15 | transforms.ToTensor(), 16 | transforms.Normalize( 17 | mean=[0.485, 0.456, 0.406], 18 | std=[0.229, 0.224, 0.225] 19 | ), 20 | ] 21 | 22 | 23 | def get_data_loader(split: str, args: argparse.Namespace) -> DataLoader: 24 | """Return dataloader for specified split.""" 25 | split_transforms: List[Transform] 26 | if split == 'train': 27 | split_transforms = [ 28 | transforms.Resize(448), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.RandomRotation(5), 31 | transforms.RandomCrop(448), 32 | ] 33 | elif split == 'test': 34 | split_transforms = [ 35 | transforms.Resize(448), 36 | transforms.CenterCrop(448), 37 | ] 38 | 39 | dataset: ImageFolder = ImageFolder( 40 | root=Path(args.datadir) / split, 41 | transform=transforms.Compose(split_transforms + common_transforms) 42 | ) 43 | dataloader: DataLoader = DataLoader( 44 | dataset=dataset, 45 | batch_size=args.batchsize, 46 | shuffle=True if split == 'train' else False, 47 | num_workers=args.workers, 48 | pin_memory=True, 49 | ) 50 | return dataloader 51 | -------------------------------------------------------------------------------- /bcnn/model.py: -------------------------------------------------------------------------------- 1 | """Load model with pretrained weights and initialise new layers.""" 2 | 3 | from overrides import overrides 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | import torchvision.models as models 9 | 10 | 11 | class BilinearModel(nn.Module): 12 | """Load model with pretrained weights and initialise new layers.""" 13 | 14 | def __init__(self, num_classes: int = 200) -> None: 15 | """Load pretrained model, set new layers with specified number of layers.""" 16 | super(BilinearModel, self).__init__() 17 | model: nn.Module = models.vgg16(pretrained=True) 18 | self.features: nn.Module = nn.Sequential(*list(model.features)[:-1]) 19 | self.classifier: nn.Module = nn.Linear(512 ** 2, num_classes) 20 | nn.init.kaiming_normal_(self.classifier.weight.data) 21 | if self.classifier.bias is not None: 22 | nn.init.constant_(self.classifier.bias.data, val=0) 23 | 24 | @overrides 25 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 26 | """Extract input features, perform bilinear transform, project to # of classes and return.""" 27 | outputs: torch.Tensor = self.features(inputs) # extract features from pretrained base 28 | outputs = outputs.view(-1, 512, 28 ** 2) # reshape to batchsize * 512 * 28 ** 2 29 | outputs = torch.bmm(outputs, outputs.permute(0, 2, 1)) # bilinear product 30 | outputs = torch.div(outputs, 28 ** 2) # divide by 196 to normalize 31 | outputs = outputs.view(-1, 512 ** 2) # reshape to batchsize * 512 * 512 32 | outputs = torch.sign(outputs) * torch.sqrt(outputs + 1e-5) # signed square root normalization 33 | outputs = nn.functional.normalize(outputs, p=2, dim=1) # l2 normalization 34 | outputs = self.classifier(outputs) # linear projection 35 | return outputs 36 | -------------------------------------------------------------------------------- /bcnn/trainer.py: -------------------------------------------------------------------------------- 1 | """Trainer class to abstract rudimentary training loop.""" 2 | 3 | from typing import Tuple 4 | 5 | import torch 6 | from torch.nn import Module 7 | from torch.optim.optimizer import Optimizer 8 | from torch.utils.data import DataLoader 9 | 10 | from tqdm import tqdm 11 | 12 | 13 | class Trainer(object): 14 | """Trainer class to abstract rudimentary training loop.""" 15 | 16 | def __init__( 17 | self, 18 | model: Module, 19 | criterion: Module, 20 | optimizer: Optimizer, 21 | device: torch.device) -> None: 22 | """Set trainer class with model, criterion, optimizer. (Data is passed to train/eval).""" 23 | super(Trainer, self).__init__() 24 | self.model: Module = model 25 | self.criterion: Module = criterion 26 | self.optimizer: Optimizer = optimizer 27 | self.device: torch.device = device 28 | 29 | def train(self, loader: DataLoader) -> Tuple[float, float]: 30 | """Train model using batches from loader and return accuracy and loss.""" 31 | total_loss, total_acc = 0.0, 0.0 32 | self.model.train() 33 | for _, (inputs, targets) in tqdm(enumerate(loader), total=len(loader), desc='Training'): 34 | inputs = inputs.to(self.device) 35 | targets = targets.to(self.device) 36 | outputs = self.model(inputs) 37 | loss = self.criterion(outputs, targets) 38 | self.optimizer.zero_grad() 39 | loss.backward() 40 | self.optimizer.step() 41 | _, predicted = torch.max(outputs, 1) 42 | total_loss += loss.item() 43 | total_acc += (predicted == targets).float().sum().item() / targets.numel() 44 | return total_loss / len(loader), 100.0 * total_acc / len(loader) 45 | 46 | def test(self, loader: DataLoader) -> Tuple[float, float]: 47 | """Evaluate model using batches from loader and return accuracy and loss.""" 48 | with torch.no_grad(): 49 | total_loss, total_acc = 0.0, 0.0 50 | self.model.eval() 51 | for _, (inputs, targets) in tqdm(enumerate(loader), total=len(loader), desc='Testing '): 52 | inputs = inputs.to(self.device) 53 | targets = targets.to(self.device) 54 | outputs = self.model(inputs) 55 | loss = self.criterion(outputs, targets) 56 | _, predicted = torch.max(outputs, 1) 57 | total_loss += loss.item() 58 | total_acc += (predicted == targets).float().sum().item() / targets.numel() 59 | return total_loss / len(loader), 100.0 * total_acc / len(loader) 60 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """Argument parser for all programs.""" 2 | 3 | import argparse 4 | import pathlib 5 | 6 | 7 | def parse_args_main() -> argparse.Namespace: 8 | """Parse arguments for main training script.""" 9 | parser: argparse.ArgumentParser = argparse.ArgumentParser(description='Parse arguments for main.py') 10 | parser.add_argument('--datadir', type=pathlib.Path, default='./data/CUB_200_2011/original/', 11 | help='Path to folder containing data') 12 | parser.add_argument('--savedir', type=pathlib.Path, default='./ckpt/test', 13 | help='Directory for checkpointing to disk') 14 | parser.add_argument('--load', type=str, default='', 15 | help='Checkpoint filename to load state dicts') 16 | parser.add_argument('--batchsize', type=int, default=64, 17 | help='Batchsize for each GPU') 18 | parser.add_argument('--epochs', type=int, nargs='+', default=[55, 25], 19 | help='Number of epochs for partial and full finetuning') 20 | parser.add_argument('--lr', type=float, nargs='+', default=[1.0, 1e-2], 21 | help='Learning rate (multiplied by factor for new layers)') 22 | parser.add_argument('--wd', type=float, nargs='+', default=[1e-8, 1e-5], 23 | help='Weight decay (multiplied by factor for new layers)') 24 | parser.add_argument('--momentum', type=float, default=0.9, 25 | help='Momentum for SGD optimizer') 26 | parser.add_argument('--stepfactor', type=float, default=0.1, 27 | help='Step size for reducing learning rate') 28 | parser.add_argument('--patience', type=int, default=3, 29 | help='How long to wait before dropping LR') 30 | parser.add_argument('--gpus', type=int, default=[], nargs='+', 31 | help='Space separated list of gpus to use') 32 | parser.add_argument('--seed', type=int, default=12345, 33 | help='Random seed for reproducibility') 34 | parser.add_argument('--workers', type=int, default=8, 35 | help='Number of parallel data loader threads') 36 | args: argparse.Namespace = parser.parse_args() 37 | return args 38 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Bilinear CNN training script.""" 2 | 3 | import argparse 4 | import logging 5 | import pathlib 6 | import random 7 | from typing import Any, Dict, Tuple 8 | 9 | from bcnn import BilinearModel, Trainer, get_data_loader 10 | 11 | from config import parse_args_main 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch.optim.lr_scheduler import ReduceLROnPlateau 17 | from torch.utils.data import DataLoader 18 | 19 | logging.basicConfig( 20 | level=logging.DEBUG, 21 | format="[%(asctime)s] %(levelname)s:%(name)s:%(message)s", 22 | handlers=[ 23 | logging.StreamHandler(), 24 | # logging.FileHandler('expt.log', mode='w') 25 | ]) 26 | logger = logging.getLogger() 27 | 28 | 29 | def checkpoint( 30 | trainer: Trainer, 31 | epoch: int, 32 | accuracy: float, 33 | savedir: pathlib.Path, 34 | config: argparse.Namespace) -> None: 35 | """Save a model checkpoint at specified location.""" 36 | checkpoint: Dict[str, Any] = { 37 | "model": trainer.model.state_dict(), 38 | "optim": trainer.optimizer.state_dict(), 39 | "epoch": epoch, 40 | "accuracy": accuracy, 41 | "config": config, 42 | } 43 | logger.debug("==> Checkpointing Model") 44 | torch.save(checkpoint, savedir / 'checkpoint.pt') 45 | 46 | 47 | def run_epochs_for_loop( 48 | trainer: Trainer, 49 | epochs: int, 50 | train_loader: DataLoader, 51 | test_loader: DataLoader, 52 | savedir: pathlib.Path, 53 | config: argparse.Namespace, 54 | scheduler: ReduceLROnPlateau = None): 55 | """Run train + evaluation loop for specified epochs. 56 | 57 | Save checkpoint to specified save folder when better optimum is found. 58 | If LR scheduler is specified, change LR accordingly. 59 | """ 60 | best_acc: float = 0.0 61 | for epoch in range(epochs): 62 | (train_loss, train_acc) = trainer.train(train_loader) # type: Tuple[float, float] 63 | (test_loss, test_acc) = trainer.test(test_loader) # type: Tuple[float, float] 64 | logger.info("Epoch %d: TrainLoss %f \t TrainAcc %f" % (epoch, train_loss, train_acc)) 65 | logger.info("Epoch %d: TestLoss %f \t TestAcc %f" % (epoch, test_loss, test_acc)) 66 | if scheduler is not None: 67 | scheduler.step(test_acc) 68 | if test_acc > best_acc: 69 | best_acc = test_acc 70 | checkpoint(trainer, epoch, test_acc, savedir, config) 71 | 72 | 73 | def main(): 74 | """Train bilinear CNN.""" 75 | args: argparse.Namespace = parse_args_main() 76 | logger.debug(args) 77 | 78 | # random seeding 79 | torch.manual_seed(args.seed) 80 | random.seed(args.seed) 81 | if len(args.gpus) > 0: 82 | torch.cuda.manual_seed(args.seed) 83 | torch.backends.cudnn.benchmark = True 84 | device: torch.device = torch.device('cuda:0') 85 | 86 | args.savedir.mkdir(parents=True, exist_ok=True) 87 | 88 | train_loader: DataLoader = get_data_loader('train', args) 89 | test_loader: DataLoader = get_data_loader('test', args) 90 | 91 | model: nn.Module = BilinearModel(num_classes=200) 92 | model = torch.nn.DataParallel(model) 93 | criterion: nn.Module = nn.CrossEntropyLoss() 94 | model.to(device) 95 | criterion.to(device) 96 | 97 | logger.debug("==> PRETRAINING NEW BILINEAR LAYER ONLY") 98 | for param in model.module.features.parameters(): 99 | param.requires_grad = False 100 | optimizer: optim.optimizer.Optimizer = optim.SGD( 101 | model.module.classifier.parameters(), 102 | lr=args.lr[0], 103 | weight_decay=args.wd[0], 104 | momentum=args.momentum, 105 | nesterov=True, 106 | ) 107 | pretrainer: Trainer = Trainer( 108 | model, 109 | criterion, 110 | optimizer, 111 | device, 112 | ) 113 | scheduler: ReduceLROnPlateau = ReduceLROnPlateau( 114 | optimizer, 115 | mode='max', 116 | factor=args.stepfactor, 117 | patience=args.patience, 118 | verbose=True, 119 | threshold=1e-4, 120 | ) 121 | run_epochs_for_loop( 122 | trainer=pretrainer, 123 | epochs=args.epochs[0], 124 | train_loader=train_loader, 125 | test_loader=test_loader, 126 | savedir=args.savedir, 127 | config=args, 128 | scheduler=scheduler, 129 | ) 130 | 131 | logger.debug("==> FINE-TUNING OLDER LAYERS AS WELL") 132 | for param in model.module.features.parameters(): 133 | param.requires_grad = True 134 | optimizer: optim.optimizer.Optimizer = optim.SGD( 135 | model.parameters(), 136 | lr=args.lr[1], 137 | weight_decay=args.wd[1], 138 | momentum=args.momentum, 139 | nesterov=True, 140 | ) 141 | finetuner: Trainer = Trainer( 142 | model, 143 | criterion, 144 | optimizer, 145 | device, 146 | ) 147 | scheduler: ReduceLROnPlateau = ReduceLROnPlateau( 148 | optimizer, 149 | mode='max', 150 | factor=args.stepfactor, 151 | patience=args.patience, 152 | verbose=True, 153 | threshold=1e-4, 154 | ) 155 | run_epochs_for_loop( 156 | trainer=finetuner, 157 | epochs=args.epochs[1], 158 | train_loader=train_loader, 159 | test_loader=test_loader, 160 | savedir=args.savedir, 161 | config=args, 162 | scheduler=scheduler, 163 | ) 164 | 165 | 166 | if __name__ == "__main__": 167 | main() 168 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | overrides==1.9 2 | torch==1.1.0 3 | torchvision==0.3.0 4 | tqdm==4.35.0 5 | -------------------------------------------------------------------------------- /scripts/prepareData.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # move to data directory 4 | ORIGIN=$(pwd) 5 | mkdir -p data/; cd data/ 6 | 7 | if [ ! -d CUB_200_2011 ]; then 8 | if [ ! -f CUB_200_2011.tgz ]; then 9 | echo "==> DOWNLOADING THE DATASET ..." 10 | wget -c http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz 11 | fi 12 | echo "==> EXTRACTING THE DATASET ..." 13 | tar -xzf CUB_200_2011.tgz 14 | if [ -f attributes.txt ]; then 15 | mv attributes.txt CUB_200_2011/ 16 | fi 17 | # rm CUB_200_2011.tgz 18 | fi 19 | 20 | cd CUB_200_2011/ 21 | ROOT=$(pwd) 22 | 23 | echo "==> GETTING DIRECTORY STRUCTURE ..." 24 | cd images/; find . -type d > $ROOT/dirs.txt; cd $ROOT; 25 | 26 | # join to get index, filename, class, split, bounding boxes 27 | join images.txt image_class_labels.txt | join - train_test_split.txt | join - bounding_boxes.txt > $ROOT/combined_details.txt 28 | 29 | function initDirectoryStructure(){ 30 | mkdir -p $1; cd $1; echo $(pwd) 31 | xargs mkdir -p < $ROOT/dirs.txt; 32 | cd ../; 33 | } 34 | 35 | echo "==> COPYING ORIGINAL FILES OVER ..." 36 | initDirectoryStructure original 37 | cd images/; 38 | find . -name "*jpg" | parallel -j0 --eta --bar cp {} ../original/{} 39 | cd ../ 40 | 41 | # print out filename with bbox dimensions with convert commands 42 | echo "==> CROPPING IMAGES TO BOUNDING BOXES ..." 43 | initDirectoryStructure cropped 44 | awk -F' ' '{print "convert images/"$2" -crop "$5+$7"x"$6+$8"+"$5"+"$6" +repage cropped/"$2}' $ROOT/combined_details.txt | parallel -j0 --eta --bar 45 | 46 | # resize images so that smaller dimension becomes 512 px, maintaining aspect ratio 47 | echo "==> RESIZING CROPPED IMAGES ..." 48 | initDirectoryStructure resized_cropped 49 | find cropped/ -name "*jpg" | parallel -j0 --eta --bar convert -resize "512^" {} resized_{} 50 | 51 | # resize images so that smaller dimension becomes 512 px, maintaining aspect ratio 52 | echo "==> RESIZING UNCROPPED IMAGES ..." 53 | initDirectoryStructure resized_original 54 | find original/ -name "*jpg" | parallel -j0 --eta --bar convert -resize "512^" {} resized_{} 55 | 56 | echo "==> REMOVING UNRESIZED IMAGES" 57 | rm -rf $ROOT/original/* 58 | rm -rf $ROOT/cropped/* 59 | 60 | function splitDirectory(){ 61 | src=$1 62 | tgt=$2 63 | mkdir -p $tgt; cd $tgt 64 | initDirectoryStructure train 65 | initDirectoryStructure test 66 | cd ../ 67 | awk -v s="$src" -v t="$tgt" -F' ' '$4 == 1 {print "cp "s"/"$2" "t"/train/"$2}' $ROOT/combined_details.txt | parallel -j0 --eta --bar #--dryrun 68 | awk -v s="$src" -v t="$tgt" -F' ' '$4 == 0 {print "cp "s"/"$2" "t"/test/"$2}' $ROOT/combined_details.txt | parallel -j0 --eta --bar #--dryrun 69 | } 70 | 71 | echo "==> SPLITTING RESIZED AND UNCROPPED IMAGES INTO TRAIN/TEST FOLDERS" 72 | splitDirectory resized_original original 73 | echo "==> SPLITTING RESIZED AND CROPPED IMAGES INTO TRAIN/TEST FOLDERS" 74 | splitDirectory resized_cropped cropped 75 | 76 | echo "==> Cleaning up intermediate stuff" 77 | cd $ROOT 78 | rm -rf dirs.txt combined_details.txt resized_cropped/ resized_original/ 79 | 80 | # move back to original directory 81 | cd $ORIGIN 82 | --------------------------------------------------------------------------------