├── .gitignore ├── API ├── 1.jpg ├── Segmentation_API │ ├── app.py │ ├── detectron_seg.py │ ├── detectron_seg.yml │ ├── requirements.yml │ └── segmentation.py ├── app.py ├── matting.py └── requirements.yml ├── LICENSE ├── README.md ├── Videos ├── frienvengers.avi └── frienvengers_using_detectron.avi ├── dataloader.py ├── demo.py ├── examples ├── example_results.png ├── images │ ├── elephant.png │ └── troll.png ├── predictions │ ├── elephant_alpha.png │ ├── elephant_bg.png │ ├── elephant_fg.png │ ├── troll_alpha.png │ ├── troll_bg.png │ └── troll_fg.png └── trimaps │ ├── elephant.png │ └── troll.png ├── frienvengers.ipynb ├── inference.py ├── model.py └── networks ├── __init__.py ├── layers_WS.py ├── models.py ├── resnet_GN_WS.py ├── resnet_bn.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | ./vscode/* 2 | *.pyc 3 | ./vscode 4 | .vscode/* 5 | 6 | *.pth 7 | .ipynb_checkpoints 8 | API/Segmentation_API/segmentation_orig.py 9 | Experiment.ipynb 10 | FBA_Matting.ipynb 11 | Videos/avengers.avi 12 | Videos/friends.avi 13 | API/app2.py 14 | API/Segmentation_API/segmentation_orig.py 15 | -------------------------------------------------------------------------------- /API/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/API/1.jpg -------------------------------------------------------------------------------- /API/Segmentation_API/app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, jsonify,request,Response 2 | from segmentation import Model,Preprocessing 3 | # from detectron_seg import Model,Preprocessing 4 | import cv2 5 | import numpy as np 6 | import io 7 | from PIL import Image 8 | 9 | app = Flask(__name__) 10 | 11 | def get_arr(arg): 12 | return np.array(arg).astype('uint8') 13 | 14 | seg_model = Model() 15 | 16 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5)) 17 | 18 | preprocessor = Preprocessing(kernel) 19 | 20 | 21 | @app.route('/',methods=['POST']) 22 | def gen_trimap(): 23 | data = request.get_json() 24 | image = get_arr(data['image']) 25 | output = seg_model.get_seg_output(image) 26 | if len(output) == 0: 27 | return(Response(status=406)) 28 | masks = np.array([mask.cpu().numpy() for mask,classes in output]) 29 | trimap = preprocessor.get_trimap(masks) 30 | return jsonify({'trimap':trimap.tolist()})#,'masks':masks.tolist()}) 31 | 32 | if __name__=='__main__': 33 | app.run(debug=True,threaded=True) -------------------------------------------------------------------------------- /API/Segmentation_API/detectron_seg.py: -------------------------------------------------------------------------------- 1 | from detectron2 import model_zoo 2 | from detectron2.engine import DefaultPredictor 3 | from detectron2.config import get_cfg 4 | from detectron2.utils.visualizer import Visualizer 5 | from detectron2.data import MetadataCatalog 6 | import torch 7 | import numpy as np 8 | import cv2 9 | 10 | 11 | 12 | 13 | class Model: 14 | def __init__(self,confidence_thresh=0.6): 15 | cfg = get_cfg() 16 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) 17 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_thresh # set threshold for this model 18 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") 19 | self.model = DefaultPredictor(cfg) 20 | 21 | 22 | def get_seg_output(self,image:np.array): 23 | out = self.model(image)['instances'] 24 | 25 | outputs = [(out.pred_masks[i],out.pred_classes[i]) for i in range(len(out.pred_classes)) if out.pred_classes[i]==0] 26 | 27 | return outputs 28 | 29 | 30 | 31 | class Preprocessing: 32 | def __init__(self,kernel,dilate_iter=5,erode_iter=1): 33 | self.kernel = kernel 34 | self.dilate_iter = dilate_iter 35 | self.erode_iter = erode_iter 36 | 37 | def get_target_mask(self,masks): 38 | out = np.zeros(masks[0].shape) 39 | for mask in masks: 40 | out += mask 41 | out = np.clip(out,0,1) 42 | return out 43 | 44 | def get_trimap(self,masks): 45 | target_mask = self.get_target_mask(masks) 46 | erode = cv2.erode(target_mask.astype('uint8'),self.kernel,iterations=self.erode_iter) 47 | dilate = cv2.dilate(target_mask.astype('uint8'),self.kernel,iterations=self.dilate_iter) 48 | h, w = target_mask.shape 49 | 50 | trimap = np.zeros((h, w, 2)) 51 | trimap[erode == 1, 1] = 1 52 | trimap[dilate == 0, 0] = 1 53 | 54 | return trimap 55 | 56 | 57 | -------------------------------------------------------------------------------- /API/Segmentation_API/detectron_seg.yml: -------------------------------------------------------------------------------- 1 | name: detectron 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - attrs=19.3.0=py_0 7 | - backcall=0.1.0=py38_0 8 | - bleach=3.1.4=py_0 9 | - ca-certificates=2020.6.24=0 10 | - certifi=2020.6.20=py38_0 11 | - click=7.1.2=py_0 12 | - decorator=4.4.2=py_0 13 | - defusedxml=0.6.0=py_0 14 | - entrypoints=0.3=py38_0 15 | - flask=1.1.2=py_0 16 | - gmp=6.1.2=h6c8ec71_1 17 | - importlib_metadata=1.5.0=py38_0 18 | - ipykernel=5.1.4=py38h39e3cac_0 19 | - ipython=7.13.0=py38h5ca1d4c_0 20 | - ipython_genutils=0.2.0=py38_0 21 | - itsdangerous=1.1.0=py_0 22 | - jedi=0.16.0=py38_1 23 | - jinja2=2.11.1=py_0 24 | - jsonschema=3.2.0=py38_0 25 | - jupyter_client=6.1.2=py_0 26 | - jupyter_core=4.6.3=py38_0 27 | - ld_impl_linux-64=2.33.1=h53a641e_7 28 | - libedit=3.1.20181209=hc058e9b_0 29 | - libffi=3.2.1=hd88cf55_4 30 | - libgcc-ng=9.1.0=hdf63c60_0 31 | - libsodium=1.0.16=h1bed415_0 32 | - libstdcxx-ng=9.1.0=hdf63c60_0 33 | - markupsafe=1.1.1=py38h7b6447c_0 34 | - mistune=0.8.4=py38h7b6447c_1000 35 | - nb_conda_kernels=2.2.3=py38_0 36 | - nbconvert=5.6.1=py38_0 37 | - nbformat=5.0.4=py_0 38 | - ncurses=6.2=he6710b0_0 39 | - notebook=6.0.3=py38_0 40 | - openssl=1.1.1g=h7b6447c_0 41 | - pandoc=2.2.3.2=0 42 | - pandocfilters=1.4.2=py38_1 43 | - parso=0.6.2=py_0 44 | - pexpect=4.8.0=py38_0 45 | - pickleshare=0.7.5=py38_1000 46 | - pip=20.0.2=py38_1 47 | - prometheus_client=0.7.1=py_0 48 | - prompt-toolkit=3.0.4=py_0 49 | - prompt_toolkit=3.0.4=0 50 | - ptyprocess=0.6.0=py38_0 51 | - pygments=2.6.1=py_0 52 | - pyrsistent=0.16.0=py38h7b6447c_0 53 | - python=3.8.2=hcf32534_0 54 | - python-dateutil=2.8.1=py_0 55 | - pyzmq=18.1.1=py38he6710b0_0 56 | - readline=8.0=h7b6447c_0 57 | - send2trash=1.5.0=py38_0 58 | - setuptools=46.1.3=py38_0 59 | - six=1.14.0=py38_0 60 | - sqlite=3.31.1=h7b6447c_0 61 | - terminado=0.8.3=py38_0 62 | - testpath=0.4.4=py_0 63 | - tk=8.6.8=hbc83047_0 64 | - tornado=6.0.4=py38h7b6447c_1 65 | - traitlets=4.3.3=py38_0 66 | - wcwidth=0.1.9=py_0 67 | - webencodings=0.5.1=py38_1 68 | - werkzeug=1.0.1=py_0 69 | - wheel=0.34.2=py38_0 70 | - xz=5.2.5=h7b6447c_0 71 | - zeromq=4.3.1=he6710b0_3 72 | - zipp=2.2.0=py_0 73 | - zlib=1.2.11=h7b6447c_3 74 | - pip: 75 | - absl-py==0.9.0 76 | - cachetools==4.1.0 77 | - chardet==3.0.4 78 | - cloudpickle==1.3.0 79 | - cycler==0.10.0 80 | - cython==0.29.16 81 | - detectron2==0.1.1+cu92 82 | - future==0.18.2 83 | - fvcore==0.1.dev200420 84 | - google-auth==1.14.0 85 | - google-auth-oauthlib==0.4.1 86 | - grpcio==1.28.1 87 | - idna==2.9 88 | - kiwisolver==1.2.0 89 | - markdown==3.2.1 90 | - matplotlib==3.2.1 91 | - numpy==1.18.3 92 | - oauthlib==3.1.0 93 | - opencv-python==4.2.0.34 94 | - pillow==7.1.1 95 | - portalocker==1.7.0 96 | - protobuf==3.11.3 97 | - pyasn1==0.4.8 98 | - pyasn1-modules==0.2.8 99 | - pycocotools==2.0 100 | - pydot==1.4.1 101 | - pyparsing==2.4.7 102 | - pyyaml==5.1 103 | - requests==2.23.0 104 | - requests-oauthlib==1.3.0 105 | - rsa==4.0 106 | - tabulate==0.8.7 107 | - tensorboard==2.2.1 108 | - tensorboard-plugin-wit==1.6.0.post3 109 | - termcolor==1.1.0 110 | - torch==1.4.0+cu92 111 | - torchvision==0.5.0+cu92 112 | - tqdm==4.45.0 113 | - urllib3==1.25.9 114 | - yacs==0.1.6 115 | prefix: /home/kakarot/anaconda3/envs/detectron 116 | 117 | -------------------------------------------------------------------------------- /API/Segmentation_API/requirements.yml: -------------------------------------------------------------------------------- 1 | name: Fba_matting 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - attrs=19.3.0=py_0 7 | - backcall=0.2.0=py_0 8 | - blas=1.0=mkl 9 | - bleach=3.1.5=py_0 10 | - ca-certificates=2020.6.24=0 11 | - certifi=2020.6.20=py37_0 12 | - click=7.1.2=py_0 13 | - dbus=1.13.16=hb2f20db_0 14 | - decorator=4.4.2=py_0 15 | - defusedxml=0.6.0=py_0 16 | - entrypoints=0.3=py37_0 17 | - expat=2.2.9=he6710b0_2 18 | - flask=1.1.2=py_0 19 | - fontconfig=2.13.0=h9420a91_0 20 | - freetype=2.10.2=h5ab3b9f_0 21 | - glib=2.65.0=h3eb4bd4_0 22 | - gst-plugins-base=1.14.0=hbbd80ab_1 23 | - gstreamer=1.14.0=hb31296c_0 24 | - icu=58.2=he6710b0_3 25 | - importlib-metadata=1.7.0=py37_0 26 | - importlib_metadata=1.7.0=0 27 | - intel-openmp=2020.1=217 28 | - ipykernel=5.3.0=py37h5ca1d4c_0 29 | - ipython=7.16.1=py37h5ca1d4c_0 30 | - ipython_genutils=0.2.0=py37_0 31 | - itsdangerous=1.1.0=py37_0 32 | - jedi=0.17.1=py37_0 33 | - jinja2=2.11.2=py_0 34 | - jpeg=9b=h024ee3a_2 35 | - jsonschema=3.2.0=py37_0 36 | - jupyter_client=6.1.3=py_0 37 | - jupyter_core=4.6.3=py37_0 38 | - kiwisolver=1.2.0=py37hfd86e86_0 39 | - lcms2=2.11=h396b838_0 40 | - ld_impl_linux-64=2.33.1=h53a641e_7 41 | - libedit=3.1.20191231=h7b6447c_0 42 | - libffi=3.3=he6710b0_1 43 | - libgcc-ng=9.1.0=hdf63c60_0 44 | - libgfortran-ng=7.3.0=hdf63c60_0 45 | - libpng=1.6.37=hbc83047_0 46 | - libsodium=1.0.18=h7b6447c_0 47 | - libstdcxx-ng=9.1.0=hdf63c60_0 48 | - libtiff=4.1.0=h2733197_1 49 | - libuuid=1.0.3=h1bed415_2 50 | - libxcb=1.14=h7b6447c_0 51 | - libxml2=2.9.10=he19cac6_1 52 | - lz4-c=1.9.2=he6710b0_0 53 | - markupsafe=1.1.1=py37h7b6447c_0 54 | - matplotlib=3.2.2=0 55 | - matplotlib-base=3.2.2=py37hef1b27d_0 56 | - mistune=0.8.4=py37h7b6447c_0 57 | - mkl=2020.1=217 58 | - mkl-service=2.3.0=py37he904b0f_0 59 | - mkl_fft=1.1.0=py37h23d657b_0 60 | - mkl_random=1.1.1=py37h0573a6f_0 61 | - nb_conda=2.2.1=py37_0 62 | - nb_conda_kernels=2.2.3=py37_0 63 | - nbconvert=5.6.1=py37_0 64 | - nbformat=5.0.7=py_0 65 | - ncurses=6.2=he6710b0_1 66 | - notebook=6.0.3=py37_0 67 | - numpy=1.18.5=py37ha1c710e_0 68 | - numpy-base=1.18.5=py37hde5b4d6_0 69 | - olefile=0.46=py37_0 70 | - openssl=1.1.1g=h7b6447c_0 71 | - packaging=20.4=py_0 72 | - pandoc=2.9.2.1=0 73 | - pandocfilters=1.4.2=py37_1 74 | - parso=0.7.0=py_0 75 | - pcre=8.44=he6710b0_0 76 | - pexpect=4.8.0=py37_0 77 | - pickleshare=0.7.5=py37_0 78 | - pillow=7.2.0=py37hb39fc2d_0 79 | - pip=20.1.1=py37_1 80 | - prometheus_client=0.8.0=py_0 81 | - prompt-toolkit=3.0.5=py_0 82 | - ptyprocess=0.6.0=py37_0 83 | - pygments=2.6.1=py_0 84 | - pyparsing=2.4.7=py_0 85 | - pyqt=5.9.2=py37h05f1152_2 86 | - pyrsistent=0.16.0=py37h7b6447c_0 87 | - python=3.7.7=hcff3b4d_5 88 | - python-dateutil=2.8.1=py_0 89 | - pyzmq=19.0.1=py37he6710b0_1 90 | - qt=5.9.7=h5867ecd_1 91 | - readline=8.0=h7b6447c_0 92 | - send2trash=1.5.0=py37_0 93 | - setuptools=47.3.1=py37_0 94 | - sip=4.19.8=py37hf484d3e_0 95 | - six=1.15.0=py_0 96 | - sqlite=3.32.3=h62c20be_0 97 | - terminado=0.8.3=py37_0 98 | - testpath=0.4.4=py_0 99 | - tk=8.6.10=hbc83047_0 100 | - tornado=6.0.4=py37h7b6447c_1 101 | - traitlets=4.3.3=py37_0 102 | - wcwidth=0.2.5=py_0 103 | - webencodings=0.5.1=py37_1 104 | - werkzeug=1.0.1=py_0 105 | - wheel=0.34.2=py37_0 106 | - xz=5.2.5=h7b6447c_0 107 | - zeromq=4.3.2=he6710b0_2 108 | - zipp=3.1.0=py_0 109 | - zlib=1.2.11=h7b6447c_3 110 | - zstd=1.4.5=h0b5b093_0 111 | - pip: 112 | - chardet==3.0.4 113 | - cycler==0.10.0 114 | - filelock==3.0.12 115 | - future==0.18.2 116 | - gdown==3.11.1 117 | - idna==2.10 118 | - opencv-python==4.2.0.34 119 | - pysocks==1.7.1 120 | - requests==2.24.0 121 | - torch==1.5.1 122 | - tqdm==4.47.0 123 | - urllib3==1.25.9 124 | prefix: /home/kakarot/anaconda3/envs/Fba_matting 125 | 126 | -------------------------------------------------------------------------------- /API/Segmentation_API/segmentation.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torchvision.transforms as T 3 | import torch 4 | import numpy as np 5 | import cv2 6 | import requests 7 | 8 | 9 | 10 | model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) 11 | 12 | 13 | 14 | class Model: 15 | def __init__(self,confidence_thresh=0.6): 16 | self.model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) 17 | self.model.eval(); 18 | self.transform = T.Compose([T.ToTensor()]) 19 | self.conf_thresh = confidence_thresh 20 | 21 | 22 | def get_seg_output(self,image:np.array): 23 | image = self.transform(image.copy()) 24 | print(image.shape) 25 | with torch.no_grad(): 26 | pred = self.model([image]) 27 | 28 | outputs = [(pred[0]['masks'][i][0],pred[0]['labels'][i]) for i in range(len(pred[0]['boxes'])) if pred[0]['scores'][i]>self.conf_thresh and pred[0]['labels'][i]==1] 29 | # outputs = [(pred[0]['masks'][i][0],pred[0]['labels'][i]) for i in range(len(pred[0]['boxes'])) if pred[0]['scores'][i]>self.conf_thresh] 30 | 31 | return outputs 32 | 33 | 34 | 35 | class Preprocessing: 36 | def __init__(self,kernel,lower_bound=0.1,upper_bound=0.9,dilate_iter=10,erode_iter=10): 37 | self.kernel = kernel 38 | self.low_thresh = lower_bound 39 | self.high_thresh = upper_bound 40 | self.dilate_iter = dilate_iter 41 | self.erode_iter = erode_iter 42 | 43 | def get_target_mask(self,masks): 44 | out = np.zeros(masks[0].shape) 45 | for mask in masks: 46 | out += mask 47 | out = np.clip(out,0,1) 48 | return out 49 | 50 | def get_trimap(self,masks): 51 | target_mask = self.get_target_mask(masks) 52 | foreground = target_mask >= self.high_thresh 53 | ambiguous = (target_mask < self.high_thresh)*(target_mask>=self.low_thresh) 54 | print(self.erode_iter) 55 | erode = cv2.erode(foreground.astype('uint8'),self.kernel,iterations=self.erode_iter) 56 | dilate = cv2.dilate(ambiguous.astype('uint8'),self.kernel,iterations=self.dilate_iter) 57 | h, w = target_mask.shape 58 | 59 | bg_giver = np.clip((erode + dilate),0,1 ) 60 | trimap = np.zeros((h, w, 2)) 61 | trimap[erode == 1, 1] = 1 62 | trimap[bg_giver == 0, 0] = 1 63 | 64 | return trimap 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /API/app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, jsonify,request,url_for,redirect 2 | import sys 3 | sys.path.append('..') 4 | from networks.models import build_model 5 | from matting import pred 6 | import requests 7 | import numpy as np 8 | import cv2 9 | import base64 10 | from PIL import Image 11 | import io 12 | 13 | app = Flask(__name__) 14 | 15 | def get_array(arg): 16 | return np.array(arg).astype('uint8') 17 | 18 | class Matting_Args: 19 | def __init__(self): 20 | self.encoder = 'resnet50_GN_WS' 21 | self.decoder = 'fba_decoder' 22 | self.weights = '../models/FBA.pth' 23 | 24 | args = Matting_Args() 25 | 26 | matting_model = build_model(args) 27 | matting_model.eval(); 28 | 29 | def get_response(new_bg,data): 30 | image = get_array(data.get('image')) 31 | response = requests.post('http://127.0.0.1:3000/',json = data) 32 | if response.status_code == 406: 33 | return jsonify({'output':image.tolist()}) 34 | h,w,_ = image.shape 35 | trimap = get_array(response.json()['trimap']) 36 | fg, bg, alpha = pred(image/255.0,trimap,matting_model) 37 | combined = ((alpha[...,None]*image)).astype('uint8') + ((1-alpha)[...,None]*cv2.resize(new_bg,(w,h))).astype('uint8') 38 | return jsonify({'output':combined.tolist()}) 39 | 40 | @app.route('/with_bg',methods=['POST']) 41 | def extraction(): 42 | data = request.get_json() 43 | new_bg = get_array(data.get('bg')) 44 | return get_response(new_bg,data) 45 | 46 | @app.route('/',methods=["POST"]) 47 | def extraction_without_bg(): 48 | data = request.get_json() 49 | new_bg = cv2.imread('1.jpg')[:,:,::-1] 50 | return get_response(new_bg,data) 51 | 52 | 53 | if __name__ == '__main__': 54 | app.run(debug=True,threaded=True) 55 | -------------------------------------------------------------------------------- /API/matting.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | 4 | 5 | from networks.transforms import trimap_transform, groupnorm_normalise_image 6 | from networks.models import build_model 7 | from dataloader import PredDataset 8 | 9 | # System libs 10 | import os 11 | import argparse 12 | 13 | # External libs 14 | import cv2 15 | import numpy as np 16 | import torch 17 | 18 | 19 | def np_to_torch(x): 20 | return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cuda() 21 | 22 | 23 | def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray: 24 | ''' Scales inputs to multiple of 8. ''' 25 | h, w = x.shape[:2] 26 | h1 = int(np.ceil(scale * h / 8) * 8) 27 | w1 = int(np.ceil(scale * w / 8) * 8) 28 | x_scale = cv2.resize(x, (w1, h1), interpolation=scale_type) 29 | return x_scale 30 | 31 | 32 | 33 | def pred(image_np: np.ndarray, trimap_np: np.ndarray, model) -> np.ndarray: 34 | ''' Predict alpha, foreground and background. 35 | Parameters: 36 | image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3) 37 | trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2) 38 | Returns: 39 | fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3) 40 | bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3) 41 | alpha: alpha matte image between 0 and 1. Dimensions: (h, w) 42 | ''' 43 | h, w = trimap_np.shape[:2] 44 | 45 | image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4) 46 | trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4) 47 | 48 | with torch.no_grad(): 49 | 50 | image_torch = np_to_torch(image_scale_np) 51 | trimap_torch = np_to_torch(trimap_scale_np) 52 | 53 | trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np)) 54 | image_transformed_torch = groupnorm_normalise_image(image_torch.clone(), format='nchw') 55 | 56 | output = model(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch) 57 | 58 | output = cv2.resize(output[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4) 59 | alpha = output[:, :, 0] 60 | fg = output[:, :, 1:4] 61 | bg = output[:, :, 4:7] 62 | 63 | alpha[trimap_np[:, :, 0] == 1] = 0 64 | alpha[trimap_np[:, :, 1] == 1] = 1 65 | fg[alpha == 1] = image_np[alpha == 1] 66 | bg[alpha == 0] = image_np[alpha == 0] 67 | return fg, bg, alpha 68 | 69 | -------------------------------------------------------------------------------- /API/requirements.yml: -------------------------------------------------------------------------------- 1 | name: fastai 2 | channels: 3 | - pytorch 4 | - fastai 5 | - anaconda 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - attrs=19.3.0=py_0 11 | - backcall=0.1.0=py37_0 12 | - beautifulsoup4=4.9.0=py37hc8dfbb8_0 13 | - blas=1.0=openblas 14 | - bleach=3.1.4=py_0 15 | - bottleneck=1.3.2=py37h03ebfcd_1 16 | - brotlipy=0.7.0=py37h8f50634_1000 17 | - bzip2=1.0.8=h7b6447c_0 18 | - c-ares=1.16.1=h516909a_0 19 | - ca-certificates=2020.6.20=hecda079_0 20 | - cachetools=4.1.1=py_0 21 | - cairo=1.14.12=h8948797_3 22 | - certifi=2020.6.20=py37hc8dfbb8_0 23 | - cffi=1.14.0=py37hd463f26_0 24 | - chardet=3.0.4=py37hc8dfbb8_1006 25 | - click=7.1.1=py_0 26 | - cryptography=2.8=py37hb09aad4_2 27 | - cudatoolkit=10.1.243=h6bb024c_0 28 | - cycler=0.10.0=py_2 29 | - cymem=2.0.2=py37he1b5a44_0 30 | - cython-blis=0.2.4=py37h516909a_1 31 | - dataclasses=0.6=py_0 32 | - dbus=1.13.6=he372182_0 33 | - decorator=4.4.2=py_0 34 | - defusedxml=0.6.0=py_0 35 | - entrypoints=0.3=py37_0 36 | - expat=2.2.9=he1b5a44_2 37 | - fastai=1.0.60=1 38 | - fastprogress=0.2.2=py_0 39 | - ffmpeg=4.0=hcdf2ecd_0 40 | - flask=1.1.1=py_1 41 | - fontconfig=2.13.1=he4413a7_1000 42 | - freeglut=3.0.0=hf484d3e_5 43 | - freetype=2.10.1=he06d7ca_0 44 | - gettext=0.19.8.1=hc5be6a0_1002 45 | - glib=2.64.2=h6f030ca_0 46 | - gmp=6.1.2=h6c8ec71_1 47 | - google-api-core-grpc=1.20.1=hc8dfbb8_0 48 | - google-cloud-vision=0.42.0=py37_0 49 | - graphite2=1.3.13=h23475e2_0 50 | - grpcio=1.30.0=py37hb0870dc_0 51 | - gst-plugins-base=1.14.5=h0935bb2_2 52 | - gstreamer=1.14.5=h36ae1b5_2 53 | - harfbuzz=1.8.8=hffaf4a1_0 54 | - hdf5=1.10.2=hba1933b_1 55 | - icu=58.2=hf484d3e_1000 56 | - importlib_metadata=1.5.0=py37_0 57 | - intel-openmp=2020.0=166 58 | - ipykernel=5.1.4=py37h39e3cac_0 59 | - ipython=7.13.0=py37h5ca1d4c_0 60 | - ipython_genutils=0.2.0=py37_0 61 | - ipywidgets=7.5.1=py_0 62 | - itsdangerous=1.1.0=py37_0 63 | - jasper=2.0.14=h07fcdf6_1 64 | - jedi=0.16.0=py37_1 65 | - jinja2=2.11.1=py_0 66 | - joblib=0.14.1=py_0 67 | - jpeg=9c=h14c3975_1001 68 | - jsonschema=3.2.0=py37_0 69 | - jupyter_client=6.1.2=py_0 70 | - jupyter_core=4.6.3=py37_0 71 | - kiwisolver=1.2.0=py37h99015e2_0 72 | - ld_impl_linux-64=2.33.1=h53a641e_7 73 | - libblas=3.8.0=14_openblas 74 | - libcblas=3.8.0=14_openblas 75 | - libedit=3.1.20181209=hc058e9b_0 76 | - libffi=3.2.1=hd88cf55_4 77 | - libgcc=7.2.0=h69d50b8_2 78 | - libgcc-ng=9.1.0=hdf63c60_0 79 | - libgfortran-ng=7.3.0=hdf63c60_5 80 | - libglu=9.0.0=hf484d3e_1 81 | - libiconv=1.15=h516909a_1006 82 | - liblapack=3.8.0=14_openblas 83 | - libopenblas=0.3.7=h5ec1e0e_6 84 | - libopencv=3.4.2=hb342d67_1 85 | - libopus=1.3=h7b6447c_0 86 | - libpng=1.6.37=hed695b0_1 87 | - libprotobuf=3.12.3=h8b12597_2 88 | - libsodium=1.0.16=h1bed415_0 89 | - libstdcxx-ng=9.1.0=hdf63c60_0 90 | - libtiff=4.1.0=hc7e4089_6 91 | - libuuid=2.32.1=h14c3975_1000 92 | - libvpx=1.7.0=h439df22_0 93 | - libwebp-base=1.1.0=h516909a_3 94 | - libxcb=1.13=h14c3975_1002 95 | - libxml2=2.9.9=hea5a465_1 96 | - lz4-c=1.9.2=he1b5a44_0 97 | - markupsafe=1.1.1=py37h7b6447c_0 98 | - matplotlib=3.1.3=py37_0 99 | - matplotlib-base=3.1.3=py37hef1b27d_0 100 | - mistune=0.8.4=py37h7b6447c_0 101 | - mkl=2020.0=166 102 | - murmurhash=1.0.0=py37h3340039_0 103 | - nb_conda_kernels=2.2.3=py37_0 104 | - nbconvert=5.6.1=py37_0 105 | - nbformat=5.0.4=py_0 106 | - ncurses=6.2=he6710b0_0 107 | - ninja=1.10.0=hc9558a2_0 108 | - nodejs=6.13.1=0 109 | - notebook=6.0.3=py37_0 110 | - numexpr=2.7.1=py37h0da4684_1 111 | - numpy=1.18.1=py37h8960a57_1 112 | - olefile=0.46=py_0 113 | - opencv=3.4.2=py37h6fd60c2_1 114 | - openssl=1.1.1g=h516909a_0 115 | - packaging=20.1=py_0 116 | - pandas=1.0.3=py37h0da4684_1 117 | - pandoc=2.2.3.2=0 118 | - pandocfilters=1.4.2=py37_1 119 | - parso=0.6.2=py_0 120 | - patsy=0.5.1=py_0 121 | - pcre=8.44=he1b5a44_0 122 | - pexpect=4.8.0=py37_0 123 | - pickleshare=0.7.5=py37_0 124 | - pillow=7.0.0=py37hb39fc2d_0 125 | - pip=20.0.2=py37_1 126 | - pixman=0.38.0=h7b6447c_0 127 | - plac=0.9.6=py37_0 128 | - preshed=2.0.1=py37he1b5a44_0 129 | - prometheus_client=0.7.1=py_0 130 | - prompt-toolkit=3.0.4=py_0 131 | - prompt_toolkit=3.0.4=0 132 | - pthread-stubs=0.4=h14c3975_1001 133 | - ptyprocess=0.6.0=py37_0 134 | - py-opencv=3.4.2=py37hb342d67_1 135 | - pyasn1=0.4.8=py_0 136 | - pycparser=2.20=py_0 137 | - pygments=2.6.1=py_0 138 | - pyopenssl=19.1.0=py_1 139 | - pyparsing=2.4.7=pyh9f0ad1d_0 140 | - pyqt=5.9.2=py37hcca6a23_4 141 | - pyrsistent=0.16.0=py37h7b6447c_0 142 | - pysocks=1.7.1=py37hc8dfbb8_1 143 | - python=3.7.7=hcf32534_0_cpython 144 | - python-dateutil=2.8.1=py_0 145 | - python_abi=3.7=1_cp37m 146 | - pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0 147 | - pyyaml=5.3.1=py37h8f50634_0 148 | - pyzmq=18.1.1=py37he6710b0_0 149 | - qt=5.9.7=h52cfd70_2 150 | - readline=8.0=h7b6447c_0 151 | - rsa=4.6=pyh9f0ad1d_0 152 | - scikit-learn=0.22.1=py37h22eb022_0 153 | - scipy=1.4.1=py37ha3d9a3c_3 154 | - seaborn=0.10.0=py_1 155 | - send2trash=1.5.0=py37_0 156 | - setuptools=46.1.3=py37_0 157 | - sip=4.19.8=py37hf484d3e_0 158 | - soupsieve=1.9.4=py37hc8dfbb8_1 159 | - spacy=2.1.8=py37hc9558a2_0 160 | - sqlite=3.31.1=h7b6447c_0 161 | - srsly=0.1.0=py37he1b5a44_0 162 | - statsmodels=0.11.1=py37h8f50634_1 163 | - terminado=0.8.3=py37_0 164 | - testpath=0.4.4=py_0 165 | - thinc=7.0.8=py37hc9558a2_0 166 | - tk=8.6.8=hbc83047_0 167 | - torchvision=0.5.0=py37_cu101 168 | - tornado=6.0.4=py37h7b6447c_1 169 | - tqdm=4.47.0=py_0 170 | - traitlets=4.3.3=py37_0 171 | - urllib3=1.25.9=py_0 172 | - wasabi=0.2.2=py_0 173 | - wcwidth=0.1.9=py_0 174 | - webencodings=0.5.1=py37_1 175 | - werkzeug=1.0.1=py_0 176 | - wheel=0.34.2=py37_0 177 | - widgetsnbextension=3.5.1=py37_0 178 | - xlrd=1.2.0=py37_0 179 | - xorg-libxau=1.0.9=h14c3975_0 180 | - xorg-libxdmcp=1.1.3=h516909a_0 181 | - xz=5.2.5=h7b6447c_0 182 | - yaml=0.2.3=h516909a_0 183 | - zeromq=4.3.1=he6710b0_3 184 | - zipp=2.2.0=py_0 185 | - zlib=1.2.11=h7b6447c_3 186 | - zstd=1.4.4=h6597ccf_3 187 | - pip: 188 | - blessings==1.7 189 | - colorthief==0.2.1 190 | - easyprocess==0.2.10 191 | - flask-cors==3.0.8 192 | - google-api-core==1.20.1 193 | - google-auth==1.19.1 194 | - googleapis-common-protos==1.52.0 195 | - gpustat==0.6.0 196 | - idna==2.10 197 | - nvidia-ml-py3==7.352.0 198 | - protobuf==3.12.2 199 | - psutil==5.7.0 200 | - pyasn1-modules==0.2.8 201 | - pytz==2020.1 202 | - pyunpack==0.1.2 203 | - requests==2.24.0 204 | - six==1.15.0 205 | - torch-summary==1.3.3 206 | prefix: /home/kakarot/anaconda3/envs/fastai 207 | 208 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Marco Forte 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Background Removal 3 | 4 | 5 |

6 | 7 |

8 | 9 | ## Requirements 10 | GPU memory >= 11GB for inference on Adobe Composition-1K testing set, more generally for resolutions above 1920x1080. 11 | 12 | #### Packages: 13 | - torch >= 1.4 14 | - numpy 15 | - opencv-python 16 | #### Additional Packages for jupyter notebook 17 | - matplotlib 18 | 19 | 20 | 21 | ## Prediction 22 | There is a script `inference.py` which gives the background subtracted from the provided image 23 | 24 | ## Citation 25 | 26 | Original github repository of FBA_matting : [link](https://github.com/MarcoForte/FBA_Matting) 27 | ``` 28 | @article{forte2020fbamatting, 29 | title = {F, B, Alpha Matting}, 30 | author = {Marco Forte and François Pitié}, 31 | journal = {CoRR}, 32 | volume = {abs/2003.07711}, 33 | year = {2020}, 34 | } 35 | ``` 36 | 37 | -------------------------------------------------------------------------------- /Videos/frienvengers.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/Videos/frienvengers.avi -------------------------------------------------------------------------------- /Videos/frienvengers_using_detectron.avi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/Videos/frienvengers_using_detectron.avi -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import cv2 4 | import os 5 | 6 | 7 | class PredDataset(Dataset): 8 | ''' Reads image and trimap pairs from folder. 9 | 10 | ''' 11 | 12 | def __init__(self, img_dir, trimap_dir): 13 | self.img_dir, self.trimap_dir = img_dir, trimap_dir 14 | self.img_names = [x for x in os.listdir(self.img_dir) if 'png' in x] 15 | 16 | def __len__(self): 17 | return len(self.img_names) 18 | 19 | def __getitem__(self, idx): 20 | img_name = self.img_names[idx] 21 | 22 | image = read_image(os.path.join(self.img_dir, img_name)) 23 | trimap = read_trimap(os.path.join(self.trimap_dir, img_name)) 24 | pred_dict = {'image': image, 'trimap': trimap, 'name': img_name} 25 | 26 | return pred_dict 27 | 28 | 29 | def read_image(name): 30 | return (cv2.imread(name) / 255.0)[:, :, ::-1] 31 | 32 | 33 | def read_trimap(name): 34 | trimap_im = cv2.imread(name, 0) / 255.0 35 | h, w = trimap_im.shape 36 | trimap = np.zeros((h, w, 2)) 37 | trimap[trimap_im == 1, 1] = 1 38 | trimap[trimap_im == 0, 0] = 1 39 | return trimap 40 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # Our libs 2 | from networks.transforms import trimap_transform, groupnorm_normalise_image 3 | from networks.models import build_model 4 | from dataloader import PredDataset 5 | 6 | # System libs 7 | import os 8 | import argparse 9 | 10 | # External libs 11 | import cv2 12 | import numpy as np 13 | import torch 14 | 15 | 16 | def np_to_torch(x): 17 | return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cuda() 18 | 19 | 20 | def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray: 21 | ''' Scales inputs to multiple of 8. ''' 22 | h, w = x.shape[:2] 23 | h1 = int(np.ceil(scale * h / 8) * 8) 24 | w1 = int(np.ceil(scale * w / 8) * 8) 25 | x_scale = cv2.resize(x, (w1, h1), interpolation=scale_type) 26 | return x_scale 27 | 28 | 29 | def predict_fba_folder(model, args): 30 | save_dir = args.output_dir 31 | 32 | dataset_test = PredDataset(args.image_dir, args.trimap_dir) 33 | 34 | gen = iter(dataset_test) 35 | for item_dict in gen: 36 | image_np = item_dict['image'] 37 | trimap_np = item_dict['trimap'] 38 | 39 | fg, bg, alpha = pred(image_np, trimap_np, model) 40 | 41 | cv2.imwrite(os.path.join(save_dir, item_dict['name'][:-4] + '_fg.png'), fg[:, :, ::-1] * 255) 42 | cv2.imwrite(os.path.join(save_dir, item_dict['name'][:-4] + '_bg.png'), bg[:, :, ::-1] * 255) 43 | cv2.imwrite(os.path.join(save_dir, item_dict['name'][:-4] + '_alpha.png'), alpha * 255) 44 | 45 | 46 | def pred(image_np: np.ndarray, trimap_np: np.ndarray, model) -> np.ndarray: 47 | ''' Predict alpha, foreground and background. 48 | Parameters: 49 | image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3) 50 | trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2) 51 | Returns: 52 | fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3) 53 | bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3) 54 | alpha: alpha matte image between 0 and 1. Dimensions: (h, w) 55 | ''' 56 | h, w = trimap_np.shape[:2] 57 | 58 | image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4) 59 | trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4) 60 | 61 | with torch.no_grad(): 62 | 63 | image_torch = np_to_torch(image_scale_np) 64 | trimap_torch = np_to_torch(trimap_scale_np) 65 | 66 | trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np)) 67 | image_transformed_torch = groupnorm_normalise_image(image_torch.clone(), format='nchw') 68 | 69 | output = model(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch) 70 | 71 | output = cv2.resize(output[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4) 72 | alpha = output[:, :, 0] 73 | fg = output[:, :, 1:4] 74 | bg = output[:, :, 4:7] 75 | 76 | alpha[trimap_np[:, :, 0] == 1] = 0 77 | alpha[trimap_np[:, :, 1] == 1] = 1 78 | fg[alpha == 1] = image_np[alpha == 1] 79 | bg[alpha == 0] = image_np[alpha == 0] 80 | return fg, bg, alpha 81 | 82 | 83 | if __name__ == '__main__': 84 | 85 | parser = argparse.ArgumentParser() 86 | # Model related arguments 87 | parser.add_argument('--encoder', default='resnet50_GN_WS', help="encoder model") 88 | parser.add_argument('--decoder', default='fba_decoder', help="Decoder model") 89 | parser.add_argument('--weights', default='FBA.pth') 90 | parser.add_argument('--image_dir', default='./examples/images', help="") 91 | parser.add_argument('--trimap_dir', default='./examples/trimaps', help="") 92 | parser.add_argument('--output_dir', default='./examples/predictions', help="") 93 | 94 | args = parser.parse_args() 95 | model = build_model(args) 96 | model.eval() 97 | predict_fba_folder(model, args) 98 | -------------------------------------------------------------------------------- /examples/example_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/example_results.png -------------------------------------------------------------------------------- /examples/images/elephant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/images/elephant.png -------------------------------------------------------------------------------- /examples/images/troll.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/images/troll.png -------------------------------------------------------------------------------- /examples/predictions/elephant_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/elephant_alpha.png -------------------------------------------------------------------------------- /examples/predictions/elephant_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/elephant_bg.png -------------------------------------------------------------------------------- /examples/predictions/elephant_fg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/elephant_fg.png -------------------------------------------------------------------------------- /examples/predictions/troll_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/troll_alpha.png -------------------------------------------------------------------------------- /examples/predictions/troll_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/troll_bg.png -------------------------------------------------------------------------------- /examples/predictions/troll_fg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/troll_fg.png -------------------------------------------------------------------------------- /examples/trimaps/elephant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/trimaps/elephant.png -------------------------------------------------------------------------------- /examples/trimaps/troll.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/trimaps/troll.png -------------------------------------------------------------------------------- /frienvengers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import requests\n", 10 | "import pickle\n", 11 | "import numpy as np\n", 12 | "import cv2\n", 13 | "import matplotlib.pyplot as plt" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "cap1 = cv2.VideoCapture('Videos/friends.avi') #friends_video " 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "cap2 = cv2.VideoCapture('Videos/avengers.avi')#avengers_video" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 4, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "fourcc = cv2.VideoWriter_fourcc(*'XVID')\n", 41 | "writer = cv2.VideoWriter(f'Videos/frienvengers_using_detectron.avi', fourcc, 25.0, (360,240)) #output video" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "scrolled": true 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "ret1,org_fr = cap1.read()\n", 53 | "ret2,new_bg = cap2.read()\n", 54 | "\n", 55 | "count = 0\n", 56 | "while(ret1 and ret2):\n", 57 | " print(count)\n", 58 | " \n", 59 | " org_fr = cv2.resize(org_fr,(360,240))[:,:,::-1]\n", 60 | " new_bg = cv2.resize(new_bg,(360,240))[:,:,::-1]\n", 61 | " response = requests.post('http://127.0.0.1:5000/with_bg',json={'image':org_fr.tolist(),'bg':new_bg.tolist()})\n", 62 | " combined = np.array(response.json()['output'])\n", 63 | " count+=1\n", 64 | " ret1,org_fr = cap1.read()\n", 65 | " ret2,new_bg = cap2.read()\n", 66 | " writer.write(combined[:,:,::-1].astype('uint8'))" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [] 75 | } 76 | ], 77 | "metadata": { 78 | "kernelspec": { 79 | "display_name": "Python [conda env:Fba_matting]", 80 | "language": "python", 81 | "name": "conda-env-Fba_matting-py" 82 | }, 83 | "language_info": { 84 | "codemirror_mode": { 85 | "name": "ipython", 86 | "version": 3 87 | }, 88 | "file_extension": ".py", 89 | "mimetype": "text/x-python", 90 | "name": "python", 91 | "nbconvert_exporter": "python", 92 | "pygments_lexer": "ipython3", 93 | "version": "3.7.7" 94 | } 95 | }, 96 | "nbformat": 4, 97 | "nbformat_minor": 4 98 | } 99 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import argparse 4 | import requests 5 | import matplotlib.pyplot as plt 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument('--image_path',type=str) 11 | parser.add_argument('--output_path', type=str) 12 | args = parser.parse_args() 13 | img_path = args.image_path 14 | 15 | out_path = args.output_path 16 | img = cv2.imread(img_path)[:,:,::-1] 17 | # print(img.shape) 18 | response = requests.post('http://127.0.0.1:5000/',json={'image':img.tolist()}) 19 | out = np.array(response.json()['output']) 20 | plt.imsave(out_path+'/out.jpg',out.astype('uint8')) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from networks.models import build_model 2 | 3 | 4 | class Args: 5 | def __init__(self): 6 | self.encoder = 'resnet50_GN_WS' 7 | self.decoder = 'fba_decoder' 8 | self.weights = 'FBA.pth' 9 | 10 | args = Args() 11 | Model = build_model(args) -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/networks/__init__.py -------------------------------------------------------------------------------- /networks/layers_WS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class Conv2d(nn.Conv2d): 7 | 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 9 | padding=0, dilation=1, groups=1, bias=True): 10 | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, 11 | padding, dilation, groups, bias) 12 | 13 | def forward(self, x): 14 | # return super(Conv2d, self).forward(x) 15 | weight = self.weight 16 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, 17 | keepdim=True).mean(dim=3, keepdim=True) 18 | weight = weight - weight_mean 19 | # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 20 | std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1) + 1e-5 21 | weight = weight / std.expand_as(weight) 22 | return F.conv2d(x, weight, self.bias, self.stride, 23 | self.padding, self.dilation, self.groups) 24 | 25 | 26 | def BatchNorm2d(num_features): 27 | return nn.GroupNorm(num_channels=num_features, num_groups=32) 28 | -------------------------------------------------------------------------------- /networks/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import networks.resnet_GN_WS as resnet_GN_WS 4 | import networks.layers_WS as L 5 | import networks.resnet_bn as resnet_bn 6 | 7 | 8 | def build_model(args): 9 | builder = ModelBuilder() 10 | net_encoder = builder.build_encoder(arch=args.encoder) 11 | 12 | if('BN' in args.encoder): 13 | batch_norm = True 14 | else: 15 | batch_norm = False 16 | net_decoder = builder.build_decoder(arch=args.decoder, batch_norm=batch_norm) 17 | 18 | model = MattingModule(net_encoder, net_decoder) 19 | 20 | model.cuda() 21 | 22 | if(args.weights != 'default'): 23 | sd = torch.load(args.weights) 24 | model.load_state_dict(sd, strict=True) 25 | 26 | return model 27 | 28 | 29 | class MattingModule(nn.Module): 30 | def __init__(self, net_enc, net_dec): 31 | super(MattingModule, self).__init__() 32 | self.encoder = net_enc 33 | self.decoder = net_dec 34 | 35 | def forward(self, image, two_chan_trimap, image_n, trimap_transformed): 36 | resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1) 37 | conv_out, indices = self.encoder(resnet_input, return_feature_maps=True) 38 | return self.decoder(conv_out, image, indices, two_chan_trimap) 39 | 40 | 41 | class ModelBuilder(): 42 | def build_encoder(self, arch='resnet50_GN'): 43 | if arch == 'resnet50_GN_WS': 44 | orig_resnet = resnet_GN_WS.__dict__['l_resnet50']() 45 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) 46 | elif arch == 'resnet50_BN': 47 | orig_resnet = resnet_bn.__dict__['l_resnet50']() 48 | net_encoder = ResnetDilatedBN(orig_resnet, dilate_scale=8) 49 | 50 | else: 51 | raise Exception('Architecture undefined!') 52 | 53 | num_channels = 3 + 6 + 2 54 | 55 | if(num_channels > 3): 56 | print(f'modifying input layer to accept {num_channels} channels') 57 | net_encoder_sd = net_encoder.state_dict() 58 | conv1_weights = net_encoder_sd['conv1.weight'] 59 | 60 | c_out, c_in, h, w = conv1_weights.size() 61 | conv1_mod = torch.zeros(c_out, num_channels, h, w) 62 | conv1_mod[:, :3, :, :] = conv1_weights 63 | 64 | conv1 = net_encoder.conv1 65 | conv1.in_channels = num_channels 66 | conv1.weight = torch.nn.Parameter(conv1_mod) 67 | 68 | net_encoder.conv1 = conv1 69 | 70 | net_encoder_sd['conv1.weight'] = conv1_mod 71 | 72 | net_encoder.load_state_dict(net_encoder_sd) 73 | return net_encoder 74 | 75 | def build_decoder(self, arch='fba_decoder', batch_norm=False): 76 | if arch == 'fba_decoder': 77 | net_decoder = fba_decoder(batch_norm=batch_norm) 78 | 79 | return net_decoder 80 | 81 | 82 | class ResnetDilatedBN(nn.Module): 83 | def __init__(self, orig_resnet, dilate_scale=8): 84 | super(ResnetDilatedBN, self).__init__() 85 | from functools import partial 86 | 87 | if dilate_scale == 8: 88 | orig_resnet.layer3.apply( 89 | partial(self._nostride_dilate, dilate=2)) 90 | orig_resnet.layer4.apply( 91 | partial(self._nostride_dilate, dilate=4)) 92 | elif dilate_scale == 16: 93 | orig_resnet.layer4.apply( 94 | partial(self._nostride_dilate, dilate=2)) 95 | 96 | # take pretrained resnet, except AvgPool and FC 97 | self.conv1 = orig_resnet.conv1 98 | self.bn1 = orig_resnet.bn1 99 | self.relu1 = orig_resnet.relu1 100 | self.conv2 = orig_resnet.conv2 101 | self.bn2 = orig_resnet.bn2 102 | self.relu2 = orig_resnet.relu2 103 | self.conv3 = orig_resnet.conv3 104 | self.bn3 = orig_resnet.bn3 105 | self.relu3 = orig_resnet.relu3 106 | self.maxpool = orig_resnet.maxpool 107 | self.layer1 = orig_resnet.layer1 108 | self.layer2 = orig_resnet.layer2 109 | self.layer3 = orig_resnet.layer3 110 | self.layer4 = orig_resnet.layer4 111 | 112 | def _nostride_dilate(self, m, dilate): 113 | classname = m.__class__.__name__ 114 | if classname.find('Conv') != -1: 115 | # the convolution with stride 116 | if m.stride == (2, 2): 117 | m.stride = (1, 1) 118 | if m.kernel_size == (3, 3): 119 | m.dilation = (dilate // 2, dilate // 2) 120 | m.padding = (dilate // 2, dilate // 2) 121 | # other convoluions 122 | else: 123 | if m.kernel_size == (3, 3): 124 | m.dilation = (dilate, dilate) 125 | m.padding = (dilate, dilate) 126 | 127 | def forward(self, x, return_feature_maps=False): 128 | conv_out = [x] 129 | x = self.relu1(self.bn1(self.conv1(x))) 130 | x = self.relu2(self.bn2(self.conv2(x))) 131 | x = self.relu3(self.bn3(self.conv3(x))) 132 | conv_out.append(x) 133 | x, indices = self.maxpool(x) 134 | x = self.layer1(x) 135 | conv_out.append(x) 136 | x = self.layer2(x) 137 | conv_out.append(x) 138 | x = self.layer3(x) 139 | conv_out.append(x) 140 | x = self.layer4(x) 141 | conv_out.append(x) 142 | 143 | if return_feature_maps: 144 | return conv_out, indices 145 | return [x] 146 | 147 | 148 | class Resnet(nn.Module): 149 | def __init__(self, orig_resnet): 150 | super(Resnet, self).__init__() 151 | 152 | # take pretrained resnet, except AvgPool and FC 153 | self.conv1 = orig_resnet.conv1 154 | self.bn1 = orig_resnet.bn1 155 | self.relu1 = orig_resnet.relu1 156 | self.conv2 = orig_resnet.conv2 157 | self.bn2 = orig_resnet.bn2 158 | self.relu2 = orig_resnet.relu2 159 | self.conv3 = orig_resnet.conv3 160 | self.bn3 = orig_resnet.bn3 161 | self.relu3 = orig_resnet.relu3 162 | self.maxpool = orig_resnet.maxpool 163 | self.layer1 = orig_resnet.layer1 164 | self.layer2 = orig_resnet.layer2 165 | self.layer3 = orig_resnet.layer3 166 | self.layer4 = orig_resnet.layer4 167 | 168 | def forward(self, x, return_feature_maps=False): 169 | conv_out = [] 170 | 171 | x = self.relu1(self.bn1(self.conv1(x))) 172 | x = self.relu2(self.bn2(self.conv2(x))) 173 | x = self.relu3(self.bn3(self.conv3(x))) 174 | conv_out.append(x) 175 | x, indices = self.maxpool(x) 176 | 177 | x = self.layer1(x) 178 | conv_out.append(x) 179 | x = self.layer2(x) 180 | conv_out.append(x) 181 | x = self.layer3(x) 182 | conv_out.append(x) 183 | x = self.layer4(x) 184 | conv_out.append(x) 185 | 186 | if return_feature_maps: 187 | return conv_out 188 | return [x] 189 | 190 | 191 | class ResnetDilated(nn.Module): 192 | def __init__(self, orig_resnet, dilate_scale=8): 193 | super(ResnetDilated, self).__init__() 194 | from functools import partial 195 | 196 | if dilate_scale == 8: 197 | orig_resnet.layer3.apply( 198 | partial(self._nostride_dilate, dilate=2)) 199 | orig_resnet.layer4.apply( 200 | partial(self._nostride_dilate, dilate=4)) 201 | elif dilate_scale == 16: 202 | orig_resnet.layer4.apply( 203 | partial(self._nostride_dilate, dilate=2)) 204 | 205 | # take pretrained resnet, except AvgPool and FC 206 | self.conv1 = orig_resnet.conv1 207 | self.bn1 = orig_resnet.bn1 208 | self.relu = orig_resnet.relu 209 | self.maxpool = orig_resnet.maxpool 210 | self.layer1 = orig_resnet.layer1 211 | self.layer2 = orig_resnet.layer2 212 | self.layer3 = orig_resnet.layer3 213 | self.layer4 = orig_resnet.layer4 214 | 215 | def _nostride_dilate(self, m, dilate): 216 | classname = m.__class__.__name__ 217 | if classname.find('Conv') != -1: 218 | # the convolution with stride 219 | if m.stride == (2, 2): 220 | m.stride = (1, 1) 221 | if m.kernel_size == (3, 3): 222 | m.dilation = (dilate // 2, dilate // 2) 223 | m.padding = (dilate // 2, dilate // 2) 224 | # other convoluions 225 | else: 226 | if m.kernel_size == (3, 3): 227 | m.dilation = (dilate, dilate) 228 | m.padding = (dilate, dilate) 229 | 230 | def forward(self, x, return_feature_maps=False): 231 | conv_out = [x] 232 | x = self.relu(self.bn1(self.conv1(x))) 233 | conv_out.append(x) 234 | x, indices = self.maxpool(x) 235 | x = self.layer1(x) 236 | conv_out.append(x) 237 | x = self.layer2(x) 238 | conv_out.append(x) 239 | x = self.layer3(x) 240 | conv_out.append(x) 241 | x = self.layer4(x) 242 | conv_out.append(x) 243 | 244 | if return_feature_maps: 245 | return conv_out, indices 246 | return [x] 247 | 248 | 249 | def norm(dim, bn=False): 250 | if(bn is False): 251 | return nn.GroupNorm(32, dim) 252 | else: 253 | return nn.BatchNorm2d(dim) 254 | 255 | 256 | def fba_fusion(alpha, img, F, B): 257 | F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B)) 258 | B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F) 259 | 260 | F = torch.clamp(F, 0, 1) 261 | B = torch.clamp(B, 0, 1) 262 | la = 0.1 263 | alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la) 264 | alpha = torch.clamp(alpha, 0, 1) 265 | return alpha, F, B 266 | 267 | 268 | class fba_decoder(nn.Module): 269 | def __init__(self, batch_norm=False): 270 | super(fba_decoder, self).__init__() 271 | pool_scales = (1, 2, 3, 6) 272 | self.batch_norm = batch_norm 273 | 274 | self.ppm = [] 275 | 276 | for scale in pool_scales: 277 | self.ppm.append(nn.Sequential( 278 | nn.AdaptiveAvgPool2d(scale), 279 | L.Conv2d(2048, 256, kernel_size=1, bias=True), 280 | norm(256, self.batch_norm), 281 | nn.LeakyReLU() 282 | )) 283 | self.ppm = nn.ModuleList(self.ppm) 284 | 285 | self.conv_up1 = nn.Sequential( 286 | L.Conv2d(2048 + len(pool_scales) * 256, 256, 287 | kernel_size=3, padding=1, bias=True), 288 | 289 | norm(256, self.batch_norm), 290 | nn.LeakyReLU(), 291 | L.Conv2d(256, 256, kernel_size=3, padding=1), 292 | norm(256, self.batch_norm), 293 | nn.LeakyReLU() 294 | ) 295 | 296 | self.conv_up2 = nn.Sequential( 297 | L.Conv2d(256 + 256, 256, 298 | kernel_size=3, padding=1, bias=True), 299 | norm(256, self.batch_norm), 300 | nn.LeakyReLU() 301 | ) 302 | if(self.batch_norm): 303 | d_up3 = 128 304 | else: 305 | d_up3 = 64 306 | self.conv_up3 = nn.Sequential( 307 | L.Conv2d(256 + d_up3, 64, 308 | kernel_size=3, padding=1, bias=True), 309 | norm(64, self.batch_norm), 310 | nn.LeakyReLU() 311 | ) 312 | 313 | self.unpool = nn.MaxUnpool2d(2, stride=2) 314 | 315 | self.conv_up4 = nn.Sequential( 316 | nn.Conv2d(64 + 3 + 3 + 2, 32, 317 | kernel_size=3, padding=1, bias=True), 318 | nn.LeakyReLU(), 319 | nn.Conv2d(32, 16, 320 | kernel_size=3, padding=1, bias=True), 321 | 322 | nn.LeakyReLU(), 323 | nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True) 324 | ) 325 | 326 | def forward(self, conv_out, img, indices, two_chan_trimap): 327 | conv5 = conv_out[-1] 328 | 329 | input_size = conv5.size() 330 | ppm_out = [conv5] 331 | for pool_scale in self.ppm: 332 | ppm_out.append(nn.functional.interpolate( 333 | pool_scale(conv5), 334 | (input_size[2], input_size[3]), 335 | mode='bilinear', align_corners=False)) 336 | ppm_out = torch.cat(ppm_out, 1) 337 | x = self.conv_up1(ppm_out) 338 | 339 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 340 | 341 | x = torch.cat((x, conv_out[-4]), 1) 342 | 343 | x = self.conv_up2(x) 344 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 345 | 346 | x = torch.cat((x, conv_out[-5]), 1) 347 | x = self.conv_up3(x) 348 | 349 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) 350 | x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1) 351 | 352 | output = self.conv_up4(x) 353 | 354 | alpha = torch.clamp(output[:, 0][:, None], 0, 1) 355 | F = torch.sigmoid(output[:, 1:4]) 356 | B = torch.sigmoid(output[:, 4:7]) 357 | 358 | # FBA Fusion 359 | alpha, F, B = fba_fusion(alpha, img, F, B) 360 | 361 | output = torch.cat((alpha, F, B), 1) 362 | 363 | return output 364 | -------------------------------------------------------------------------------- /networks/resnet_GN_WS.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import networks.layers_WS as L 3 | 4 | __all__ = ['ResNet', 'l_resnet50'] 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return L.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | def conv1x1(in_planes, out_planes, stride=1): 14 | """1x1 convolution""" 15 | return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bn1 = L.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn2 = L.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | identity = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | identity = self.downsample(x) 43 | 44 | out += identity 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | 50 | class Bottleneck(nn.Module): 51 | expansion = 4 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None): 54 | super(Bottleneck, self).__init__() 55 | self.conv1 = conv1x1(inplanes, planes) 56 | self.bn1 = L.BatchNorm2d(planes) 57 | self.conv2 = conv3x3(planes, planes, stride) 58 | self.bn2 = L.BatchNorm2d(planes) 59 | self.conv3 = conv1x1(planes, planes * self.expansion) 60 | self.bn3 = L.BatchNorm2d(planes * self.expansion) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | identity = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | 79 | if self.downsample is not None: 80 | identity = self.downsample(x) 81 | 82 | out += identity 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class ResNet(nn.Module): 89 | 90 | def __init__(self, block, layers, num_classes=1000): 91 | super(ResNet, self).__init__() 92 | self.inplanes = 64 93 | self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 94 | bias=False) 95 | self.bn1 = L.BatchNorm2d(64) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) 98 | self.layer1 = self._make_layer(block, 64, layers[0]) 99 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 100 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 101 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 102 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 103 | self.fc = nn.Linear(512 * block.expansion, num_classes) 104 | 105 | def _make_layer(self, block, planes, blocks, stride=1): 106 | downsample = None 107 | if stride != 1 or self.inplanes != planes * block.expansion: 108 | downsample = nn.Sequential( 109 | conv1x1(self.inplanes, planes * block.expansion, stride), 110 | L.BatchNorm2d(planes * block.expansion), 111 | ) 112 | 113 | layers = [] 114 | layers.append(block(self.inplanes, planes, stride, downsample)) 115 | self.inplanes = planes * block.expansion 116 | for _ in range(1, blocks): 117 | layers.append(block(self.inplanes, planes)) 118 | 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | x = self.conv1(x) 123 | x = self.bn1(x) 124 | x = self.relu(x) 125 | x = self.maxpool(x) 126 | 127 | x = self.layer1(x) 128 | x = self.layer2(x) 129 | x = self.layer3(x) 130 | x = self.layer4(x) 131 | 132 | x = self.avgpool(x) 133 | x = x.view(x.size(0), -1) 134 | x = self.fc(x) 135 | 136 | return x 137 | 138 | 139 | def l_resnet50(pretrained=False, **kwargs): 140 | """Constructs a ResNet-50 model. 141 | Args: 142 | pretrained (bool): If True, returns a model pre-trained on ImageNet 143 | """ 144 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 145 | return model 146 | -------------------------------------------------------------------------------- /networks/resnet_bn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from torch.nn import BatchNorm2d 4 | 5 | __all__ = ['ResNet'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn1 = BatchNorm2d(planes) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.conv2 = conv3x3(planes, planes) 23 | self.bn2 = BatchNorm2d(planes) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None): 50 | super(Bottleneck, self).__init__() 51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 52 | self.bn1 = BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 54 | padding=1, bias=False) 55 | self.bn2 = BatchNorm2d(planes, momentum=0.01) 56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 57 | self.bn3 = BatchNorm2d(planes * 4) 58 | self.relu = nn.ReLU(inplace=True) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv3(out) 74 | out = self.bn3(out) 75 | 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class ResNet(nn.Module): 86 | 87 | def __init__(self, block, layers, num_classes=1000): 88 | self.inplanes = 128 89 | super(ResNet, self).__init__() 90 | self.conv1 = conv3x3(3, 64, stride=2) 91 | self.bn1 = BatchNorm2d(64) 92 | self.relu1 = nn.ReLU(inplace=True) 93 | self.conv2 = conv3x3(64, 64) 94 | self.bn2 = BatchNorm2d(64) 95 | self.relu2 = nn.ReLU(inplace=True) 96 | self.conv3 = conv3x3(64, 128) 97 | self.bn3 = BatchNorm2d(128) 98 | self.relu3 = nn.ReLU(inplace=True) 99 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True) 100 | 101 | self.layer1 = self._make_layer(block, 64, layers[0]) 102 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 103 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 104 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 105 | self.avgpool = nn.AvgPool2d(7, stride=1) 106 | self.fc = nn.Linear(512 * block.expansion, num_classes) 107 | 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 111 | m.weight.data.normal_(0, math.sqrt(2. / n)) 112 | elif isinstance(m, BatchNorm2d): 113 | m.weight.data.fill_(1) 114 | m.bias.data.zero_() 115 | 116 | def _make_layer(self, block, planes, blocks, stride=1): 117 | downsample = None 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | downsample = nn.Sequential( 120 | nn.Conv2d(self.inplanes, planes * block.expansion, 121 | kernel_size=1, stride=stride, bias=False), 122 | BatchNorm2d(planes * block.expansion), 123 | ) 124 | 125 | layers = [] 126 | layers.append(block(self.inplanes, planes, stride, downsample)) 127 | self.inplanes = planes * block.expansion 128 | for i in range(1, blocks): 129 | layers.append(block(self.inplanes, planes)) 130 | 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | x = self.relu1(self.bn1(self.conv1(x))) 135 | x = self.relu2(self.bn2(self.conv2(x))) 136 | x = self.relu3(self.bn3(self.conv3(x))) 137 | x, indices = self.maxpool(x) 138 | 139 | x = self.layer1(x) 140 | x = self.layer2(x) 141 | x = self.layer3(x) 142 | x = self.layer4(x) 143 | 144 | x = self.avgpool(x) 145 | x = x.view(x.size(0), -1) 146 | x = self.fc(x) 147 | return x 148 | 149 | 150 | def l_resnet50(): 151 | """Constructs a ResNet-50 model. 152 | Args: 153 | pretrained (bool): If True, returns a model pre-trained on ImageNet 154 | """ 155 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 156 | return model 157 | -------------------------------------------------------------------------------- /networks/transforms.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import cv2 5 | 6 | 7 | def dt(a): 8 | return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0) 9 | 10 | 11 | def trimap_transform(trimap): 12 | h, w = trimap.shape[0], trimap.shape[1] 13 | 14 | clicks = np.zeros((h, w, 6)) 15 | for k in range(2): 16 | if(np.count_nonzero(trimap[:, :, k]) > 0): 17 | dt_mask = -dt(1 - trimap[:, :, k])**2 18 | L = 320 19 | clicks[:, :, 3*k] = np.exp(dt_mask / (2 * ((0.02 * L)**2))) 20 | clicks[:, :, 3*k+1] = np.exp(dt_mask / (2 * ((0.08 * L)**2))) 21 | clicks[:, :, 3*k+2] = np.exp(dt_mask / (2 * ((0.16 * L)**2))) 22 | 23 | return clicks 24 | 25 | 26 | # For RGB ! 27 | group_norm_std = [0.229, 0.224, 0.225] 28 | group_norm_mean = [0.485, 0.456, 0.406] 29 | 30 | 31 | def groupnorm_normalise_image(img, format='nhwc'): 32 | ''' 33 | Accept rgb in range 0,1 34 | ''' 35 | if(format == 'nhwc'): 36 | for i in range(3): 37 | img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i] 38 | else: 39 | for i in range(3): 40 | img[..., i, :, :] = (img[..., i, :, :] - group_norm_mean[i]) / group_norm_std[i] 41 | 42 | return img 43 | 44 | 45 | def groupnorm_denormalise_image(img, format='nhwc'): 46 | ''' 47 | Accept rgb, normalised, return in range 0,1 48 | ''' 49 | if(format == 'nhwc'): 50 | for i in range(3): 51 | img[:, :, :, i] = img[:, :, :, i] * group_norm_std[i] + group_norm_mean[i] 52 | else: 53 | img1 = torch.zeros_like(img).cuda() 54 | for i in range(3): 55 | img1[:, i, :, :] = img[:, i, :, :] * group_norm_std[i] + group_norm_mean[i] 56 | return img1 57 | return img 58 | --------------------------------------------------------------------------------