├── .gitignore
├── API
├── 1.jpg
├── Segmentation_API
│ ├── app.py
│ ├── detectron_seg.py
│ ├── detectron_seg.yml
│ ├── requirements.yml
│ └── segmentation.py
├── app.py
├── matting.py
└── requirements.yml
├── LICENSE
├── README.md
├── Videos
├── frienvengers.avi
└── frienvengers_using_detectron.avi
├── dataloader.py
├── demo.py
├── examples
├── example_results.png
├── images
│ ├── elephant.png
│ └── troll.png
├── predictions
│ ├── elephant_alpha.png
│ ├── elephant_bg.png
│ ├── elephant_fg.png
│ ├── troll_alpha.png
│ ├── troll_bg.png
│ └── troll_fg.png
└── trimaps
│ ├── elephant.png
│ └── troll.png
├── frienvengers.ipynb
├── inference.py
├── model.py
└── networks
├── __init__.py
├── layers_WS.py
├── models.py
├── resnet_GN_WS.py
├── resnet_bn.py
└── transforms.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ./vscode/*
2 | *.pyc
3 | ./vscode
4 | .vscode/*
5 |
6 | *.pth
7 | .ipynb_checkpoints
8 | API/Segmentation_API/segmentation_orig.py
9 | Experiment.ipynb
10 | FBA_Matting.ipynb
11 | Videos/avengers.avi
12 | Videos/friends.avi
13 | API/app2.py
14 | API/Segmentation_API/segmentation_orig.py
15 |
--------------------------------------------------------------------------------
/API/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/API/1.jpg
--------------------------------------------------------------------------------
/API/Segmentation_API/app.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, jsonify,request,Response
2 | from segmentation import Model,Preprocessing
3 | # from detectron_seg import Model,Preprocessing
4 | import cv2
5 | import numpy as np
6 | import io
7 | from PIL import Image
8 |
9 | app = Flask(__name__)
10 |
11 | def get_arr(arg):
12 | return np.array(arg).astype('uint8')
13 |
14 | seg_model = Model()
15 |
16 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))
17 |
18 | preprocessor = Preprocessing(kernel)
19 |
20 |
21 | @app.route('/',methods=['POST'])
22 | def gen_trimap():
23 | data = request.get_json()
24 | image = get_arr(data['image'])
25 | output = seg_model.get_seg_output(image)
26 | if len(output) == 0:
27 | return(Response(status=406))
28 | masks = np.array([mask.cpu().numpy() for mask,classes in output])
29 | trimap = preprocessor.get_trimap(masks)
30 | return jsonify({'trimap':trimap.tolist()})#,'masks':masks.tolist()})
31 |
32 | if __name__=='__main__':
33 | app.run(debug=True,threaded=True)
--------------------------------------------------------------------------------
/API/Segmentation_API/detectron_seg.py:
--------------------------------------------------------------------------------
1 | from detectron2 import model_zoo
2 | from detectron2.engine import DefaultPredictor
3 | from detectron2.config import get_cfg
4 | from detectron2.utils.visualizer import Visualizer
5 | from detectron2.data import MetadataCatalog
6 | import torch
7 | import numpy as np
8 | import cv2
9 |
10 |
11 |
12 |
13 | class Model:
14 | def __init__(self,confidence_thresh=0.6):
15 | cfg = get_cfg()
16 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
17 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_thresh # set threshold for this model
18 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
19 | self.model = DefaultPredictor(cfg)
20 |
21 |
22 | def get_seg_output(self,image:np.array):
23 | out = self.model(image)['instances']
24 |
25 | outputs = [(out.pred_masks[i],out.pred_classes[i]) for i in range(len(out.pred_classes)) if out.pred_classes[i]==0]
26 |
27 | return outputs
28 |
29 |
30 |
31 | class Preprocessing:
32 | def __init__(self,kernel,dilate_iter=5,erode_iter=1):
33 | self.kernel = kernel
34 | self.dilate_iter = dilate_iter
35 | self.erode_iter = erode_iter
36 |
37 | def get_target_mask(self,masks):
38 | out = np.zeros(masks[0].shape)
39 | for mask in masks:
40 | out += mask
41 | out = np.clip(out,0,1)
42 | return out
43 |
44 | def get_trimap(self,masks):
45 | target_mask = self.get_target_mask(masks)
46 | erode = cv2.erode(target_mask.astype('uint8'),self.kernel,iterations=self.erode_iter)
47 | dilate = cv2.dilate(target_mask.astype('uint8'),self.kernel,iterations=self.dilate_iter)
48 | h, w = target_mask.shape
49 |
50 | trimap = np.zeros((h, w, 2))
51 | trimap[erode == 1, 1] = 1
52 | trimap[dilate == 0, 0] = 1
53 |
54 | return trimap
55 |
56 |
57 |
--------------------------------------------------------------------------------
/API/Segmentation_API/detectron_seg.yml:
--------------------------------------------------------------------------------
1 | name: detectron
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - attrs=19.3.0=py_0
7 | - backcall=0.1.0=py38_0
8 | - bleach=3.1.4=py_0
9 | - ca-certificates=2020.6.24=0
10 | - certifi=2020.6.20=py38_0
11 | - click=7.1.2=py_0
12 | - decorator=4.4.2=py_0
13 | - defusedxml=0.6.0=py_0
14 | - entrypoints=0.3=py38_0
15 | - flask=1.1.2=py_0
16 | - gmp=6.1.2=h6c8ec71_1
17 | - importlib_metadata=1.5.0=py38_0
18 | - ipykernel=5.1.4=py38h39e3cac_0
19 | - ipython=7.13.0=py38h5ca1d4c_0
20 | - ipython_genutils=0.2.0=py38_0
21 | - itsdangerous=1.1.0=py_0
22 | - jedi=0.16.0=py38_1
23 | - jinja2=2.11.1=py_0
24 | - jsonschema=3.2.0=py38_0
25 | - jupyter_client=6.1.2=py_0
26 | - jupyter_core=4.6.3=py38_0
27 | - ld_impl_linux-64=2.33.1=h53a641e_7
28 | - libedit=3.1.20181209=hc058e9b_0
29 | - libffi=3.2.1=hd88cf55_4
30 | - libgcc-ng=9.1.0=hdf63c60_0
31 | - libsodium=1.0.16=h1bed415_0
32 | - libstdcxx-ng=9.1.0=hdf63c60_0
33 | - markupsafe=1.1.1=py38h7b6447c_0
34 | - mistune=0.8.4=py38h7b6447c_1000
35 | - nb_conda_kernels=2.2.3=py38_0
36 | - nbconvert=5.6.1=py38_0
37 | - nbformat=5.0.4=py_0
38 | - ncurses=6.2=he6710b0_0
39 | - notebook=6.0.3=py38_0
40 | - openssl=1.1.1g=h7b6447c_0
41 | - pandoc=2.2.3.2=0
42 | - pandocfilters=1.4.2=py38_1
43 | - parso=0.6.2=py_0
44 | - pexpect=4.8.0=py38_0
45 | - pickleshare=0.7.5=py38_1000
46 | - pip=20.0.2=py38_1
47 | - prometheus_client=0.7.1=py_0
48 | - prompt-toolkit=3.0.4=py_0
49 | - prompt_toolkit=3.0.4=0
50 | - ptyprocess=0.6.0=py38_0
51 | - pygments=2.6.1=py_0
52 | - pyrsistent=0.16.0=py38h7b6447c_0
53 | - python=3.8.2=hcf32534_0
54 | - python-dateutil=2.8.1=py_0
55 | - pyzmq=18.1.1=py38he6710b0_0
56 | - readline=8.0=h7b6447c_0
57 | - send2trash=1.5.0=py38_0
58 | - setuptools=46.1.3=py38_0
59 | - six=1.14.0=py38_0
60 | - sqlite=3.31.1=h7b6447c_0
61 | - terminado=0.8.3=py38_0
62 | - testpath=0.4.4=py_0
63 | - tk=8.6.8=hbc83047_0
64 | - tornado=6.0.4=py38h7b6447c_1
65 | - traitlets=4.3.3=py38_0
66 | - wcwidth=0.1.9=py_0
67 | - webencodings=0.5.1=py38_1
68 | - werkzeug=1.0.1=py_0
69 | - wheel=0.34.2=py38_0
70 | - xz=5.2.5=h7b6447c_0
71 | - zeromq=4.3.1=he6710b0_3
72 | - zipp=2.2.0=py_0
73 | - zlib=1.2.11=h7b6447c_3
74 | - pip:
75 | - absl-py==0.9.0
76 | - cachetools==4.1.0
77 | - chardet==3.0.4
78 | - cloudpickle==1.3.0
79 | - cycler==0.10.0
80 | - cython==0.29.16
81 | - detectron2==0.1.1+cu92
82 | - future==0.18.2
83 | - fvcore==0.1.dev200420
84 | - google-auth==1.14.0
85 | - google-auth-oauthlib==0.4.1
86 | - grpcio==1.28.1
87 | - idna==2.9
88 | - kiwisolver==1.2.0
89 | - markdown==3.2.1
90 | - matplotlib==3.2.1
91 | - numpy==1.18.3
92 | - oauthlib==3.1.0
93 | - opencv-python==4.2.0.34
94 | - pillow==7.1.1
95 | - portalocker==1.7.0
96 | - protobuf==3.11.3
97 | - pyasn1==0.4.8
98 | - pyasn1-modules==0.2.8
99 | - pycocotools==2.0
100 | - pydot==1.4.1
101 | - pyparsing==2.4.7
102 | - pyyaml==5.1
103 | - requests==2.23.0
104 | - requests-oauthlib==1.3.0
105 | - rsa==4.0
106 | - tabulate==0.8.7
107 | - tensorboard==2.2.1
108 | - tensorboard-plugin-wit==1.6.0.post3
109 | - termcolor==1.1.0
110 | - torch==1.4.0+cu92
111 | - torchvision==0.5.0+cu92
112 | - tqdm==4.45.0
113 | - urllib3==1.25.9
114 | - yacs==0.1.6
115 | prefix: /home/kakarot/anaconda3/envs/detectron
116 |
117 |
--------------------------------------------------------------------------------
/API/Segmentation_API/requirements.yml:
--------------------------------------------------------------------------------
1 | name: Fba_matting
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - attrs=19.3.0=py_0
7 | - backcall=0.2.0=py_0
8 | - blas=1.0=mkl
9 | - bleach=3.1.5=py_0
10 | - ca-certificates=2020.6.24=0
11 | - certifi=2020.6.20=py37_0
12 | - click=7.1.2=py_0
13 | - dbus=1.13.16=hb2f20db_0
14 | - decorator=4.4.2=py_0
15 | - defusedxml=0.6.0=py_0
16 | - entrypoints=0.3=py37_0
17 | - expat=2.2.9=he6710b0_2
18 | - flask=1.1.2=py_0
19 | - fontconfig=2.13.0=h9420a91_0
20 | - freetype=2.10.2=h5ab3b9f_0
21 | - glib=2.65.0=h3eb4bd4_0
22 | - gst-plugins-base=1.14.0=hbbd80ab_1
23 | - gstreamer=1.14.0=hb31296c_0
24 | - icu=58.2=he6710b0_3
25 | - importlib-metadata=1.7.0=py37_0
26 | - importlib_metadata=1.7.0=0
27 | - intel-openmp=2020.1=217
28 | - ipykernel=5.3.0=py37h5ca1d4c_0
29 | - ipython=7.16.1=py37h5ca1d4c_0
30 | - ipython_genutils=0.2.0=py37_0
31 | - itsdangerous=1.1.0=py37_0
32 | - jedi=0.17.1=py37_0
33 | - jinja2=2.11.2=py_0
34 | - jpeg=9b=h024ee3a_2
35 | - jsonschema=3.2.0=py37_0
36 | - jupyter_client=6.1.3=py_0
37 | - jupyter_core=4.6.3=py37_0
38 | - kiwisolver=1.2.0=py37hfd86e86_0
39 | - lcms2=2.11=h396b838_0
40 | - ld_impl_linux-64=2.33.1=h53a641e_7
41 | - libedit=3.1.20191231=h7b6447c_0
42 | - libffi=3.3=he6710b0_1
43 | - libgcc-ng=9.1.0=hdf63c60_0
44 | - libgfortran-ng=7.3.0=hdf63c60_0
45 | - libpng=1.6.37=hbc83047_0
46 | - libsodium=1.0.18=h7b6447c_0
47 | - libstdcxx-ng=9.1.0=hdf63c60_0
48 | - libtiff=4.1.0=h2733197_1
49 | - libuuid=1.0.3=h1bed415_2
50 | - libxcb=1.14=h7b6447c_0
51 | - libxml2=2.9.10=he19cac6_1
52 | - lz4-c=1.9.2=he6710b0_0
53 | - markupsafe=1.1.1=py37h7b6447c_0
54 | - matplotlib=3.2.2=0
55 | - matplotlib-base=3.2.2=py37hef1b27d_0
56 | - mistune=0.8.4=py37h7b6447c_0
57 | - mkl=2020.1=217
58 | - mkl-service=2.3.0=py37he904b0f_0
59 | - mkl_fft=1.1.0=py37h23d657b_0
60 | - mkl_random=1.1.1=py37h0573a6f_0
61 | - nb_conda=2.2.1=py37_0
62 | - nb_conda_kernels=2.2.3=py37_0
63 | - nbconvert=5.6.1=py37_0
64 | - nbformat=5.0.7=py_0
65 | - ncurses=6.2=he6710b0_1
66 | - notebook=6.0.3=py37_0
67 | - numpy=1.18.5=py37ha1c710e_0
68 | - numpy-base=1.18.5=py37hde5b4d6_0
69 | - olefile=0.46=py37_0
70 | - openssl=1.1.1g=h7b6447c_0
71 | - packaging=20.4=py_0
72 | - pandoc=2.9.2.1=0
73 | - pandocfilters=1.4.2=py37_1
74 | - parso=0.7.0=py_0
75 | - pcre=8.44=he6710b0_0
76 | - pexpect=4.8.0=py37_0
77 | - pickleshare=0.7.5=py37_0
78 | - pillow=7.2.0=py37hb39fc2d_0
79 | - pip=20.1.1=py37_1
80 | - prometheus_client=0.8.0=py_0
81 | - prompt-toolkit=3.0.5=py_0
82 | - ptyprocess=0.6.0=py37_0
83 | - pygments=2.6.1=py_0
84 | - pyparsing=2.4.7=py_0
85 | - pyqt=5.9.2=py37h05f1152_2
86 | - pyrsistent=0.16.0=py37h7b6447c_0
87 | - python=3.7.7=hcff3b4d_5
88 | - python-dateutil=2.8.1=py_0
89 | - pyzmq=19.0.1=py37he6710b0_1
90 | - qt=5.9.7=h5867ecd_1
91 | - readline=8.0=h7b6447c_0
92 | - send2trash=1.5.0=py37_0
93 | - setuptools=47.3.1=py37_0
94 | - sip=4.19.8=py37hf484d3e_0
95 | - six=1.15.0=py_0
96 | - sqlite=3.32.3=h62c20be_0
97 | - terminado=0.8.3=py37_0
98 | - testpath=0.4.4=py_0
99 | - tk=8.6.10=hbc83047_0
100 | - tornado=6.0.4=py37h7b6447c_1
101 | - traitlets=4.3.3=py37_0
102 | - wcwidth=0.2.5=py_0
103 | - webencodings=0.5.1=py37_1
104 | - werkzeug=1.0.1=py_0
105 | - wheel=0.34.2=py37_0
106 | - xz=5.2.5=h7b6447c_0
107 | - zeromq=4.3.2=he6710b0_2
108 | - zipp=3.1.0=py_0
109 | - zlib=1.2.11=h7b6447c_3
110 | - zstd=1.4.5=h0b5b093_0
111 | - pip:
112 | - chardet==3.0.4
113 | - cycler==0.10.0
114 | - filelock==3.0.12
115 | - future==0.18.2
116 | - gdown==3.11.1
117 | - idna==2.10
118 | - opencv-python==4.2.0.34
119 | - pysocks==1.7.1
120 | - requests==2.24.0
121 | - torch==1.5.1
122 | - tqdm==4.47.0
123 | - urllib3==1.25.9
124 | prefix: /home/kakarot/anaconda3/envs/Fba_matting
125 |
126 |
--------------------------------------------------------------------------------
/API/Segmentation_API/segmentation.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | import torchvision.transforms as T
3 | import torch
4 | import numpy as np
5 | import cv2
6 | import requests
7 |
8 |
9 |
10 | model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
11 |
12 |
13 |
14 | class Model:
15 | def __init__(self,confidence_thresh=0.6):
16 | self.model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
17 | self.model.eval();
18 | self.transform = T.Compose([T.ToTensor()])
19 | self.conf_thresh = confidence_thresh
20 |
21 |
22 | def get_seg_output(self,image:np.array):
23 | image = self.transform(image.copy())
24 | print(image.shape)
25 | with torch.no_grad():
26 | pred = self.model([image])
27 |
28 | outputs = [(pred[0]['masks'][i][0],pred[0]['labels'][i]) for i in range(len(pred[0]['boxes'])) if pred[0]['scores'][i]>self.conf_thresh and pred[0]['labels'][i]==1]
29 | # outputs = [(pred[0]['masks'][i][0],pred[0]['labels'][i]) for i in range(len(pred[0]['boxes'])) if pred[0]['scores'][i]>self.conf_thresh]
30 |
31 | return outputs
32 |
33 |
34 |
35 | class Preprocessing:
36 | def __init__(self,kernel,lower_bound=0.1,upper_bound=0.9,dilate_iter=10,erode_iter=10):
37 | self.kernel = kernel
38 | self.low_thresh = lower_bound
39 | self.high_thresh = upper_bound
40 | self.dilate_iter = dilate_iter
41 | self.erode_iter = erode_iter
42 |
43 | def get_target_mask(self,masks):
44 | out = np.zeros(masks[0].shape)
45 | for mask in masks:
46 | out += mask
47 | out = np.clip(out,0,1)
48 | return out
49 |
50 | def get_trimap(self,masks):
51 | target_mask = self.get_target_mask(masks)
52 | foreground = target_mask >= self.high_thresh
53 | ambiguous = (target_mask < self.high_thresh)*(target_mask>=self.low_thresh)
54 | print(self.erode_iter)
55 | erode = cv2.erode(foreground.astype('uint8'),self.kernel,iterations=self.erode_iter)
56 | dilate = cv2.dilate(ambiguous.astype('uint8'),self.kernel,iterations=self.dilate_iter)
57 | h, w = target_mask.shape
58 |
59 | bg_giver = np.clip((erode + dilate),0,1 )
60 | trimap = np.zeros((h, w, 2))
61 | trimap[erode == 1, 1] = 1
62 | trimap[bg_giver == 0, 0] = 1
63 |
64 | return trimap
65 |
66 |
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/API/app.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, jsonify,request,url_for,redirect
2 | import sys
3 | sys.path.append('..')
4 | from networks.models import build_model
5 | from matting import pred
6 | import requests
7 | import numpy as np
8 | import cv2
9 | import base64
10 | from PIL import Image
11 | import io
12 |
13 | app = Flask(__name__)
14 |
15 | def get_array(arg):
16 | return np.array(arg).astype('uint8')
17 |
18 | class Matting_Args:
19 | def __init__(self):
20 | self.encoder = 'resnet50_GN_WS'
21 | self.decoder = 'fba_decoder'
22 | self.weights = '../models/FBA.pth'
23 |
24 | args = Matting_Args()
25 |
26 | matting_model = build_model(args)
27 | matting_model.eval();
28 |
29 | def get_response(new_bg,data):
30 | image = get_array(data.get('image'))
31 | response = requests.post('http://127.0.0.1:3000/',json = data)
32 | if response.status_code == 406:
33 | return jsonify({'output':image.tolist()})
34 | h,w,_ = image.shape
35 | trimap = get_array(response.json()['trimap'])
36 | fg, bg, alpha = pred(image/255.0,trimap,matting_model)
37 | combined = ((alpha[...,None]*image)).astype('uint8') + ((1-alpha)[...,None]*cv2.resize(new_bg,(w,h))).astype('uint8')
38 | return jsonify({'output':combined.tolist()})
39 |
40 | @app.route('/with_bg',methods=['POST'])
41 | def extraction():
42 | data = request.get_json()
43 | new_bg = get_array(data.get('bg'))
44 | return get_response(new_bg,data)
45 |
46 | @app.route('/',methods=["POST"])
47 | def extraction_without_bg():
48 | data = request.get_json()
49 | new_bg = cv2.imread('1.jpg')[:,:,::-1]
50 | return get_response(new_bg,data)
51 |
52 |
53 | if __name__ == '__main__':
54 | app.run(debug=True,threaded=True)
55 |
--------------------------------------------------------------------------------
/API/matting.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
3 |
4 |
5 | from networks.transforms import trimap_transform, groupnorm_normalise_image
6 | from networks.models import build_model
7 | from dataloader import PredDataset
8 |
9 | # System libs
10 | import os
11 | import argparse
12 |
13 | # External libs
14 | import cv2
15 | import numpy as np
16 | import torch
17 |
18 |
19 | def np_to_torch(x):
20 | return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cuda()
21 |
22 |
23 | def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray:
24 | ''' Scales inputs to multiple of 8. '''
25 | h, w = x.shape[:2]
26 | h1 = int(np.ceil(scale * h / 8) * 8)
27 | w1 = int(np.ceil(scale * w / 8) * 8)
28 | x_scale = cv2.resize(x, (w1, h1), interpolation=scale_type)
29 | return x_scale
30 |
31 |
32 |
33 | def pred(image_np: np.ndarray, trimap_np: np.ndarray, model) -> np.ndarray:
34 | ''' Predict alpha, foreground and background.
35 | Parameters:
36 | image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3)
37 | trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2)
38 | Returns:
39 | fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3)
40 | bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3)
41 | alpha: alpha matte image between 0 and 1. Dimensions: (h, w)
42 | '''
43 | h, w = trimap_np.shape[:2]
44 |
45 | image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
46 | trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)
47 |
48 | with torch.no_grad():
49 |
50 | image_torch = np_to_torch(image_scale_np)
51 | trimap_torch = np_to_torch(trimap_scale_np)
52 |
53 | trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np))
54 | image_transformed_torch = groupnorm_normalise_image(image_torch.clone(), format='nchw')
55 |
56 | output = model(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch)
57 |
58 | output = cv2.resize(output[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4)
59 | alpha = output[:, :, 0]
60 | fg = output[:, :, 1:4]
61 | bg = output[:, :, 4:7]
62 |
63 | alpha[trimap_np[:, :, 0] == 1] = 0
64 | alpha[trimap_np[:, :, 1] == 1] = 1
65 | fg[alpha == 1] = image_np[alpha == 1]
66 | bg[alpha == 0] = image_np[alpha == 0]
67 | return fg, bg, alpha
68 |
69 |
--------------------------------------------------------------------------------
/API/requirements.yml:
--------------------------------------------------------------------------------
1 | name: fastai
2 | channels:
3 | - pytorch
4 | - fastai
5 | - anaconda
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _libgcc_mutex=0.1=main
10 | - attrs=19.3.0=py_0
11 | - backcall=0.1.0=py37_0
12 | - beautifulsoup4=4.9.0=py37hc8dfbb8_0
13 | - blas=1.0=openblas
14 | - bleach=3.1.4=py_0
15 | - bottleneck=1.3.2=py37h03ebfcd_1
16 | - brotlipy=0.7.0=py37h8f50634_1000
17 | - bzip2=1.0.8=h7b6447c_0
18 | - c-ares=1.16.1=h516909a_0
19 | - ca-certificates=2020.6.20=hecda079_0
20 | - cachetools=4.1.1=py_0
21 | - cairo=1.14.12=h8948797_3
22 | - certifi=2020.6.20=py37hc8dfbb8_0
23 | - cffi=1.14.0=py37hd463f26_0
24 | - chardet=3.0.4=py37hc8dfbb8_1006
25 | - click=7.1.1=py_0
26 | - cryptography=2.8=py37hb09aad4_2
27 | - cudatoolkit=10.1.243=h6bb024c_0
28 | - cycler=0.10.0=py_2
29 | - cymem=2.0.2=py37he1b5a44_0
30 | - cython-blis=0.2.4=py37h516909a_1
31 | - dataclasses=0.6=py_0
32 | - dbus=1.13.6=he372182_0
33 | - decorator=4.4.2=py_0
34 | - defusedxml=0.6.0=py_0
35 | - entrypoints=0.3=py37_0
36 | - expat=2.2.9=he1b5a44_2
37 | - fastai=1.0.60=1
38 | - fastprogress=0.2.2=py_0
39 | - ffmpeg=4.0=hcdf2ecd_0
40 | - flask=1.1.1=py_1
41 | - fontconfig=2.13.1=he4413a7_1000
42 | - freeglut=3.0.0=hf484d3e_5
43 | - freetype=2.10.1=he06d7ca_0
44 | - gettext=0.19.8.1=hc5be6a0_1002
45 | - glib=2.64.2=h6f030ca_0
46 | - gmp=6.1.2=h6c8ec71_1
47 | - google-api-core-grpc=1.20.1=hc8dfbb8_0
48 | - google-cloud-vision=0.42.0=py37_0
49 | - graphite2=1.3.13=h23475e2_0
50 | - grpcio=1.30.0=py37hb0870dc_0
51 | - gst-plugins-base=1.14.5=h0935bb2_2
52 | - gstreamer=1.14.5=h36ae1b5_2
53 | - harfbuzz=1.8.8=hffaf4a1_0
54 | - hdf5=1.10.2=hba1933b_1
55 | - icu=58.2=hf484d3e_1000
56 | - importlib_metadata=1.5.0=py37_0
57 | - intel-openmp=2020.0=166
58 | - ipykernel=5.1.4=py37h39e3cac_0
59 | - ipython=7.13.0=py37h5ca1d4c_0
60 | - ipython_genutils=0.2.0=py37_0
61 | - ipywidgets=7.5.1=py_0
62 | - itsdangerous=1.1.0=py37_0
63 | - jasper=2.0.14=h07fcdf6_1
64 | - jedi=0.16.0=py37_1
65 | - jinja2=2.11.1=py_0
66 | - joblib=0.14.1=py_0
67 | - jpeg=9c=h14c3975_1001
68 | - jsonschema=3.2.0=py37_0
69 | - jupyter_client=6.1.2=py_0
70 | - jupyter_core=4.6.3=py37_0
71 | - kiwisolver=1.2.0=py37h99015e2_0
72 | - ld_impl_linux-64=2.33.1=h53a641e_7
73 | - libblas=3.8.0=14_openblas
74 | - libcblas=3.8.0=14_openblas
75 | - libedit=3.1.20181209=hc058e9b_0
76 | - libffi=3.2.1=hd88cf55_4
77 | - libgcc=7.2.0=h69d50b8_2
78 | - libgcc-ng=9.1.0=hdf63c60_0
79 | - libgfortran-ng=7.3.0=hdf63c60_5
80 | - libglu=9.0.0=hf484d3e_1
81 | - libiconv=1.15=h516909a_1006
82 | - liblapack=3.8.0=14_openblas
83 | - libopenblas=0.3.7=h5ec1e0e_6
84 | - libopencv=3.4.2=hb342d67_1
85 | - libopus=1.3=h7b6447c_0
86 | - libpng=1.6.37=hed695b0_1
87 | - libprotobuf=3.12.3=h8b12597_2
88 | - libsodium=1.0.16=h1bed415_0
89 | - libstdcxx-ng=9.1.0=hdf63c60_0
90 | - libtiff=4.1.0=hc7e4089_6
91 | - libuuid=2.32.1=h14c3975_1000
92 | - libvpx=1.7.0=h439df22_0
93 | - libwebp-base=1.1.0=h516909a_3
94 | - libxcb=1.13=h14c3975_1002
95 | - libxml2=2.9.9=hea5a465_1
96 | - lz4-c=1.9.2=he1b5a44_0
97 | - markupsafe=1.1.1=py37h7b6447c_0
98 | - matplotlib=3.1.3=py37_0
99 | - matplotlib-base=3.1.3=py37hef1b27d_0
100 | - mistune=0.8.4=py37h7b6447c_0
101 | - mkl=2020.0=166
102 | - murmurhash=1.0.0=py37h3340039_0
103 | - nb_conda_kernels=2.2.3=py37_0
104 | - nbconvert=5.6.1=py37_0
105 | - nbformat=5.0.4=py_0
106 | - ncurses=6.2=he6710b0_0
107 | - ninja=1.10.0=hc9558a2_0
108 | - nodejs=6.13.1=0
109 | - notebook=6.0.3=py37_0
110 | - numexpr=2.7.1=py37h0da4684_1
111 | - numpy=1.18.1=py37h8960a57_1
112 | - olefile=0.46=py_0
113 | - opencv=3.4.2=py37h6fd60c2_1
114 | - openssl=1.1.1g=h516909a_0
115 | - packaging=20.1=py_0
116 | - pandas=1.0.3=py37h0da4684_1
117 | - pandoc=2.2.3.2=0
118 | - pandocfilters=1.4.2=py37_1
119 | - parso=0.6.2=py_0
120 | - patsy=0.5.1=py_0
121 | - pcre=8.44=he1b5a44_0
122 | - pexpect=4.8.0=py37_0
123 | - pickleshare=0.7.5=py37_0
124 | - pillow=7.0.0=py37hb39fc2d_0
125 | - pip=20.0.2=py37_1
126 | - pixman=0.38.0=h7b6447c_0
127 | - plac=0.9.6=py37_0
128 | - preshed=2.0.1=py37he1b5a44_0
129 | - prometheus_client=0.7.1=py_0
130 | - prompt-toolkit=3.0.4=py_0
131 | - prompt_toolkit=3.0.4=0
132 | - pthread-stubs=0.4=h14c3975_1001
133 | - ptyprocess=0.6.0=py37_0
134 | - py-opencv=3.4.2=py37hb342d67_1
135 | - pyasn1=0.4.8=py_0
136 | - pycparser=2.20=py_0
137 | - pygments=2.6.1=py_0
138 | - pyopenssl=19.1.0=py_1
139 | - pyparsing=2.4.7=pyh9f0ad1d_0
140 | - pyqt=5.9.2=py37hcca6a23_4
141 | - pyrsistent=0.16.0=py37h7b6447c_0
142 | - pysocks=1.7.1=py37hc8dfbb8_1
143 | - python=3.7.7=hcf32534_0_cpython
144 | - python-dateutil=2.8.1=py_0
145 | - python_abi=3.7=1_cp37m
146 | - pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0
147 | - pyyaml=5.3.1=py37h8f50634_0
148 | - pyzmq=18.1.1=py37he6710b0_0
149 | - qt=5.9.7=h52cfd70_2
150 | - readline=8.0=h7b6447c_0
151 | - rsa=4.6=pyh9f0ad1d_0
152 | - scikit-learn=0.22.1=py37h22eb022_0
153 | - scipy=1.4.1=py37ha3d9a3c_3
154 | - seaborn=0.10.0=py_1
155 | - send2trash=1.5.0=py37_0
156 | - setuptools=46.1.3=py37_0
157 | - sip=4.19.8=py37hf484d3e_0
158 | - soupsieve=1.9.4=py37hc8dfbb8_1
159 | - spacy=2.1.8=py37hc9558a2_0
160 | - sqlite=3.31.1=h7b6447c_0
161 | - srsly=0.1.0=py37he1b5a44_0
162 | - statsmodels=0.11.1=py37h8f50634_1
163 | - terminado=0.8.3=py37_0
164 | - testpath=0.4.4=py_0
165 | - thinc=7.0.8=py37hc9558a2_0
166 | - tk=8.6.8=hbc83047_0
167 | - torchvision=0.5.0=py37_cu101
168 | - tornado=6.0.4=py37h7b6447c_1
169 | - tqdm=4.47.0=py_0
170 | - traitlets=4.3.3=py37_0
171 | - urllib3=1.25.9=py_0
172 | - wasabi=0.2.2=py_0
173 | - wcwidth=0.1.9=py_0
174 | - webencodings=0.5.1=py37_1
175 | - werkzeug=1.0.1=py_0
176 | - wheel=0.34.2=py37_0
177 | - widgetsnbextension=3.5.1=py37_0
178 | - xlrd=1.2.0=py37_0
179 | - xorg-libxau=1.0.9=h14c3975_0
180 | - xorg-libxdmcp=1.1.3=h516909a_0
181 | - xz=5.2.5=h7b6447c_0
182 | - yaml=0.2.3=h516909a_0
183 | - zeromq=4.3.1=he6710b0_3
184 | - zipp=2.2.0=py_0
185 | - zlib=1.2.11=h7b6447c_3
186 | - zstd=1.4.4=h6597ccf_3
187 | - pip:
188 | - blessings==1.7
189 | - colorthief==0.2.1
190 | - easyprocess==0.2.10
191 | - flask-cors==3.0.8
192 | - google-api-core==1.20.1
193 | - google-auth==1.19.1
194 | - googleapis-common-protos==1.52.0
195 | - gpustat==0.6.0
196 | - idna==2.10
197 | - nvidia-ml-py3==7.352.0
198 | - protobuf==3.12.2
199 | - psutil==5.7.0
200 | - pyasn1-modules==0.2.8
201 | - pytz==2020.1
202 | - pyunpack==0.1.2
203 | - requests==2.24.0
204 | - six==1.15.0
205 | - torch-summary==1.3.3
206 | prefix: /home/kakarot/anaconda3/envs/fastai
207 |
208 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Marco Forte
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Background Removal
3 |
4 |
5 |
6 |
7 |
8 |
9 | ## Requirements
10 | GPU memory >= 11GB for inference on Adobe Composition-1K testing set, more generally for resolutions above 1920x1080.
11 |
12 | #### Packages:
13 | - torch >= 1.4
14 | - numpy
15 | - opencv-python
16 | #### Additional Packages for jupyter notebook
17 | - matplotlib
18 |
19 |
20 |
21 | ## Prediction
22 | There is a script `inference.py` which gives the background subtracted from the provided image
23 |
24 | ## Citation
25 |
26 | Original github repository of FBA_matting : [link](https://github.com/MarcoForte/FBA_Matting)
27 | ```
28 | @article{forte2020fbamatting,
29 | title = {F, B, Alpha Matting},
30 | author = {Marco Forte and François Pitié},
31 | journal = {CoRR},
32 | volume = {abs/2003.07711},
33 | year = {2020},
34 | }
35 | ```
36 |
37 |
--------------------------------------------------------------------------------
/Videos/frienvengers.avi:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/Videos/frienvengers.avi
--------------------------------------------------------------------------------
/Videos/frienvengers_using_detectron.avi:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/Videos/frienvengers_using_detectron.avi
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import numpy as np
3 | import cv2
4 | import os
5 |
6 |
7 | class PredDataset(Dataset):
8 | ''' Reads image and trimap pairs from folder.
9 |
10 | '''
11 |
12 | def __init__(self, img_dir, trimap_dir):
13 | self.img_dir, self.trimap_dir = img_dir, trimap_dir
14 | self.img_names = [x for x in os.listdir(self.img_dir) if 'png' in x]
15 |
16 | def __len__(self):
17 | return len(self.img_names)
18 |
19 | def __getitem__(self, idx):
20 | img_name = self.img_names[idx]
21 |
22 | image = read_image(os.path.join(self.img_dir, img_name))
23 | trimap = read_trimap(os.path.join(self.trimap_dir, img_name))
24 | pred_dict = {'image': image, 'trimap': trimap, 'name': img_name}
25 |
26 | return pred_dict
27 |
28 |
29 | def read_image(name):
30 | return (cv2.imread(name) / 255.0)[:, :, ::-1]
31 |
32 |
33 | def read_trimap(name):
34 | trimap_im = cv2.imread(name, 0) / 255.0
35 | h, w = trimap_im.shape
36 | trimap = np.zeros((h, w, 2))
37 | trimap[trimap_im == 1, 1] = 1
38 | trimap[trimap_im == 0, 0] = 1
39 | return trimap
40 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | # Our libs
2 | from networks.transforms import trimap_transform, groupnorm_normalise_image
3 | from networks.models import build_model
4 | from dataloader import PredDataset
5 |
6 | # System libs
7 | import os
8 | import argparse
9 |
10 | # External libs
11 | import cv2
12 | import numpy as np
13 | import torch
14 |
15 |
16 | def np_to_torch(x):
17 | return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cuda()
18 |
19 |
20 | def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray:
21 | ''' Scales inputs to multiple of 8. '''
22 | h, w = x.shape[:2]
23 | h1 = int(np.ceil(scale * h / 8) * 8)
24 | w1 = int(np.ceil(scale * w / 8) * 8)
25 | x_scale = cv2.resize(x, (w1, h1), interpolation=scale_type)
26 | return x_scale
27 |
28 |
29 | def predict_fba_folder(model, args):
30 | save_dir = args.output_dir
31 |
32 | dataset_test = PredDataset(args.image_dir, args.trimap_dir)
33 |
34 | gen = iter(dataset_test)
35 | for item_dict in gen:
36 | image_np = item_dict['image']
37 | trimap_np = item_dict['trimap']
38 |
39 | fg, bg, alpha = pred(image_np, trimap_np, model)
40 |
41 | cv2.imwrite(os.path.join(save_dir, item_dict['name'][:-4] + '_fg.png'), fg[:, :, ::-1] * 255)
42 | cv2.imwrite(os.path.join(save_dir, item_dict['name'][:-4] + '_bg.png'), bg[:, :, ::-1] * 255)
43 | cv2.imwrite(os.path.join(save_dir, item_dict['name'][:-4] + '_alpha.png'), alpha * 255)
44 |
45 |
46 | def pred(image_np: np.ndarray, trimap_np: np.ndarray, model) -> np.ndarray:
47 | ''' Predict alpha, foreground and background.
48 | Parameters:
49 | image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3)
50 | trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2)
51 | Returns:
52 | fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3)
53 | bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3)
54 | alpha: alpha matte image between 0 and 1. Dimensions: (h, w)
55 | '''
56 | h, w = trimap_np.shape[:2]
57 |
58 | image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
59 | trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)
60 |
61 | with torch.no_grad():
62 |
63 | image_torch = np_to_torch(image_scale_np)
64 | trimap_torch = np_to_torch(trimap_scale_np)
65 |
66 | trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np))
67 | image_transformed_torch = groupnorm_normalise_image(image_torch.clone(), format='nchw')
68 |
69 | output = model(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch)
70 |
71 | output = cv2.resize(output[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4)
72 | alpha = output[:, :, 0]
73 | fg = output[:, :, 1:4]
74 | bg = output[:, :, 4:7]
75 |
76 | alpha[trimap_np[:, :, 0] == 1] = 0
77 | alpha[trimap_np[:, :, 1] == 1] = 1
78 | fg[alpha == 1] = image_np[alpha == 1]
79 | bg[alpha == 0] = image_np[alpha == 0]
80 | return fg, bg, alpha
81 |
82 |
83 | if __name__ == '__main__':
84 |
85 | parser = argparse.ArgumentParser()
86 | # Model related arguments
87 | parser.add_argument('--encoder', default='resnet50_GN_WS', help="encoder model")
88 | parser.add_argument('--decoder', default='fba_decoder', help="Decoder model")
89 | parser.add_argument('--weights', default='FBA.pth')
90 | parser.add_argument('--image_dir', default='./examples/images', help="")
91 | parser.add_argument('--trimap_dir', default='./examples/trimaps', help="")
92 | parser.add_argument('--output_dir', default='./examples/predictions', help="")
93 |
94 | args = parser.parse_args()
95 | model = build_model(args)
96 | model.eval()
97 | predict_fba_folder(model, args)
98 |
--------------------------------------------------------------------------------
/examples/example_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/example_results.png
--------------------------------------------------------------------------------
/examples/images/elephant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/images/elephant.png
--------------------------------------------------------------------------------
/examples/images/troll.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/images/troll.png
--------------------------------------------------------------------------------
/examples/predictions/elephant_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/elephant_alpha.png
--------------------------------------------------------------------------------
/examples/predictions/elephant_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/elephant_bg.png
--------------------------------------------------------------------------------
/examples/predictions/elephant_fg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/elephant_fg.png
--------------------------------------------------------------------------------
/examples/predictions/troll_alpha.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/troll_alpha.png
--------------------------------------------------------------------------------
/examples/predictions/troll_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/troll_bg.png
--------------------------------------------------------------------------------
/examples/predictions/troll_fg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/predictions/troll_fg.png
--------------------------------------------------------------------------------
/examples/trimaps/elephant.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/trimaps/elephant.png
--------------------------------------------------------------------------------
/examples/trimaps/troll.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/examples/trimaps/troll.png
--------------------------------------------------------------------------------
/frienvengers.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import requests\n",
10 | "import pickle\n",
11 | "import numpy as np\n",
12 | "import cv2\n",
13 | "import matplotlib.pyplot as plt"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "cap1 = cv2.VideoCapture('Videos/friends.avi') #friends_video "
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 3,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "cap2 = cv2.VideoCapture('Videos/avengers.avi')#avengers_video"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 4,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "fourcc = cv2.VideoWriter_fourcc(*'XVID')\n",
41 | "writer = cv2.VideoWriter(f'Videos/frienvengers_using_detectron.avi', fourcc, 25.0, (360,240)) #output video"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {
48 | "scrolled": true
49 | },
50 | "outputs": [],
51 | "source": [
52 | "ret1,org_fr = cap1.read()\n",
53 | "ret2,new_bg = cap2.read()\n",
54 | "\n",
55 | "count = 0\n",
56 | "while(ret1 and ret2):\n",
57 | " print(count)\n",
58 | " \n",
59 | " org_fr = cv2.resize(org_fr,(360,240))[:,:,::-1]\n",
60 | " new_bg = cv2.resize(new_bg,(360,240))[:,:,::-1]\n",
61 | " response = requests.post('http://127.0.0.1:5000/with_bg',json={'image':org_fr.tolist(),'bg':new_bg.tolist()})\n",
62 | " combined = np.array(response.json()['output'])\n",
63 | " count+=1\n",
64 | " ret1,org_fr = cap1.read()\n",
65 | " ret2,new_bg = cap2.read()\n",
66 | " writer.write(combined[:,:,::-1].astype('uint8'))"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": []
75 | }
76 | ],
77 | "metadata": {
78 | "kernelspec": {
79 | "display_name": "Python [conda env:Fba_matting]",
80 | "language": "python",
81 | "name": "conda-env-Fba_matting-py"
82 | },
83 | "language_info": {
84 | "codemirror_mode": {
85 | "name": "ipython",
86 | "version": 3
87 | },
88 | "file_extension": ".py",
89 | "mimetype": "text/x-python",
90 | "name": "python",
91 | "nbconvert_exporter": "python",
92 | "pygments_lexer": "ipython3",
93 | "version": "3.7.7"
94 | }
95 | },
96 | "nbformat": 4,
97 | "nbformat_minor": 4
98 | }
99 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import argparse
4 | import requests
5 | import matplotlib.pyplot as plt
6 |
7 | if __name__ == '__main__':
8 | parser = argparse.ArgumentParser()
9 |
10 | parser.add_argument('--image_path',type=str)
11 | parser.add_argument('--output_path', type=str)
12 | args = parser.parse_args()
13 | img_path = args.image_path
14 |
15 | out_path = args.output_path
16 | img = cv2.imread(img_path)[:,:,::-1]
17 | # print(img.shape)
18 | response = requests.post('http://127.0.0.1:5000/',json={'image':img.tolist()})
19 | out = np.array(response.json()['output'])
20 | plt.imsave(out_path+'/out.jpg',out.astype('uint8'))
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from networks.models import build_model
2 |
3 |
4 | class Args:
5 | def __init__(self):
6 | self.encoder = 'resnet50_GN_WS'
7 | self.decoder = 'fba_decoder'
8 | self.weights = 'FBA.pth'
9 |
10 | args = Args()
11 | Model = build_model(args)
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/harsh2912/Background-removal/5149d7439d9f761abba1b2f3718cc3ad0c35a7f7/networks/__init__.py
--------------------------------------------------------------------------------
/networks/layers_WS.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class Conv2d(nn.Conv2d):
7 |
8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
9 | padding=0, dilation=1, groups=1, bias=True):
10 | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
11 | padding, dilation, groups, bias)
12 |
13 | def forward(self, x):
14 | # return super(Conv2d, self).forward(x)
15 | weight = self.weight
16 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
17 | keepdim=True).mean(dim=3, keepdim=True)
18 | weight = weight - weight_mean
19 | # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
20 | std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1) + 1e-5
21 | weight = weight / std.expand_as(weight)
22 | return F.conv2d(x, weight, self.bias, self.stride,
23 | self.padding, self.dilation, self.groups)
24 |
25 |
26 | def BatchNorm2d(num_features):
27 | return nn.GroupNorm(num_channels=num_features, num_groups=32)
28 |
--------------------------------------------------------------------------------
/networks/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import networks.resnet_GN_WS as resnet_GN_WS
4 | import networks.layers_WS as L
5 | import networks.resnet_bn as resnet_bn
6 |
7 |
8 | def build_model(args):
9 | builder = ModelBuilder()
10 | net_encoder = builder.build_encoder(arch=args.encoder)
11 |
12 | if('BN' in args.encoder):
13 | batch_norm = True
14 | else:
15 | batch_norm = False
16 | net_decoder = builder.build_decoder(arch=args.decoder, batch_norm=batch_norm)
17 |
18 | model = MattingModule(net_encoder, net_decoder)
19 |
20 | model.cuda()
21 |
22 | if(args.weights != 'default'):
23 | sd = torch.load(args.weights)
24 | model.load_state_dict(sd, strict=True)
25 |
26 | return model
27 |
28 |
29 | class MattingModule(nn.Module):
30 | def __init__(self, net_enc, net_dec):
31 | super(MattingModule, self).__init__()
32 | self.encoder = net_enc
33 | self.decoder = net_dec
34 |
35 | def forward(self, image, two_chan_trimap, image_n, trimap_transformed):
36 | resnet_input = torch.cat((image_n, trimap_transformed, two_chan_trimap), 1)
37 | conv_out, indices = self.encoder(resnet_input, return_feature_maps=True)
38 | return self.decoder(conv_out, image, indices, two_chan_trimap)
39 |
40 |
41 | class ModelBuilder():
42 | def build_encoder(self, arch='resnet50_GN'):
43 | if arch == 'resnet50_GN_WS':
44 | orig_resnet = resnet_GN_WS.__dict__['l_resnet50']()
45 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
46 | elif arch == 'resnet50_BN':
47 | orig_resnet = resnet_bn.__dict__['l_resnet50']()
48 | net_encoder = ResnetDilatedBN(orig_resnet, dilate_scale=8)
49 |
50 | else:
51 | raise Exception('Architecture undefined!')
52 |
53 | num_channels = 3 + 6 + 2
54 |
55 | if(num_channels > 3):
56 | print(f'modifying input layer to accept {num_channels} channels')
57 | net_encoder_sd = net_encoder.state_dict()
58 | conv1_weights = net_encoder_sd['conv1.weight']
59 |
60 | c_out, c_in, h, w = conv1_weights.size()
61 | conv1_mod = torch.zeros(c_out, num_channels, h, w)
62 | conv1_mod[:, :3, :, :] = conv1_weights
63 |
64 | conv1 = net_encoder.conv1
65 | conv1.in_channels = num_channels
66 | conv1.weight = torch.nn.Parameter(conv1_mod)
67 |
68 | net_encoder.conv1 = conv1
69 |
70 | net_encoder_sd['conv1.weight'] = conv1_mod
71 |
72 | net_encoder.load_state_dict(net_encoder_sd)
73 | return net_encoder
74 |
75 | def build_decoder(self, arch='fba_decoder', batch_norm=False):
76 | if arch == 'fba_decoder':
77 | net_decoder = fba_decoder(batch_norm=batch_norm)
78 |
79 | return net_decoder
80 |
81 |
82 | class ResnetDilatedBN(nn.Module):
83 | def __init__(self, orig_resnet, dilate_scale=8):
84 | super(ResnetDilatedBN, self).__init__()
85 | from functools import partial
86 |
87 | if dilate_scale == 8:
88 | orig_resnet.layer3.apply(
89 | partial(self._nostride_dilate, dilate=2))
90 | orig_resnet.layer4.apply(
91 | partial(self._nostride_dilate, dilate=4))
92 | elif dilate_scale == 16:
93 | orig_resnet.layer4.apply(
94 | partial(self._nostride_dilate, dilate=2))
95 |
96 | # take pretrained resnet, except AvgPool and FC
97 | self.conv1 = orig_resnet.conv1
98 | self.bn1 = orig_resnet.bn1
99 | self.relu1 = orig_resnet.relu1
100 | self.conv2 = orig_resnet.conv2
101 | self.bn2 = orig_resnet.bn2
102 | self.relu2 = orig_resnet.relu2
103 | self.conv3 = orig_resnet.conv3
104 | self.bn3 = orig_resnet.bn3
105 | self.relu3 = orig_resnet.relu3
106 | self.maxpool = orig_resnet.maxpool
107 | self.layer1 = orig_resnet.layer1
108 | self.layer2 = orig_resnet.layer2
109 | self.layer3 = orig_resnet.layer3
110 | self.layer4 = orig_resnet.layer4
111 |
112 | def _nostride_dilate(self, m, dilate):
113 | classname = m.__class__.__name__
114 | if classname.find('Conv') != -1:
115 | # the convolution with stride
116 | if m.stride == (2, 2):
117 | m.stride = (1, 1)
118 | if m.kernel_size == (3, 3):
119 | m.dilation = (dilate // 2, dilate // 2)
120 | m.padding = (dilate // 2, dilate // 2)
121 | # other convoluions
122 | else:
123 | if m.kernel_size == (3, 3):
124 | m.dilation = (dilate, dilate)
125 | m.padding = (dilate, dilate)
126 |
127 | def forward(self, x, return_feature_maps=False):
128 | conv_out = [x]
129 | x = self.relu1(self.bn1(self.conv1(x)))
130 | x = self.relu2(self.bn2(self.conv2(x)))
131 | x = self.relu3(self.bn3(self.conv3(x)))
132 | conv_out.append(x)
133 | x, indices = self.maxpool(x)
134 | x = self.layer1(x)
135 | conv_out.append(x)
136 | x = self.layer2(x)
137 | conv_out.append(x)
138 | x = self.layer3(x)
139 | conv_out.append(x)
140 | x = self.layer4(x)
141 | conv_out.append(x)
142 |
143 | if return_feature_maps:
144 | return conv_out, indices
145 | return [x]
146 |
147 |
148 | class Resnet(nn.Module):
149 | def __init__(self, orig_resnet):
150 | super(Resnet, self).__init__()
151 |
152 | # take pretrained resnet, except AvgPool and FC
153 | self.conv1 = orig_resnet.conv1
154 | self.bn1 = orig_resnet.bn1
155 | self.relu1 = orig_resnet.relu1
156 | self.conv2 = orig_resnet.conv2
157 | self.bn2 = orig_resnet.bn2
158 | self.relu2 = orig_resnet.relu2
159 | self.conv3 = orig_resnet.conv3
160 | self.bn3 = orig_resnet.bn3
161 | self.relu3 = orig_resnet.relu3
162 | self.maxpool = orig_resnet.maxpool
163 | self.layer1 = orig_resnet.layer1
164 | self.layer2 = orig_resnet.layer2
165 | self.layer3 = orig_resnet.layer3
166 | self.layer4 = orig_resnet.layer4
167 |
168 | def forward(self, x, return_feature_maps=False):
169 | conv_out = []
170 |
171 | x = self.relu1(self.bn1(self.conv1(x)))
172 | x = self.relu2(self.bn2(self.conv2(x)))
173 | x = self.relu3(self.bn3(self.conv3(x)))
174 | conv_out.append(x)
175 | x, indices = self.maxpool(x)
176 |
177 | x = self.layer1(x)
178 | conv_out.append(x)
179 | x = self.layer2(x)
180 | conv_out.append(x)
181 | x = self.layer3(x)
182 | conv_out.append(x)
183 | x = self.layer4(x)
184 | conv_out.append(x)
185 |
186 | if return_feature_maps:
187 | return conv_out
188 | return [x]
189 |
190 |
191 | class ResnetDilated(nn.Module):
192 | def __init__(self, orig_resnet, dilate_scale=8):
193 | super(ResnetDilated, self).__init__()
194 | from functools import partial
195 |
196 | if dilate_scale == 8:
197 | orig_resnet.layer3.apply(
198 | partial(self._nostride_dilate, dilate=2))
199 | orig_resnet.layer4.apply(
200 | partial(self._nostride_dilate, dilate=4))
201 | elif dilate_scale == 16:
202 | orig_resnet.layer4.apply(
203 | partial(self._nostride_dilate, dilate=2))
204 |
205 | # take pretrained resnet, except AvgPool and FC
206 | self.conv1 = orig_resnet.conv1
207 | self.bn1 = orig_resnet.bn1
208 | self.relu = orig_resnet.relu
209 | self.maxpool = orig_resnet.maxpool
210 | self.layer1 = orig_resnet.layer1
211 | self.layer2 = orig_resnet.layer2
212 | self.layer3 = orig_resnet.layer3
213 | self.layer4 = orig_resnet.layer4
214 |
215 | def _nostride_dilate(self, m, dilate):
216 | classname = m.__class__.__name__
217 | if classname.find('Conv') != -1:
218 | # the convolution with stride
219 | if m.stride == (2, 2):
220 | m.stride = (1, 1)
221 | if m.kernel_size == (3, 3):
222 | m.dilation = (dilate // 2, dilate // 2)
223 | m.padding = (dilate // 2, dilate // 2)
224 | # other convoluions
225 | else:
226 | if m.kernel_size == (3, 3):
227 | m.dilation = (dilate, dilate)
228 | m.padding = (dilate, dilate)
229 |
230 | def forward(self, x, return_feature_maps=False):
231 | conv_out = [x]
232 | x = self.relu(self.bn1(self.conv1(x)))
233 | conv_out.append(x)
234 | x, indices = self.maxpool(x)
235 | x = self.layer1(x)
236 | conv_out.append(x)
237 | x = self.layer2(x)
238 | conv_out.append(x)
239 | x = self.layer3(x)
240 | conv_out.append(x)
241 | x = self.layer4(x)
242 | conv_out.append(x)
243 |
244 | if return_feature_maps:
245 | return conv_out, indices
246 | return [x]
247 |
248 |
249 | def norm(dim, bn=False):
250 | if(bn is False):
251 | return nn.GroupNorm(32, dim)
252 | else:
253 | return nn.BatchNorm2d(dim)
254 |
255 |
256 | def fba_fusion(alpha, img, F, B):
257 | F = ((alpha * img + (1 - alpha**2) * F - alpha * (1 - alpha) * B))
258 | B = ((1 - alpha) * img + (2 * alpha - alpha**2) * B - alpha * (1 - alpha) * F)
259 |
260 | F = torch.clamp(F, 0, 1)
261 | B = torch.clamp(B, 0, 1)
262 | la = 0.1
263 | alpha = (alpha * la + torch.sum((img - B) * (F - B), 1, keepdim=True)) / (torch.sum((F - B) * (F - B), 1, keepdim=True) + la)
264 | alpha = torch.clamp(alpha, 0, 1)
265 | return alpha, F, B
266 |
267 |
268 | class fba_decoder(nn.Module):
269 | def __init__(self, batch_norm=False):
270 | super(fba_decoder, self).__init__()
271 | pool_scales = (1, 2, 3, 6)
272 | self.batch_norm = batch_norm
273 |
274 | self.ppm = []
275 |
276 | for scale in pool_scales:
277 | self.ppm.append(nn.Sequential(
278 | nn.AdaptiveAvgPool2d(scale),
279 | L.Conv2d(2048, 256, kernel_size=1, bias=True),
280 | norm(256, self.batch_norm),
281 | nn.LeakyReLU()
282 | ))
283 | self.ppm = nn.ModuleList(self.ppm)
284 |
285 | self.conv_up1 = nn.Sequential(
286 | L.Conv2d(2048 + len(pool_scales) * 256, 256,
287 | kernel_size=3, padding=1, bias=True),
288 |
289 | norm(256, self.batch_norm),
290 | nn.LeakyReLU(),
291 | L.Conv2d(256, 256, kernel_size=3, padding=1),
292 | norm(256, self.batch_norm),
293 | nn.LeakyReLU()
294 | )
295 |
296 | self.conv_up2 = nn.Sequential(
297 | L.Conv2d(256 + 256, 256,
298 | kernel_size=3, padding=1, bias=True),
299 | norm(256, self.batch_norm),
300 | nn.LeakyReLU()
301 | )
302 | if(self.batch_norm):
303 | d_up3 = 128
304 | else:
305 | d_up3 = 64
306 | self.conv_up3 = nn.Sequential(
307 | L.Conv2d(256 + d_up3, 64,
308 | kernel_size=3, padding=1, bias=True),
309 | norm(64, self.batch_norm),
310 | nn.LeakyReLU()
311 | )
312 |
313 | self.unpool = nn.MaxUnpool2d(2, stride=2)
314 |
315 | self.conv_up4 = nn.Sequential(
316 | nn.Conv2d(64 + 3 + 3 + 2, 32,
317 | kernel_size=3, padding=1, bias=True),
318 | nn.LeakyReLU(),
319 | nn.Conv2d(32, 16,
320 | kernel_size=3, padding=1, bias=True),
321 |
322 | nn.LeakyReLU(),
323 | nn.Conv2d(16, 7, kernel_size=1, padding=0, bias=True)
324 | )
325 |
326 | def forward(self, conv_out, img, indices, two_chan_trimap):
327 | conv5 = conv_out[-1]
328 |
329 | input_size = conv5.size()
330 | ppm_out = [conv5]
331 | for pool_scale in self.ppm:
332 | ppm_out.append(nn.functional.interpolate(
333 | pool_scale(conv5),
334 | (input_size[2], input_size[3]),
335 | mode='bilinear', align_corners=False))
336 | ppm_out = torch.cat(ppm_out, 1)
337 | x = self.conv_up1(ppm_out)
338 |
339 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
340 |
341 | x = torch.cat((x, conv_out[-4]), 1)
342 |
343 | x = self.conv_up2(x)
344 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
345 |
346 | x = torch.cat((x, conv_out[-5]), 1)
347 | x = self.conv_up3(x)
348 |
349 | x = torch.nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
350 | x = torch.cat((x, conv_out[-6][:, :3], img, two_chan_trimap), 1)
351 |
352 | output = self.conv_up4(x)
353 |
354 | alpha = torch.clamp(output[:, 0][:, None], 0, 1)
355 | F = torch.sigmoid(output[:, 1:4])
356 | B = torch.sigmoid(output[:, 4:7])
357 |
358 | # FBA Fusion
359 | alpha, F, B = fba_fusion(alpha, img, F, B)
360 |
361 | output = torch.cat((alpha, F, B), 1)
362 |
363 | return output
364 |
--------------------------------------------------------------------------------
/networks/resnet_GN_WS.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import networks.layers_WS as L
3 |
4 | __all__ = ['ResNet', 'l_resnet50']
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | """3x3 convolution with padding"""
9 | return L.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 |
12 |
13 | def conv1x1(in_planes, out_planes, stride=1):
14 | """1x1 convolution"""
15 | return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
16 |
17 |
18 | class BasicBlock(nn.Module):
19 | expansion = 1
20 |
21 | def __init__(self, inplanes, planes, stride=1, downsample=None):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(inplanes, planes, stride)
24 | self.bn1 = L.BatchNorm2d(planes)
25 | self.relu = nn.ReLU(inplace=True)
26 | self.conv2 = conv3x3(planes, planes)
27 | self.bn2 = L.BatchNorm2d(planes)
28 | self.downsample = downsample
29 | self.stride = stride
30 |
31 | def forward(self, x):
32 | identity = x
33 |
34 | out = self.conv1(x)
35 | out = self.bn1(out)
36 | out = self.relu(out)
37 |
38 | out = self.conv2(out)
39 | out = self.bn2(out)
40 |
41 | if self.downsample is not None:
42 | identity = self.downsample(x)
43 |
44 | out += identity
45 | out = self.relu(out)
46 |
47 | return out
48 |
49 |
50 | class Bottleneck(nn.Module):
51 | expansion = 4
52 |
53 | def __init__(self, inplanes, planes, stride=1, downsample=None):
54 | super(Bottleneck, self).__init__()
55 | self.conv1 = conv1x1(inplanes, planes)
56 | self.bn1 = L.BatchNorm2d(planes)
57 | self.conv2 = conv3x3(planes, planes, stride)
58 | self.bn2 = L.BatchNorm2d(planes)
59 | self.conv3 = conv1x1(planes, planes * self.expansion)
60 | self.bn3 = L.BatchNorm2d(planes * self.expansion)
61 | self.relu = nn.ReLU(inplace=True)
62 | self.downsample = downsample
63 | self.stride = stride
64 |
65 | def forward(self, x):
66 | identity = x
67 |
68 | out = self.conv1(x)
69 | out = self.bn1(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv2(out)
73 | out = self.bn2(out)
74 | out = self.relu(out)
75 |
76 | out = self.conv3(out)
77 | out = self.bn3(out)
78 |
79 | if self.downsample is not None:
80 | identity = self.downsample(x)
81 |
82 | out += identity
83 | out = self.relu(out)
84 |
85 | return out
86 |
87 |
88 | class ResNet(nn.Module):
89 |
90 | def __init__(self, block, layers, num_classes=1000):
91 | super(ResNet, self).__init__()
92 | self.inplanes = 64
93 | self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
94 | bias=False)
95 | self.bn1 = L.BatchNorm2d(64)
96 | self.relu = nn.ReLU(inplace=True)
97 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
98 | self.layer1 = self._make_layer(block, 64, layers[0])
99 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
100 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
101 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
102 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
103 | self.fc = nn.Linear(512 * block.expansion, num_classes)
104 |
105 | def _make_layer(self, block, planes, blocks, stride=1):
106 | downsample = None
107 | if stride != 1 or self.inplanes != planes * block.expansion:
108 | downsample = nn.Sequential(
109 | conv1x1(self.inplanes, planes * block.expansion, stride),
110 | L.BatchNorm2d(planes * block.expansion),
111 | )
112 |
113 | layers = []
114 | layers.append(block(self.inplanes, planes, stride, downsample))
115 | self.inplanes = planes * block.expansion
116 | for _ in range(1, blocks):
117 | layers.append(block(self.inplanes, planes))
118 |
119 | return nn.Sequential(*layers)
120 |
121 | def forward(self, x):
122 | x = self.conv1(x)
123 | x = self.bn1(x)
124 | x = self.relu(x)
125 | x = self.maxpool(x)
126 |
127 | x = self.layer1(x)
128 | x = self.layer2(x)
129 | x = self.layer3(x)
130 | x = self.layer4(x)
131 |
132 | x = self.avgpool(x)
133 | x = x.view(x.size(0), -1)
134 | x = self.fc(x)
135 |
136 | return x
137 |
138 |
139 | def l_resnet50(pretrained=False, **kwargs):
140 | """Constructs a ResNet-50 model.
141 | Args:
142 | pretrained (bool): If True, returns a model pre-trained on ImageNet
143 | """
144 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
145 | return model
146 |
--------------------------------------------------------------------------------
/networks/resnet_bn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | from torch.nn import BatchNorm2d
4 |
5 | __all__ = ['ResNet']
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | "3x3 convolution with padding"
10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11 | padding=1, bias=False)
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = conv3x3(inplanes, planes, stride)
20 | self.bn1 = BatchNorm2d(planes)
21 | self.relu = nn.ReLU(inplace=True)
22 | self.conv2 = conv3x3(planes, planes)
23 | self.bn2 = BatchNorm2d(planes)
24 | self.downsample = downsample
25 | self.stride = stride
26 |
27 | def forward(self, x):
28 | residual = x
29 |
30 | out = self.conv1(x)
31 | out = self.bn1(out)
32 | out = self.relu(out)
33 |
34 | out = self.conv2(out)
35 | out = self.bn2(out)
36 |
37 | if self.downsample is not None:
38 | residual = self.downsample(x)
39 |
40 | out += residual
41 | out = self.relu(out)
42 |
43 | return out
44 |
45 |
46 | class Bottleneck(nn.Module):
47 | expansion = 4
48 |
49 | def __init__(self, inplanes, planes, stride=1, downsample=None):
50 | super(Bottleneck, self).__init__()
51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
52 | self.bn1 = BatchNorm2d(planes)
53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
54 | padding=1, bias=False)
55 | self.bn2 = BatchNorm2d(planes, momentum=0.01)
56 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
57 | self.bn3 = BatchNorm2d(planes * 4)
58 | self.relu = nn.ReLU(inplace=True)
59 | self.downsample = downsample
60 | self.stride = stride
61 |
62 | def forward(self, x):
63 | residual = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv3(out)
74 | out = self.bn3(out)
75 |
76 | if self.downsample is not None:
77 | residual = self.downsample(x)
78 |
79 | out += residual
80 | out = self.relu(out)
81 |
82 | return out
83 |
84 |
85 | class ResNet(nn.Module):
86 |
87 | def __init__(self, block, layers, num_classes=1000):
88 | self.inplanes = 128
89 | super(ResNet, self).__init__()
90 | self.conv1 = conv3x3(3, 64, stride=2)
91 | self.bn1 = BatchNorm2d(64)
92 | self.relu1 = nn.ReLU(inplace=True)
93 | self.conv2 = conv3x3(64, 64)
94 | self.bn2 = BatchNorm2d(64)
95 | self.relu2 = nn.ReLU(inplace=True)
96 | self.conv3 = conv3x3(64, 128)
97 | self.bn3 = BatchNorm2d(128)
98 | self.relu3 = nn.ReLU(inplace=True)
99 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)
100 |
101 | self.layer1 = self._make_layer(block, 64, layers[0])
102 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
103 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
104 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
105 | self.avgpool = nn.AvgPool2d(7, stride=1)
106 | self.fc = nn.Linear(512 * block.expansion, num_classes)
107 |
108 | for m in self.modules():
109 | if isinstance(m, nn.Conv2d):
110 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
111 | m.weight.data.normal_(0, math.sqrt(2. / n))
112 | elif isinstance(m, BatchNorm2d):
113 | m.weight.data.fill_(1)
114 | m.bias.data.zero_()
115 |
116 | def _make_layer(self, block, planes, blocks, stride=1):
117 | downsample = None
118 | if stride != 1 or self.inplanes != planes * block.expansion:
119 | downsample = nn.Sequential(
120 | nn.Conv2d(self.inplanes, planes * block.expansion,
121 | kernel_size=1, stride=stride, bias=False),
122 | BatchNorm2d(planes * block.expansion),
123 | )
124 |
125 | layers = []
126 | layers.append(block(self.inplanes, planes, stride, downsample))
127 | self.inplanes = planes * block.expansion
128 | for i in range(1, blocks):
129 | layers.append(block(self.inplanes, planes))
130 |
131 | return nn.Sequential(*layers)
132 |
133 | def forward(self, x):
134 | x = self.relu1(self.bn1(self.conv1(x)))
135 | x = self.relu2(self.bn2(self.conv2(x)))
136 | x = self.relu3(self.bn3(self.conv3(x)))
137 | x, indices = self.maxpool(x)
138 |
139 | x = self.layer1(x)
140 | x = self.layer2(x)
141 | x = self.layer3(x)
142 | x = self.layer4(x)
143 |
144 | x = self.avgpool(x)
145 | x = x.view(x.size(0), -1)
146 | x = self.fc(x)
147 | return x
148 |
149 |
150 | def l_resnet50():
151 | """Constructs a ResNet-50 model.
152 | Args:
153 | pretrained (bool): If True, returns a model pre-trained on ImageNet
154 | """
155 | model = ResNet(Bottleneck, [3, 4, 6, 3])
156 | return model
157 |
--------------------------------------------------------------------------------
/networks/transforms.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 | import cv2
5 |
6 |
7 | def dt(a):
8 | return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0)
9 |
10 |
11 | def trimap_transform(trimap):
12 | h, w = trimap.shape[0], trimap.shape[1]
13 |
14 | clicks = np.zeros((h, w, 6))
15 | for k in range(2):
16 | if(np.count_nonzero(trimap[:, :, k]) > 0):
17 | dt_mask = -dt(1 - trimap[:, :, k])**2
18 | L = 320
19 | clicks[:, :, 3*k] = np.exp(dt_mask / (2 * ((0.02 * L)**2)))
20 | clicks[:, :, 3*k+1] = np.exp(dt_mask / (2 * ((0.08 * L)**2)))
21 | clicks[:, :, 3*k+2] = np.exp(dt_mask / (2 * ((0.16 * L)**2)))
22 |
23 | return clicks
24 |
25 |
26 | # For RGB !
27 | group_norm_std = [0.229, 0.224, 0.225]
28 | group_norm_mean = [0.485, 0.456, 0.406]
29 |
30 |
31 | def groupnorm_normalise_image(img, format='nhwc'):
32 | '''
33 | Accept rgb in range 0,1
34 | '''
35 | if(format == 'nhwc'):
36 | for i in range(3):
37 | img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i]
38 | else:
39 | for i in range(3):
40 | img[..., i, :, :] = (img[..., i, :, :] - group_norm_mean[i]) / group_norm_std[i]
41 |
42 | return img
43 |
44 |
45 | def groupnorm_denormalise_image(img, format='nhwc'):
46 | '''
47 | Accept rgb, normalised, return in range 0,1
48 | '''
49 | if(format == 'nhwc'):
50 | for i in range(3):
51 | img[:, :, :, i] = img[:, :, :, i] * group_norm_std[i] + group_norm_mean[i]
52 | else:
53 | img1 = torch.zeros_like(img).cuda()
54 | for i in range(3):
55 | img1[:, i, :, :] = img[:, i, :, :] * group_norm_std[i] + group_norm_mean[i]
56 | return img1
57 | return img
58 |
--------------------------------------------------------------------------------