├── .gitignore ├── README.md ├── dataset.py ├── demo.ipynb ├── net.py ├── requirements.txt ├── train.py ├── trained_model └── buster_epoch_13.pth ├── utils.py └── visualizes ├── visualize1.png ├── visualize2.png └── visualize3.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | .vscode/ 141 | .ipynb_checkpoints/ 142 | datasets/ 143 | logs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # BusterNet_pytorch: Detecting Copy-Move Image Forgery with Source/Target Localization 3 | 4 | ### Introduction 5 | I reimplement a novel deep neural architecture for image copy-move forgery detection (CMFD), code-named *BusterNet*. 6 | 7 | In this repository, we release many paper related things, including 8 | 9 | - a pretrained BusterNet model (trained model at epoch 13) 10 | - custom layers implemented in pytorch 11 | - python demo notebook 12 | 13 | 14 | ### Example 15 | 16 | ![alt text](visualizes/visualize1.png "Visualize 1") 17 | ![alt text](visualizes/visualize2.png "Visualize 2") 18 | ![alt text](visualizes/visualize3.png "Visualize 3") 19 | 20 | ### Dataset 21 | USCISI-CMFD Dataset 22 | 23 | #### Introduction 24 | This copy-move forgery detection(CMFD) dataset relies on 25 | - [MIT SUN2012 Database](https://groups.csail.mit.edu/vision/SUN/) 26 | - [MS COCO Dataset](http://cocodataset.org/#home) 27 | 28 | More precisely, we synthesize a copy-move forgery sample using the following steps 29 | 30 | 1. select a sample in the two above dataset 31 | 2. select one of its object polygon 32 | 3. use both sample image and polgyon mask to synthesize a sample 33 | 34 | More detailed description can be found in paper. 35 | 36 | #### Folder Content 37 | This USCISI-CMFD dataset folder contains the following things: 38 | 39 | * **api.py** - USCISI-CMFD dataset API 40 | * **USCISI-CMFD Dataset** - USCISI-CMFD LMDB dataset 41 | * Two versions are NOT included due to repo size limit. Please right click to download from the google drive. 42 | * [**USCISI-CMFD-Small**](https://drive.google.com/file/d/14WrmeVRTf9T0umSW6I267zBrsmCjCEIQ/view?usp=sharing) - 100 samples, ~40MB 43 | * [**USCISI-CMFD-Full**](https://drive.google.com/file/d/1gsx5c-oilsFEzX_j1zKTPP4yWEs6T385/view?usp=sharing) - 100K samples, ~100GB 44 | * After uncompressing the downloaded dataset, you should see the following files 45 | * **data.mdb** - sample LMDB data file 46 | * **samples.keys** - a file listing sample keys (each line is a key) 47 | * **lock.mdb** - sample LMDB locker file 48 | * **Demo.ipynb** - a python notebook show the usage of API 49 | * **ReadMe.md** - this file 50 | 51 | **NOTE** due to the repository size limit, the full USCISI-CMFD dataset will be provided upon request. 52 | 53 | ### Training 54 | 1. Download dataset to folder 'datasets' with link about. The ownership belong to yue_wu[at]isi.edu, therefor if you dont have accept permission. Please to contact him. 55 | (Optional) Download pretrained VGG16 at [VGG16](https://download.pytorch.org/models/vgg16-397923af.pth) 56 | 2. Install independent package. 57 | ```pip install -r requirements.txt``` 58 | 3. Training: 59 | ```python train.py``` 60 | with custom argurments: 61 | ``` 62 | usage: Buster Net [-h] [-n NUM_WORKERS] [-b BATCH_SIZE] [--num_gpus NUM_GPUS] 63 | [--freeze_layers [FREEZE_LAYERS [FREEZE_LAYERS ...]]] 64 | [--lr LR] [--optim OPTIM] [--num_epochs NUM_EPOCHS] 65 | [--val_interval VAL_INTERVAL] 66 | [--save_interval SAVE_INTERVAL] 67 | [--es_min_delta ES_MIN_DELTA] [--es_patience ES_PATIENCE] 68 | [--lmdb_dir LMDB_DIR] [--log_path LOG_PATH] 69 | [-w LOAD_WEIGHTS] [--saved_path SAVED_PATH] 70 | ``` 71 | 4. Try predict in demo.ipynb 72 | 73 | ### Citation 74 | If you use the provided code or data in any publication, please kindly cite the following paper. 75 | 76 | @inproceedings{wu2018eccv, 77 | title={BusterNet: Detecting Image Copy-Move Forgery With Source/Target Localization}, 78 | author={Wu, Yue, and AbdAlmageed, Wael and Natarajan, Prem}, 79 | booktitle={European Conference on Computer Vision (ECCV)}, 80 | year={2018}, 81 | organization={Springer}, 82 | } 83 | 84 | ### Contact 85 | - Name: Nguyen Thanh Dat 86 | - Email: ntdat017\[at\]gmail.com 87 | 88 | 89 | ### License 90 | The Software is made available for academic or non-commercial purposes only. The license is for a copy of the program for an unlimited term. Individuals requesting a license for commercial use must pay for a commercial license. 91 | 92 | USC Stevens Institute for Innovation 93 | University of Southern California 94 | 1150 S. Olive Street, Suite 2300 95 | Los Angeles, CA 90115, USA 96 | ATTN: Accounting 97 | 98 | DISCLAIMER. USC MAKES NO EXPRESS OR IMPLIED WARRANTIES, EITHER IN FACT OR BY OPERATION OF LAW, BY STATUTE OR OTHERWISE, AND USC SPECIFICALLY AND EXPRESSLY DISCLAIMS ANY EXPRESS OR IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE, VALIDITY OF THE SOFTWARE OR ANY OTHER INTELLECTUAL PROPERTY RIGHTS OR NON-INFRINGEMENT OF THE INTELLECTUAL PROPERTY OR OTHER RIGHTS OF ANY THIRD PARTY. SOFTWARE IS MADE AVAILABLE AS-IS. LIMITATION OF LIABILITY. TO THE MAXIMUM EXTENT PERMITTED BY LAW, IN NO EVENT WILL USC BE LIABLE TO ANY USER OF THIS CODE FOR ANY INCIDENTAL, CONSEQUENTIAL, EXEMPLARY OR PUNITIVE DAMAGES OF ANY KIND, LOST GOODWILL, LOST PROFITS, LOST BUSINESS AND/OR ANY INDIRECT ECONOMIC DAMAGES WHATSOEVER, REGARDLESS OF WHETHER SUCH DAMAGES ARISE FROM CLAIMS BASED UPON CONTRACT, NEGLIGENCE, TORT (INCLUDING STRICT LIABILITY OR OTHER LEGAL THEORY), A BREACH OF ANY WARRANTY OR TERM OF THIS AGREEMENT, AND REGARDLESS OF WHETHER USC WAS ADVISED OR HAD REASON TO KNOW OF THE POSSIBILITY OF INCURRING SUCH DAMAGES IN ADVANCE. 99 | 100 | For commercial license pricing and annual commercial update and support pricing, please contact: 101 | 102 | Rakesh Pandit USC Stevens Institute for Innovation 103 | University of Southern California 104 | 1150 S. Olive Street, Suite 2300 105 | Los Angeles, CA 90115, USA 106 | 107 | Tel: +1 213-821-3552 108 | Fax: +1 213-821-5001 109 | Email: rakeshvp@usc.edu and ccto: accounting@stevens.usc.edu 110 | 111 | 112 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import cv2 5 | import json 6 | import lmdb 7 | import pyarrow as pa 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | 13 | class USCISIDataset(Dataset): 14 | def __init__(self, lmdb_dir, sample_file, transform=None, target_transform=None, differentiate_target=True): 15 | super(USCISIDataset, self).__init__() 16 | 17 | self.lmdb_dir = lmdb_dir 18 | self.sample_file = os.path.join(lmdb_dir, sample_file) 19 | self.transform = transform 20 | self.target_transform = target_transform 21 | self.differentiate_target = differentiate_target 22 | 23 | assert os.path.isdir(lmdb_dir) 24 | self.lmdb_dir = lmdb_dir 25 | assert os.path.isfile(self.sample_file) 26 | 27 | self._init_db() 28 | 29 | 30 | def _init_db(self): 31 | self.env = lmdb.open(self.lmdb_dir, subdir=os.path.isdir(self.lmdb_dir), 32 | readonly=True, lock=False, 33 | readahead=False, meminit=False) 34 | # with self.env.begin(write=False) as txn: 35 | # self.length = txn.stat()['entries'] - 1 36 | self.keys = self._load_sample_keys() 37 | 38 | def _load_sample_keys(self) : 39 | with open(self.sample_file, 'r') as f: 40 | keys = [line.strip() for line in f.readlines()] 41 | return keys 42 | 43 | def __getitem__(self, index: int): 44 | img, target = None, None 45 | env = self.env 46 | with env.begin(write=False) as txn: 47 | lut_str = txn.get(self.keys[index].encode()) 48 | img, cmd_mask, trans_mat = self._decode_lut_str(lut_str) 49 | 50 | if self.transform is not None: 51 | img = self.transform(img) 52 | if self.target_transform is not None: 53 | cmd_mask = cmd_mask.astype(np.uint8) 54 | cmd_mask = self.target_transform(cmd_mask) 55 | cmd_mask = (cmd_mask * 256).type(torch.int32) 56 | 57 | return img, cmd_mask, trans_mat 58 | 59 | def __len__(self) -> int: 60 | return len(self.keys) 61 | 62 | 63 | def _decode_lut_str(self, lut_str) : 64 | '''Decode a raw LMDB lut 65 | INPUT: 66 | lut_str = str, raw string retrieved from LMDB 67 | OUTPUT: 68 | image = np.ndarray, dtype='uint8', cmd image 69 | cmd_mask = np.ndarray, dtype='float32', cmd mask 70 | trans_mat = np.ndarray, dtype='float32', cmd transform matrix 71 | ''' 72 | # 1. get raw lut 73 | lut = json.loads(lut_str) 74 | # 2. reconstruct image 75 | image = self._get_image_from_lut(lut) 76 | # 3. reconstruct copy-move masks 77 | cmd_mask = self._get_mask_from_lut(lut) 78 | # 4. get transform matrix if necessary 79 | trans_mat = self._get_transmat_from_lut(lut) 80 | return ( image, cmd_mask, trans_mat ) 81 | 82 | 83 | def _get_image_from_lut( self, lut ) : 84 | '''Decode image array from LMDB lut 85 | INPUT: 86 | lut = dict, raw decoded lut retrieved from LMDB 87 | OUTPUT: 88 | image = np.ndarray, dtype='uint8' 89 | ''' 90 | image_jpeg_buffer = lut['image_jpeg_buffer'] 91 | image = cv2.imdecode( np.array(image_jpeg_buffer).astype('uint8').reshape([-1,1]), 1 ) 92 | return image 93 | 94 | def _get_mask_from_lut( self, lut ) : 95 | '''Decode copy-move mask from LMDB lut 96 | INPUT: 97 | lut = dict, raw decoded lut retrieved from LMDB 98 | OUTPUT: 99 | cmd_mask = np.ndarray, dtype='float32' 100 | shape of HxWx1, if differentiate_target=False 101 | shape of HxWx3, if differentiate target=True 102 | NOTE: 103 | cmd_mask is encoded in the one-hot style, if differentiate target=True. 104 | color channel, R, G, and B stand for TARGET, SOURCE, and BACKGROUND classes 105 | ''' 106 | def reconstruct( cnts, h, w, val=1 ) : 107 | rst = np.zeros([h,w], dtype='uint8') 108 | cv2.fillPoly( rst, cnts, val ) 109 | return rst 110 | h, w = lut['image_height'], lut['image_width'] 111 | src_cnts = [ np.array(cnts).reshape([-1,1,2]) for cnts in lut['source_contour'] ] 112 | src_mask = reconstruct( src_cnts, h, w, val = 1 ) 113 | tgt_cnts = [ np.array(cnts).reshape([-1,1,2]) for cnts in lut['target_contour'] ] 114 | tgt_mask = reconstruct( tgt_cnts, h, w, val = 1 ) 115 | if ( self.differentiate_target ) : 116 | # 3-class target 117 | background = np.ones([h,w]).astype('uint8') - np.maximum(src_mask, tgt_mask) 118 | cmd_mask = np.dstack( [tgt_mask, src_mask, background ] ).astype(np.float32) 119 | else : 120 | # 2-class target 121 | cmd_mask = np.maximum(src_mask, tgt_mask).astype(np.float32) 122 | return cmd_mask 123 | def _get_transmat_from_lut( self, lut ) : 124 | '''Decode transform matrix between SOURCE and TARGET 125 | INPUT: 126 | lut = dict, raw decoded lut retrieved from LMDB 127 | OUTPUT: 128 | trans_mat = np.ndarray, dtype='float32', size of 3x3 129 | ''' 130 | trans_mat = lut['transform_matrix'] 131 | return np.array(trans_mat).reshape([3,3]) 132 | 133 | 134 | 135 | # def collater(data): 136 | # imgs = [s['img'] for s in data] 137 | # annots = [s['annot'] for s in data] 138 | # scales = [s['scale'] for s in data] 139 | 140 | # imgs = torch.from_numpy(np.stack(imgs, axis=0)) 141 | 142 | # max_num_annots = max(annot.shape[0] for annot in annots) 143 | 144 | # if max_num_annots > 0: 145 | 146 | # annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1 147 | 148 | # for idx, annot in enumerate(annots): 149 | # if annot.shape[0] > 0: 150 | # annot_padded[idx, :annot.shape[0], :] = annot 151 | # else: 152 | # annot_padded = torch.ones((len(annots), 1, 5)) * -1 153 | 154 | # imgs = imgs.permute(0, 3, 1, 2) 155 | 156 | # return {'img': imgs, 'annot': annot_padded, 'scale': scales} -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import os\n", 21 | "import numpy as np\n", 22 | "from torchvision import transforms\n", 23 | "\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "import torch \n", 26 | "import torch.nn as nn\n", 27 | "from torch.utils.data import DataLoader\n", 28 | "\n", 29 | "from net import BusterNet\n", 30 | "from dataset import USCISIDataset" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "\n", 40 | "def viz(imgs_np, simi_out_np, mani_out_np, mask_out_np, index):\n", 41 | " fig = plt.figure(figsize=(20, 80))\n", 42 | "\n", 43 | " img = imgs_np[index]\n", 44 | " img = imgs_np[index] * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]\n", 45 | "\n", 46 | " ax = fig.add_subplot(1, 4, 1)\n", 47 | " plt.imshow(img)\n", 48 | " plt.title('Original Image')\n", 49 | "\n", 50 | " simi_pred = (simi_out_np[index] * 2).astype(np.uint8)\n", 51 | " ax = fig.add_subplot(1, 4, 2)\n", 52 | " plt.imshow(simi_pred[:,:, 0])\n", 53 | " plt.title('Similarity mask')\n", 54 | "\n", 55 | " mani_pred = (mani_out_np[index] * 2).astype(np.uint8)\n", 56 | " ax = fig.add_subplot(1, 4, 3)\n", 57 | " plt.imshow(mani_pred[:,:, 0])\n", 58 | " plt.title('Manipulation mask')\n", 59 | "\n", 60 | " mask_pred = (mask_out_np[index] * 2).astype(np.uint8) * 255\n", 61 | " ax = fig.add_subplot(1, 4, 4)\n", 62 | " plt.imshow(mask_pred)\n", 63 | " plt.title('Output mask')\n", 64 | "\n", 65 | " plt.show()" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 6, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "# Download dataset to ./datasets/\n", 75 | "lmdb_dir = './datasets/USCISI-CMFD'\n", 76 | "test_file = 'test.keys'\n", 77 | "input_size = 256\n", 78 | "\n", 79 | "transform = transforms.Compose([\n", 80 | " transforms.ToPILImage(),\n", 81 | " transforms.Resize((input_size, input_size)),\n", 82 | " transforms.ToTensor(),\n", 83 | " transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 84 | " std=[0.229, 0.224, 0.225]),\n", 85 | "])\n", 86 | "target_transform = transforms.Compose([\n", 87 | " transforms.ToPILImage(),\n", 88 | " transforms.Resize((input_size, input_size)),\n", 89 | " transforms.ToTensor(),\n", 90 | "])\n", 91 | "test_set = USCISIDataset(lmdb_dir, test_file, transform, target_transform)\n", 92 | "\n", 93 | "test_params = {'batch_size': 16,\n", 94 | " 'shuffle': False,\n", 95 | " 'drop_last': True,\n", 96 | " # 'collate_fn': collater,\n", 97 | " 'num_workers': 1}\n", 98 | "\n", 99 | "test_generator = DataLoader(test_set, **test_params)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 7, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "data = next(iter(test_generator))\n", 109 | "\n", 110 | "imgs, gts, _= data" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 8, 116 | "metadata": {}, 117 | "outputs": [ 118 | { 119 | "name": "stderr", 120 | "output_type": "stream", 121 | "text": [ 122 | "/home/klein/miniconda3/envs/torch/lib/python3.6/site-packages/torch/nn/functional.py:2479: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n", 123 | " \"See the documentation of nn.Upsample for details.\".format(mode))\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "model = BusterNet(256)\n", 129 | "\n", 130 | "model.load_state_dict(torch.load('trained_model/buster_epoch_13.pth'))\n", 131 | "\n", 132 | "model.eval()\n", 133 | "with torch.no_grad():\n", 134 | " preds = model(imgs)\n", 135 | " \n", 136 | "mask_out, mani_output, simi_output = preds\n", 137 | "\n", 138 | "mask_out_np = mask_out.permute(0, 2, 3, 1).numpy()\n", 139 | "mani_out_np = mani_output.permute(0, 2, 3, 1).numpy()\n", 140 | "simi_out_np = simi_output.permute(0, 2, 3, 1).numpy()\n" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 21, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "name": "stderr", 150 | "output_type": "stream", 151 | "text": [ 152 | "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" 153 | ] 154 | }, 155 | { 156 | "data": { 157 | "image/png": "\n", 158 | "text/plain": [ 159 | "
" 160 | ] 161 | }, 162 | "metadata": { 163 | "needs_background": "light" 164 | }, 165 | "output_type": "display_data" 166 | } 167 | ], 168 | "source": [ 169 | "index = 11\n", 170 | "imgs_np = imgs.permute(0, 2, 3, 1).numpy()\n", 171 | "\n", 172 | "viz(imgs_np, simi_out_np, mani_out_np, mask_out_np, index)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | } 182 | ], 183 | "metadata": { 184 | "kernelspec": { 185 | "display_name": "torch", 186 | "language": "python", 187 | "name": "torch" 188 | }, 189 | "language_info": { 190 | "codemirror_mode": { 191 | "name": "ipython", 192 | "version": 3 193 | }, 194 | "file_extension": ".py", 195 | "mimetype": "text/x-python", 196 | "name": "python", 197 | "nbconvert_exporter": "python", 198 | "pygments_lexer": "ipython3", 199 | "version": "3.6.10" 200 | } 201 | }, 202 | "nbformat": 4, 203 | "nbformat_minor": 4 204 | } 205 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | 6 | from utils import Conv2dStaticSamePadding as Conv2d 7 | 8 | class BasicConv2D(nn.Module): 9 | def __init__(self, in_channels, out_channels, **kwargs): 10 | super(BasicConv2D, self).__init__() 11 | self.conv = Conv2d(in_channels, out_channels, bias=True, **kwargs) 12 | self.bn = nn.BatchNorm2d(out_channels, eps=1e-3) 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | x = self.bn(x) 17 | return F.relu(x, inplace=True) 18 | 19 | class Inception(nn.Module): 20 | '''BatchNorm Inception module with batch normalization 21 | Input: 22 | x = tensor4D, (n_samples, n_rows, n_cols, n_feats) 23 | ''' 24 | def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, conv_block=None, is_last=False): 25 | super(Inception, self).__init__() 26 | if conv_block is None: 27 | conv_block = BasicConv2D 28 | if is_last: 29 | k_size1, k_size2, k_size3 = 5, 7, 11 30 | else: 31 | k_size1, k_size2, k_size3 = 1, 3, 5 32 | 33 | self.branch1 = conv_block(in_channels, ch1x1, kernel_size=k_size1) 34 | 35 | self.branch2 = nn.Sequential( 36 | conv_block(in_channels, ch3x3red, kernel_size=1), 37 | conv_block(ch3x3red, ch3x3, kernel_size=k_size2, padding=1) 38 | ) 39 | 40 | self.branch3 = nn.Sequential( 41 | conv_block(in_channels, ch5x5red, kernel_size=1), 42 | conv_block(ch5x5red, ch5x5, kernel_size=k_size3, padding=1) 43 | ) 44 | 45 | def forward(self, x): 46 | x1 = self.branch1(x) 47 | x2 = self.branch2(x) 48 | x3 = self.branch3(x) 49 | 50 | outputs = torch.cat((x1, x2, x3), dim=1) 51 | return outputs 52 | 53 | class CorrelationPercPooling(nn.Module): 54 | '''Custom Self-Correlation Percentile Pooling Layer 55 | ''' 56 | def __init__(self, nb_pools=256, **kwargs): 57 | super(CorrelationPercPooling, self).__init__() 58 | self.nb_pools = nb_pools 59 | 60 | n_maps = 16*16 61 | 62 | if self.nb_pools is not None: 63 | self.ranks = torch.floor(torch.linspace(0, n_maps -1, self.nb_pools)).type(torch.long) 64 | else: 65 | self.ranks = torch.range(1, n_maps, dtype=torch.long) 66 | 67 | def forward(self, x): 68 | ''' 69 | x_shape: (n, c, h, w) 70 | ''' 71 | n_bsize, n_feats, n_cols, n_rows = x.shape 72 | n_maps = n_cols * n_rows 73 | x_3d = x.reshape(n_bsize, n_feats, n_maps) 74 | 75 | x_corr_3d = torch.matmul(x_3d.transpose(1, 2), x_3d) / n_feats 76 | x_corr = x_corr_3d.reshape(n_bsize, n_maps, n_cols, n_rows) 77 | 78 | # ranks = ranks.to(devices) 79 | x_sort, _ = torch.topk(x_corr, k=n_maps, dim=1, sorted=True) 80 | 81 | x_f1st_sort = x_sort.permute(1, 2, 3, 0) 82 | x_f1st_pool = x_f1st_sort[self.ranks] 83 | x_pool = x_f1st_pool.permute(3, 0, 1, 2) 84 | 85 | return x_pool 86 | 87 | def make_layers(cfg, batch_norm=False): 88 | layers = [] 89 | in_channels = 3 90 | for v in cfg: 91 | if v == 'M': 92 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 93 | else: 94 | conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1) 95 | if batch_norm: 96 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 97 | else: 98 | layers += [conv2d, nn.ReLU(inplace=True)] 99 | in_channels = v 100 | return layers 101 | 102 | cfgs = { 103 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 104 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 105 | 'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M'], 106 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 107 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 108 | } 109 | 110 | class DeconvBlock(nn.Module): 111 | def __init__(self, in_channels, out_channels): 112 | super(DeconvBlock, self).__init__() 113 | self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) 114 | 115 | h_channels = out_channels // 2 116 | self.inception = Inception(in_channels, out_channels, h_channels, out_channels, h_channels, out_channels) 117 | 118 | def forward(self, x): 119 | x = self.upsample(x) 120 | x = self.inception(x) 121 | return x 122 | 123 | class MaskDecoder(nn.Module): 124 | def __init__(self, in_channels=512): 125 | super(MaskDecoder, self).__init__() 126 | self.f16 = Inception(in_channels, 8, 4, 8, 4, 8) 127 | 128 | self.deconv_0 = DeconvBlock(24, 6) 129 | self.deconv_1 = DeconvBlock(18, 4) 130 | self.deconv_2 = DeconvBlock(12, 2) 131 | self.deconv_3 = DeconvBlock(6, 2) 132 | 133 | self.pred_mask = Inception(6, 2, 1, 2, 1, 2, is_last=True) 134 | 135 | def forward(self, x): 136 | f16 = self.f16(x) 137 | f32 = self.deconv_0(f16) 138 | f64 = self.deconv_1(f32) 139 | f128 = self.deconv_2(f64) 140 | f256 = self.deconv_3(f128) 141 | pred_mask = self.pred_mask(f256) 142 | 143 | return pred_mask 144 | 145 | class ManipulationNet(nn.Module): 146 | def __init__(self): 147 | super(ManipulationNet, self).__init__() 148 | # self.features = nn.Sequential(*make_layers(cfgs['C'])) 149 | self.features = models.vgg16_bn().features[:-10] 150 | self.mask_decoder = MaskDecoder(512) 151 | self.classifier = nn.Sequential( 152 | Conv2d(6, 1, kernel_size=3), 153 | nn.Sigmoid() 154 | ) 155 | 156 | def forward(self, x): 157 | x = self.features(x) 158 | x = self.mask_decoder(x) 159 | mask = self.classifier(x) 160 | return x, mask 161 | 162 | class SimilarityNet(nn.Module): 163 | def __init__(self): 164 | super(SimilarityNet, self).__init__() 165 | # self.features = nn.Sequential(*make_layers(cfgs['C'])) 166 | self.features = models.vgg16_bn().features[:-10] 167 | self.correlation_per_pooling = CorrelationPercPooling(nb_pools=256) 168 | self.mask_decoder = MaskDecoder(256) 169 | self.classifier = nn.Sequential( 170 | Conv2d(6, 1, kernel_size=3), 171 | nn.Sigmoid() 172 | ) 173 | 174 | def forward(self, x): 175 | x = self.features(x) 176 | x = self.correlation_per_pooling(x) 177 | x = self.mask_decoder(x) 178 | mask = self.classifier(x) 179 | return x, mask 180 | 181 | class BusterNet(nn.Module): 182 | def __init__(self, image_size): 183 | super(BusterNet, self).__init__() 184 | 185 | self.image_size = image_size 186 | 187 | self.manipulation_net = ManipulationNet() 188 | self.similarity_net = SimilarityNet() 189 | 190 | self.inception = nn.Sequential( 191 | Inception(12, 3, 3, 3, 3, 3), 192 | Conv2d(9, 3, kernel_size=3), 193 | nn.Softmax2d() 194 | ) 195 | 196 | def forward(self, x): 197 | mani_feat, mani_output = self.manipulation_net(x) 198 | simi_feat, simi_output = self.similarity_net(x) 199 | 200 | merged_feat = torch.cat([simi_feat, mani_feat], dim=1) 201 | 202 | x = self.inception(merged_feat) 203 | 204 | mask_out = F.interpolate(x, size=(self.image_size, self.image_size), mode='bilinear') 205 | return mask_out, mani_output, simi_output 206 | 207 | if __name__ == "__main__": 208 | model = BusterNet(256) 209 | print(model) 210 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 211 | print(num_params) 212 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imgaug==0.2.6 2 | lmdb==1.0.0 3 | numpy==1.17.5 4 | opencv-python==4.4.0.46 5 | opencv-python-headless==4.2.0.34 6 | Pillow==7.1.2 7 | tensorboard==2.2.2 8 | tensorboard-plugin-wit==1.7.0 9 | tensorboardX==2.1 10 | torch==1.2.0 11 | torchvision==0.4.0 12 | tqdm==4.46.0 13 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import traceback 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | from tensorboardX import SummaryWriter 12 | from tqdm import tqdm 13 | 14 | from dataset import USCISIDataset 15 | from net import BusterNet 16 | from utils import CustomDataParallel 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser('Buster Net') 20 | parser.add_argument('-n', '--num_workers', type=int, default=16, help='num_workers of dataloader') 21 | parser.add_argument('-b', '--batch_size', type=int, default=4, help='The number of images per batch among all devices') 22 | parser.add_argument('--num_gpus', type=int, default=1, help='The number of gpus') # Multi gpus not support yet. 23 | parser.add_argument('--freeze_layers', nargs='*', default=None, 24 | help='freeze layers with strategy') 25 | parser.add_argument('--lr', type=float, default=1e-2) 26 | parser.add_argument('--optim', type=str, default='adamw', help='select optimizer for training, ' 27 | 'suggest using \'adamw\' or \'adam\' until the' 28 | ' very final stage then switch to \'sgd\'') 29 | parser.add_argument('--num_epochs', type=int, default=500) 30 | parser.add_argument('--val_interval', type=int, default=1, help='Number of epoches between valing phases') 31 | parser.add_argument('--save_interval', type=int, default=500, help='Number of steps between saving') 32 | parser.add_argument('--es_min_delta', type=float, default=0.0, 33 | help='Early stopping\'s parameter: minimum change loss to qualify as an improvement') 34 | parser.add_argument('--es_patience', type=int, default=0, 35 | help='Early stopping\'s parameter: number of epochs with no improvement after which training will be stopped. Set to 0 to disable this technique.') 36 | parser.add_argument('--lmdb_dir', type=str, default='./datasets/USCISI-CMFD', help='the root folder of dataset') 37 | parser.add_argument('--log_path', type=str, default='./logs/') 38 | parser.add_argument('-w', '--load_weights', type=str, default=None, 39 | help='whether to load weights from a checkpoint, set None to initialize, set \'last\' to load last checkpoint') 40 | parser.add_argument('--saved_path', type=str, default='logs/') 41 | 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | class ModelWithLoss(nn.Module): 47 | def __init__(self, model, train_simi=True, train_mani=True, train_fusion=True, debug=False): 48 | super().__init__() 49 | self.ce_criterion = nn.CrossEntropyLoss() 50 | self.bce_criterion = nn.BCELoss() 51 | self.model = model 52 | self.train_simi = train_simi 53 | self.train_mani = train_mani 54 | self.train_fusion = train_fusion 55 | self.debug = debug 56 | 57 | def forward(self, imgs, gts): 58 | fusion_preds, mani_preds, simi_preds = self.model(imgs) 59 | simi_gts = (1 - gts[:, 2, :, :]).type(torch.float) 60 | mani_gts = gts[:, 0, :, :].type(torch.float) 61 | _, fusion_gts = gts.max(dim=1) 62 | 63 | loss = torch.zeros(3) 64 | if self.train_fusion: 65 | fusion_loss = self.ce_criterion(fusion_preds, fusion_gts) 66 | loss[0] = fusion_loss 67 | if self.train_mani: 68 | mani_preds = mani_preds.squeeze(1) 69 | mani_loss = self.bce_criterion(mani_preds, mani_gts) 70 | loss[1] = mani_loss 71 | if self.train_simi: 72 | simi_preds = simi_preds.squeeze(1) 73 | simi_loss = self.bce_criterion(simi_preds, simi_gts) 74 | loss[2] = simi_loss 75 | 76 | return loss 77 | 78 | 79 | def train(opt): 80 | train_file = 'train.keys' 81 | val_file = 'valid.keys' 82 | # Train similarity network or manipulation network independently or the whole network. 83 | train_simi=True 84 | train_mani=True 85 | train_fusion=True 86 | 87 | # According to the papers, set input_size default to 256. 88 | input_size = 256 89 | 90 | train_transform = transforms.Compose([ 91 | transforms.ToPILImage(), 92 | transforms.Resize((input_size, input_size)), 93 | # transforms.RandomHorizontalFlip(), 94 | transforms.ToTensor(), 95 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 96 | std=[0.229, 0.224, 0.225]), 97 | ]) 98 | val_transform = transforms.Compose([ 99 | transforms.ToPILImage(), 100 | transforms.Resize((input_size, input_size)), 101 | transforms.ToTensor(), 102 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 103 | std=[0.229, 0.224, 0.225]), 104 | ]) 105 | target_transform = transforms.Compose([ 106 | transforms.ToPILImage(), 107 | transforms.Resize((input_size, input_size)), 108 | transforms.ToTensor(), 109 | ]) 110 | train_set = USCISIDataset(opt.lmdb_dir, train_file, train_transform, target_transform) 111 | val_set = USCISIDataset(opt.lmdb_dir, val_file, val_transform, target_transform) 112 | 113 | training_params = {'batch_size': opt.batch_size, 114 | 'shuffle': True, 115 | 'drop_last': True, 116 | # 'collate_fn': collater, 117 | 'num_workers': opt.num_workers} 118 | 119 | val_params = {'batch_size': opt.batch_size, 120 | 'shuffle': False, 121 | 'drop_last': True, 122 | # 'collate_fn': collater, 123 | 'num_workers': opt.num_workers} 124 | 125 | training_generator = DataLoader(train_set, **training_params) 126 | val_generator = DataLoader(val_set, **val_params) 127 | 128 | model = BusterNet(image_size=input_size) 129 | 130 | if opt.load_weights is not None: 131 | try: 132 | # Load pretrain VGG16 in https://download.pytorch.org/models/vgg16-397923af.pth or continuing training 133 | if 'vgg16_bn' in opt.load_weights: 134 | vgg_backbone = torch.load(opt.load_weights) 135 | model.manipulation_net.load_state_dict(vgg_backbone, strict=False) 136 | model.similarity_net.load_state_dict(vgg_backbone, strict=False) 137 | else: 138 | model.load_state_dict(torch.load(opt.load_weights), strict=False) 139 | except RuntimeError as e: 140 | print(f'[Warning] Ignoring {e}') 141 | print( 142 | f'[Info] loaded weights: {os.path.basename(opt.load_weights)}') 143 | else: 144 | print('[Info] initializing weights...') 145 | # init_weights(model) 146 | 147 | if opt.freeze_layers is not None: 148 | assert isinstance(opt.freeze_layers, list), "Required List string" 149 | def freeze_layers(m): 150 | classname = m.__class__.__name__ 151 | for ntl in opt.freeze_layers: 152 | if ntl in classname: 153 | for param in m.parameters(): 154 | param.require_grad = False 155 | 156 | model.apply(freeze_layers) 157 | print('[Info] freeze layers in ', opt.freeze_layers) 158 | 159 | # warp the model with loss function, to reduce the memory usage on gpu0 and speedup 160 | model = ModelWithLoss(model, train_simi=train_simi, train_mani=train_mani, train_fusion=train_fusion) 161 | 162 | if opt.num_gpus > 1 and opt.batch_size // opt.num_gpus < 4: 163 | model.apply(replace_w_sync_bn) 164 | use_sync_bn = True 165 | else: 166 | use_sync_bn = False 167 | 168 | os.makedirs(opt.saved_path, exist_ok=True) 169 | writer = SummaryWriter(opt.log_path + f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/') 170 | 171 | if opt.num_gpus > 0: 172 | model = model.cuda() 173 | if opt.num_gpus > 1: 174 | model = CustomDataParallel(model, opt.num_gpus) 175 | if use_sync_bn: 176 | patch_replication_callback(model) 177 | 178 | if opt.optim == 'adamw': 179 | optimizer = torch.optim.AdamW(model.parameters(), opt.lr) 180 | elif opt.optim == 'adam': 181 | optimizer = torch.optim.Adam(model.parameters(), opt.lr) 182 | else: 183 | optimizer = torch.optim.SGD(model.parameters(), opt.lr, momentum=0.9, nesterov=True) 184 | 185 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True) 186 | 187 | last_step = 0 188 | epoch = 0 189 | best_loss = 1e5 190 | best_epoch = 0 191 | step = max(0, last_step) 192 | model.train() 193 | 194 | num_iter_per_epoch = len(training_generator) 195 | 196 | try: 197 | for epoch in range(opt.num_epochs): 198 | 199 | epoch_loss = [] 200 | progress_bar = tqdm(training_generator) 201 | for iter, data in enumerate(progress_bar): 202 | last_epoch = step // num_iter_per_epoch 203 | if iter < step - last_epoch * num_iter_per_epoch: 204 | progress_bar.update() 205 | continue 206 | try: 207 | imgs, gts, _ = data 208 | 209 | if opt.num_gpus == 1: 210 | # if only one gpu, just send it to cuda:0 211 | # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here 212 | imgs = imgs.cuda() 213 | gts = gts.cuda() 214 | 215 | optimizer.zero_grad() 216 | 217 | fusion_loss, mani_loss, simi_loss = model(imgs, gts) 218 | fusion_loss = fusion_loss.mean() 219 | simi_loss = simi_loss.mean() 220 | mani_loss = mani_loss.mean() 221 | 222 | loss = fusion_loss + mani_loss + simi_loss 223 | if loss == 0 or not torch.isfinite(loss): 224 | continue 225 | 226 | loss.backward() 227 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) 228 | optimizer.step() 229 | 230 | epoch_loss.append(float(loss)) 231 | 232 | progress_bar.set_description( 233 | 'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Fusion loss: {:.5f}. Mani loss: {:.5f}. Mini loss: {:.5f} Total loss: {:.5f}'.format( 234 | step, epoch, opt.num_epochs, iter + 1, num_iter_per_epoch, fusion_loss.item(), 235 | mani_loss.item(), simi_loss.item(), loss.item())) 236 | writer.add_scalar('Loss', loss, step) 237 | writer.add_scalar('fusion_loss', fusion_loss, step) 238 | writer.add_scalar('simi_loss', simi_loss, step) 239 | writer.add_scalar('mani_loss', mani_loss, step) 240 | 241 | # log learning_rate 242 | current_lr = optimizer.param_groups[0]['lr'] 243 | writer.add_scalar('learning_rate', current_lr, step) 244 | 245 | step += 1 246 | 247 | if step % opt.save_interval == 0 and step > 0: 248 | save_checkpoint(model, f'model_{epoch}_{step}.pth') 249 | print('checkpoint...') 250 | 251 | except Exception as e: 252 | print('[Error]', traceback.format_exc()) 253 | print(e) 254 | continue 255 | scheduler.step(np.mean(epoch_loss)) 256 | 257 | if epoch % opt.val_interval == 0: 258 | model.eval() 259 | loss_fusion_ls = [] 260 | loss_simi_ls = [] 261 | loss_mani_ls = [] 262 | for iter, data in enumerate(val_generator): 263 | with torch.no_grad(): 264 | imgs, gts, _ = data 265 | 266 | if opt.num_gpus == 1: 267 | imgs = imgs.cuda() 268 | gts = gts.cuda() 269 | 270 | fusion_loss, mani_loss, simi_loss = model(imgs, gts) 271 | fusion_loss = fusion_loss.mean() 272 | simi_loss = simi_loss.mean() 273 | mani_loss = mani_loss.mean() 274 | 275 | loss = fusion_loss + mani_loss + simi_loss 276 | if loss == 0 or not torch.isfinite(loss): 277 | continue 278 | 279 | loss_fusion_ls.append(fusion_loss.item()) 280 | loss_simi_ls.append(simi_loss.item()) 281 | loss_mani_ls.append(mani_loss.item()) 282 | 283 | fusion_loss = np.mean(loss_fusion_ls) 284 | simi_loss = np.mean(loss_simi_ls) 285 | mani_loss = np.mean(loss_mani_ls) 286 | loss = fusion_loss + simi_loss + mani_loss 287 | 288 | print( 289 | 'Val. Epoch: {}/{}. Fusion loss: {:1.5f}. Simi loss: {:1.5f}. Mani loss: {:1.5f}. Total loss: {:1.5f}'.format( 290 | epoch, opt.num_epochs, fusion_loss, simi_loss, mani_loss, loss)) 291 | writer.add_scalar('Val_Loss', loss, step) 292 | writer.add_scalar('Val_Fusion_loss', fusion_loss, step) 293 | writer.add_scalar('Val_Simi_loss', simi_loss, step) 294 | writer.add_scalar('Val_Mani_loss', mani_loss, step) 295 | 296 | if loss + opt.es_min_delta < best_loss: 297 | best_loss = loss 298 | best_epoch = epoch 299 | 300 | save_checkpoint(model, f'model_{epoch}_{step}.pth') 301 | 302 | model.train() 303 | 304 | # Early stopping 305 | if epoch - best_epoch > opt.es_patience > 0: 306 | print('[Info] Stop training at epoch {}. The lowest loss achieved is {}'.format(epoch, best_loss)) 307 | break 308 | except KeyboardInterrupt: 309 | save_checkpoint(model, f'model_{epoch}_{step}.pth') 310 | writer.close() 311 | writer.close() 312 | 313 | def save_checkpoint(model, name): 314 | if isinstance(model, CustomDataParallel): 315 | torch.save(model.module.model.state_dict(), os.path.join(opt.saved_path, name)) 316 | else: 317 | torch.save(model.model.state_dict(), os.path.join(opt.saved_path, name)) 318 | 319 | if __name__ == '__main__': 320 | opt = get_args() 321 | train(opt) 322 | -------------------------------------------------------------------------------- /trained_model/buster_epoch_13.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntdat017/BusterNet_pytorch/6a03f73eeb648e6effe729958dac765c9e601ee3/trained_model/buster_epoch_13.pth -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | import math 4 | 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Conv2dStaticSamePadding(nn.Module): 10 | """ 11 | created by Zylo117 12 | The real keras/tensorflow conv2d with same padding 13 | """ 14 | 15 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs): 16 | super().__init__() 17 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 18 | bias=bias, groups=groups) 19 | self.stride = self.conv.stride 20 | self.kernel_size = self.conv.kernel_size 21 | self.dilation = self.conv.dilation 22 | 23 | if isinstance(self.stride, int): 24 | self.stride = [self.stride] * 2 25 | elif len(self.stride) == 1: 26 | self.stride = [self.stride[0]] * 2 27 | 28 | if isinstance(self.kernel_size, int): 29 | self.kernel_size = [self.kernel_size] * 2 30 | elif len(self.kernel_size) == 1: 31 | self.kernel_size = [self.kernel_size[0]] * 2 32 | 33 | def forward(self, x): 34 | h, w = x.shape[-2:] 35 | 36 | extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1] 37 | extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0] 38 | 39 | left = extra_h // 2 40 | right = extra_h - left 41 | top = extra_v // 2 42 | bottom = extra_v - top 43 | 44 | x = F.pad(x, [left, right, top, bottom]) 45 | 46 | x = self.conv(x) 47 | return x 48 | 49 | 50 | class CustomDataParallel(nn.DataParallel): 51 | """ 52 | force splitting data to all gpus instead of sending all data to cuda:0 and then moving around. 53 | """ 54 | 55 | def __init__(self, module, num_gpus): 56 | super().__init__(module) 57 | self.num_gpus = num_gpus 58 | 59 | def scatter(self, inputs, kwargs, device_ids): 60 | # More like scatter and data prep at the same time. The point is we prep the data in such a way 61 | # that no scatter is necessary, and there's no need to shuffle stuff around different GPUs. 62 | devices = ['cuda:' + str(x) for x in range(self.num_gpus)] 63 | splits = inputs[0].shape[0] // self.num_gpus 64 | 65 | if splits == 0: 66 | raise Exception('Batchsize must be greater than num_gpus.') 67 | 68 | return [(inputs[0][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True), 69 | inputs[1][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True)) 70 | for device_idx in range(len(devices))], \ 71 | [kwargs] * len(devices) -------------------------------------------------------------------------------- /visualizes/visualize1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntdat017/BusterNet_pytorch/6a03f73eeb648e6effe729958dac765c9e601ee3/visualizes/visualize1.png -------------------------------------------------------------------------------- /visualizes/visualize2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntdat017/BusterNet_pytorch/6a03f73eeb648e6effe729958dac765c9e601ee3/visualizes/visualize2.png -------------------------------------------------------------------------------- /visualizes/visualize3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ntdat017/BusterNet_pytorch/6a03f73eeb648e6effe729958dac765c9e601ee3/visualizes/visualize3.png --------------------------------------------------------------------------------