├── .gitignore ├── LICENSE ├── README.md ├── demo.py ├── human_inst_seg ├── __init__.py ├── dataset.py ├── train.py └── unet.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | data/ 107 | 108 | .DS_Store 109 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Hiroharu Kato 4 | Copyright (c) 2018 Nikos Kolotouros 5 | Copyright (c) 2019 Shichen Liu 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Human Instance Segmentation 2 | 3 | A Single-Human Instance Segmentor runs at **50 FPS** on GV100. Both training and inference are included in this repo. 4 | 5 | ## Install 6 | 7 | ``` 8 | # via pip 9 | pip install git+https://github.com/Project-Splinter/human_inst_seg --upgrade 10 | 11 | # via git clone 12 | git clone https://github.com/Project-Splinter/human_inst_seg 13 | cd human_inst_seg 14 | python setup.py develop 15 | ``` 16 | 17 | Note to run `demo.py`, you also need to install [streamer_pytorch](https://github.com/Project-Splinter/streamer_pytorch) through: 18 | ``` 19 | pip install git+https://github.com/Project-Splinter/streamer_pytorch --upgrade 20 | ``` 21 | 22 | ## Train 23 | First Download dataset from [here](https://github.com/Project-Splinter/ATR_RemoveBG) 24 | 25 | ``` 26 | git clone https://github.com/Project-Splinter/human_inst_seg; cd human_inst_seg; 27 | mkdir ./data # put all dataset zip under here and unzip them. It should contain two folders: `ATR_RemoveBG` and `alignment` 28 | python human_inst_seg/train.py 29 | ``` 30 | 31 | ## Usage 32 | 33 | ``` 34 | # images 35 | python demo.py --images --loop --vis 36 | # videos 37 | python demo.py --videos --vis 38 | # capture device 39 | python demo.py --camera --vis 40 | ``` 41 | 42 | ## API 43 | ``` 44 | seg_engine = Segmentation(ckpt=None, device="cuda:0", init=True): 45 | seg_engine.init(pretrained="") 46 | seg_engine.forward(input) 47 | ``` 48 | **Note**: `Segmentation` **is** an instance of `nn.Module`, so you need to be carefull when you want to integrate this to other trainable model. 49 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import cv2 3 | import argparse 4 | import numpy as np 5 | import torch 6 | 7 | import human_inst_seg 8 | # this can be install by: 9 | # pip install git+https://github.com/Project-Splinter/streamer_pytorch --upgrade 10 | import streamer_pytorch as streamer 11 | 12 | parser = argparse.ArgumentParser(description='.') 13 | parser.add_argument( 14 | '--camera', action="store_true") 15 | parser.add_argument( 16 | '--images', default="", nargs="*") 17 | parser.add_argument( 18 | '--videos', default="", nargs="*") 19 | parser.add_argument( 20 | '--loop', action="store_true") 21 | parser.add_argument( 22 | '--vis', action="store_true") 23 | args = parser.parse_args() 24 | 25 | def visulization(data): 26 | image, bboxes, probs = data 27 | image = torch.cat([ 28 | image[:, 0:3], image[:, 0:3]*image[:, 3:4]], dim=3) 29 | probs = probs.unsqueeze(3) 30 | bboxes = (bboxes * probs).sum(dim=1, keepdim=True) / probs.sum(dim=1, keepdim=True) 31 | window = image[0].cpu().numpy().transpose(1, 2, 0) 32 | window = (window * 0.5 + 0.5) * 255.0 33 | window = np.uint8(window).copy() 34 | bbox = bboxes[0, 0, 0].cpu().numpy() 35 | window = cv2.rectangle( 36 | window, 37 | (int(bbox[0]), int(bbox[1])), 38 | (int(bbox[2]), int(bbox[3])), 39 | (255,0,0), 2) 40 | 41 | window = cv2.cvtColor(window, cv2.COLOR_BGR2RGB) 42 | window = cv2.resize(window, (0, 0), fx=2, fy=2) 43 | 44 | cv2.imshow('window', window) 45 | cv2.waitKey(30) 46 | 47 | seg_engine = human_inst_seg.Segmentation() 48 | seg_engine.eval() 49 | 50 | if args.camera: 51 | data_stream = streamer.CaptureStreamer() 52 | elif len(args.videos) > 0: 53 | data_stream = streamer.VideoListStreamer( 54 | args.videos * (10000 if args.loop else 1)) 55 | elif len(args.images) > 0: 56 | data_stream = streamer.ImageListStreamer( 57 | args.images * (10000 if args.loop else 1)) 58 | 59 | loader = torch.utils.data.DataLoader( 60 | data_stream, 61 | batch_size=1, 62 | num_workers=1, 63 | pin_memory=False, 64 | ) 65 | 66 | try: 67 | # no vis: ~ 50 fps 68 | for data in tqdm.tqdm(loader): 69 | outputs, bboxes, probs = seg_engine(data) 70 | if args.vis: 71 | visulization([outputs, bboxes, probs]) 72 | except Exception as e: 73 | print (e) 74 | del data_stream 75 | -------------------------------------------------------------------------------- /human_inst_seg/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import Segmentation 2 | -------------------------------------------------------------------------------- /human_inst_seg/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | 11 | def aug_matrix(w1, h1, w2, h2): 12 | dx = (w2 - w1) / 2.0 13 | dy = (h2 - h1) / 2.0 14 | matrix_trans = np.array([[1.0, 0, dx], 15 | [0, 1.0, dy], 16 | [0, 0, 1.0]]) 17 | 18 | scale = np.min([float(w2)/w1, float(h2)/h1]) # min | max 19 | 20 | M = get_affine_matrix( 21 | center = (w2 / 2.0, h2 / 2.0), 22 | translate = (0, 0), 23 | scale = scale) 24 | M = np.array(M + [0., 0., 1.]).reshape(3, 3) 25 | M = M.dot(matrix_trans) 26 | return M 27 | 28 | 29 | def get_affine_matrix(center, translate, scale): 30 | cx, cy = center 31 | tx, ty = translate 32 | 33 | M = [1, 0, 0, 34 | 0, 1, 0] 35 | M = [x * scale for x in M] 36 | 37 | # Apply translation and of center translation: RSS * C^-1 38 | M[2] += M[0] * (-cx) + M[1] * (-cy) 39 | M[5] += M[3] * (-cx) + M[4] * (-cy) 40 | 41 | # Apply center translation: T * C * RSS * C^-1 42 | M[2] += cx + tx 43 | M[5] += cy + ty 44 | return M 45 | 46 | class Dataset(object): 47 | def __init__(self, 48 | input_size=512, 49 | image_dir="./data/images", 50 | label_dir="./data/labels", 51 | train=True, 52 | ): 53 | super().__init__() 54 | self.input_size = input_size 55 | self.train = train 56 | 57 | image_names = [f for f in os.listdir(image_dir) if f[-3:]=="jpg"] 58 | image_files = [os.path.join(image_dir, f) for f in image_names] 59 | label_files = [ 60 | os.path.join( 61 | label_dir, 62 | f.replace(".jpg", "-removebg-preview.png") 63 | ) for f in image_names 64 | ] 65 | 66 | self.image_files = [] 67 | self.label_files = [] 68 | for image_file, label_file in zip(image_files, label_files): 69 | if os.path.exists(image_file) and os.path.exists(label_file): 70 | self.image_files.append(image_file) 71 | self.label_files.append(label_file) 72 | 73 | self.image_files = self.image_files 74 | self.label_files = self.label_files 75 | 76 | self.color_aug = transforms.Compose([ 77 | transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0), 78 | ]) 79 | 80 | if self.train: 81 | self.image_to_tensor = transforms.Compose([ 82 | transforms.ToTensor(), 83 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 84 | transforms.RandomErasing(p=0.5, scale=(0.02, 0.2), ratio=(0.3, 3.3), value=0), 85 | ]) 86 | else: 87 | self.image_to_tensor = transforms.Compose([ 88 | transforms.ToTensor(), 89 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 90 | ]) 91 | self.mask_to_tensor = transforms.Compose([ 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.0,), (1.0,)) 94 | ]) 95 | 96 | print (f"Dataset: {self.__len__()}") 97 | 98 | def __len__(self): 99 | return len(self.image_files) 100 | 101 | def __getitem__(self, index): 102 | image_file = self.image_files[index] 103 | label_file = self.label_files[index] 104 | 105 | image = Image.open(image_file).convert("RGB") 106 | width, height = image.size 107 | mask = Image.open(label_file).split()[-1] 108 | mask = mask.resize((width, height), Image.BILINEAR) 109 | 110 | if self.train: 111 | image = self.color_aug(image) 112 | 113 | M = aug_matrix(width, height, self.input_size, self.input_size) 114 | 115 | M_inv = np.linalg.inv(M) 116 | M_inv = M_inv[0:2].reshape(-1).tolist() 117 | 118 | image = image.transform( 119 | (self.input_size, self.input_size), Image.AFFINE, M_inv, 120 | Image.BILINEAR, fillcolor=(128, 128, 128)) 121 | mask = mask.transform( 122 | (self.input_size, self.input_size), Image.AFFINE, M_inv, 123 | Image.BILINEAR, fillcolor=(0,)) 124 | 125 | if self.train and random.random() < 0.5: 126 | image = transforms.functional.hflip(image) 127 | mask = transforms.functional.hflip(mask) 128 | 129 | input = self.image_to_tensor(image).float() 130 | label = self.mask_to_tensor(mask).long().squeeze(0) 131 | 132 | return input, label 133 | 134 | if __name__ == "__main__": 135 | import torchvision 136 | 137 | dataset = Dataset( 138 | input_size=256, 139 | image_dir="./JPEGImages/", 140 | label_dir="./RemoveBG/", 141 | ) 142 | 143 | images = [] 144 | for i in range(16): 145 | image, mask = dataset[i] 146 | images.append(image) 147 | images = torch.stack(images) 148 | 149 | input_norm = images * 0.5 + 0.5 #[-1, 1] -> [0, 1] 150 | torchvision.utils.save_image( 151 | input_norm, 152 | f"./example.jpg", 153 | normalize=True, range=(0, 1), nrow=4, padding=10, pad_value=0.5 154 | ) 155 | -------------------------------------------------------------------------------- /human_inst_seg/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.optim.lr_scheduler import StepLR, MultiStepLR 10 | import torchvision 11 | 12 | from dataset import Dataset 13 | from unet import Segmentation 14 | 15 | loss_Softmax = nn.CrossEntropyLoss(ignore_index=255) 16 | 17 | os.makedirs("./data/snapshots/", exist_ok=True) 18 | os.makedirs("./data/visualize/train/", exist_ok=True) 19 | os.makedirs("./data/visualize/test/", exist_ok=True) 20 | 21 | 22 | def train(args, model, device, train_loader, optimizer, epoch): 23 | model.train() 24 | for batch_idx, (input, target) in enumerate(train_loader): 25 | input, target = input.to(device),target.to(device) 26 | optimizer.zero_grad() 27 | output = model(input) 28 | loss = loss_Softmax(output, target) 29 | loss.backward() 30 | optimizer.step() 31 | if batch_idx % args.log_interval == 0: 32 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 33 | epoch, batch_idx * len(input), len(train_loader.dataset), 34 | 100. * batch_idx / len(train_loader), loss.item())) 35 | 36 | input_norm = input[0:8] * 0.5 + 0.5 #[-1, 1] -> [0, 1] 37 | output_norm = F.softmax(output, dim=1)[0:8, 1:2].repeat(1, 3, 1, 1).float() 38 | target_norm = target[0:8].unsqueeze(1).repeat(1, 3, 1, 1).float() 39 | 40 | torchvision.utils.save_image( 41 | torch.cat([input_norm, output_norm, target_norm], dim=0), 42 | f"./data/visualize/train/latest.jpg", 43 | normalize=True, range=(0, 1), nrow=len(input_norm), padding=10, pad_value=0.5 44 | ) 45 | 46 | torch.save( 47 | model.state_dict(), 48 | f"./data/snapshots/latest.pt", 49 | ) 50 | 51 | best = 0.0 52 | def test(args, model, device, test_loader): 53 | model.eval() 54 | correct = 0 55 | iou = 0 56 | with torch.no_grad(): 57 | for batch_idx, (data, target) in enumerate(test_loader): 58 | data, target = data.to(device), target.to(device) 59 | output, _, _ = model(data) 60 | 61 | input_norm = data * 0.5 + 0.5 #[-1, 1] -> [0, 1] 62 | output_norm = output[:, 3:4, :, :].repeat(1, 3, 1, 1).float() 63 | target_norm = target.unsqueeze(1).repeat(1, 3, 1, 1).float() 64 | torchvision.utils.save_image( 65 | torch.cat([input_norm, output_norm, target_norm], dim=0), 66 | f"./data/visualize/test/latest_{batch_idx}.jpg", 67 | normalize=True, range=(0, 1), nrow=len(input_norm), padding=10, pad_value=0.5 68 | ) 69 | 70 | pred = output_norm>0.5 71 | gt = target_norm>0.5 72 | correct += pred.eq(gt).sum().item() / gt.numel() 73 | iou += ((pred & gt).sum()+1e-6) / ((pred | gt).sum()+1e-6) 74 | 75 | print('\nTest set: , Accuracy: {:.2f}%, IOU: {:.4f}\n'.format( 76 | 100. * correct / len(test_loader.dataset) * pred.size(0), 77 | iou / len(test_loader.dataset) * pred.size(0) 78 | )) 79 | 80 | global best 81 | if iou > best: 82 | best = iou 83 | torch.save( 84 | model.state_dict(), 85 | f"./data/snapshots/best-{(iou*100): .2f}.pt", 86 | ) 87 | 88 | def main(): 89 | # Training settings 90 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 91 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 92 | help='input batch size for training (default: 64)') 93 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 94 | help='input batch size for testing (default: 1000)') 95 | parser.add_argument('--epochs', type=int, default=10000, metavar='N', 96 | help='number of epochs to train (default: 14)') 97 | parser.add_argument('--lr', type=float, default=20.0, metavar='LR', 98 | help='learning rate (default: 1.0)') 99 | parser.add_argument('--gamma', type=float, default=0.99, metavar='M', 100 | help='Learning rate step gamma (default: 0.7)') 101 | parser.add_argument('--no-cuda', action='store_true', default=False, 102 | help='disables CUDA training') 103 | parser.add_argument('--seed', type=int, default=1, metavar='dS', 104 | help='random seed (default: 1)') 105 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 106 | help='how many batches to wait before logging training status') 107 | parser.add_argument('--ckpt', type=str, default="", 108 | help='load from which file.') 109 | 110 | args = parser.parse_args() 111 | use_cuda = not args.no_cuda and torch.cuda.is_available() 112 | 113 | torch.manual_seed(args.seed) 114 | 115 | device = torch.device("cuda" if use_cuda else "cpu") 116 | 117 | kwargs = {'num_workers': 20, 'pin_memory': True} if use_cuda else {} 118 | train_loader = torch.utils.data.DataLoader( 119 | Dataset(input_size=256, train=True, 120 | image_dir="./data/ATR_RemoveBG/JPEGImages/", 121 | label_dir="./data/ATR_RemoveBG/RemoveBG/"), 122 | batch_size=args.batch_size, shuffle=True, **kwargs) 123 | test_loader = torch.utils.data.DataLoader( 124 | Dataset(input_size=256, train=False, 125 | image_dir="./data/alignment", 126 | label_dir="./data/alignment"), 127 | batch_size=args.test_batch_size, shuffle=False, **kwargs) 128 | 129 | model = Segmentation().to(device) 130 | model.train() 131 | 132 | if os.path.exists(args.ckpt): 133 | model.load_state_dict(torch.load(args.ckpt)) 134 | print (f"load from snapshots: {args.ckpt}") 135 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr) 136 | 137 | scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) 138 | for epoch in range(1, args.epochs + 1): 139 | print (f"lr: {scheduler.get_lr()[0]}") 140 | test(args, model, device, test_loader) 141 | train(args, model, device, train_loader, optimizer, epoch) 142 | scheduler.step() 143 | 144 | 145 | if __name__ == '__main__': 146 | main() 147 | -------------------------------------------------------------------------------- /human_inst_seg/unet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision 7 | 8 | # this can be installed by: 9 | # pip install git+https://github.com/liruilong940607/human_det --upgrade 10 | from human_det import Detection 11 | 12 | # this can be installed by: 13 | # pip install git+https://github.com/qubvel/segmentation_models.pytorch --upgrade 14 | import segmentation_models_pytorch as smp 15 | 16 | try: 17 | from torch.hub import load_state_dict_from_url 18 | except ImportError: 19 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 20 | 21 | def scale_boxes(boxes, scale): 22 | """ 23 | Args: 24 | boxes (tensor): A tensor of shape (B, 4) representing B boxes with 4 25 | coords representing the corners x0, y0, x1, y1, 26 | scale (float, float): The box scaling factor (w, h). 27 | Returns: 28 | Scaled boxes. 29 | """ 30 | w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 31 | h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 32 | x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 33 | y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 34 | 35 | w_half *= scale[0] 36 | h_half *= scale[1] 37 | 38 | scaled_boxes = torch.zeros_like(boxes) 39 | scaled_boxes[:, 0] = x_c - w_half 40 | scaled_boxes[:, 2] = x_c + w_half 41 | scaled_boxes[:, 1] = y_c - h_half 42 | scaled_boxes[:, 3] = y_c + h_half 43 | return scaled_boxes 44 | 45 | class Segmentation(nn.Module): 46 | def __init__(self, ckpt=None, device="cuda:0", init=True, verbose=False): 47 | super().__init__() 48 | model = smp.Unet( 49 | 'resnet18', 50 | encoder_weights='imagenet', 51 | classes=2, 52 | # activation='softmax' 53 | ).to(device) 54 | 55 | if ckpt is not None and os.path.exists(ckpt): 56 | print (f"load ckpt from: {ckpt}") 57 | model.load_state_dict(torch.load(ckpt)) 58 | 59 | self.device = device 60 | self.model = model 61 | self.verbose = verbose 62 | 63 | self.det_engine = Detection(device=device) 64 | 65 | if init: 66 | self.init() 67 | 68 | def init(self, pretrained=""): 69 | if os.path.exists(pretrained): 70 | state_dict = torch.load(pretrained) 71 | else: 72 | state_dict = load_state_dict_from_url( 73 | "https://drive.google.com/uc?export=download&id=18d2yeCx62Gup-YzgsI866uxpEo9kIl2T") 74 | self.load_state_dict(state_dict) 75 | 76 | def forward(self, input, scaled_boxes=None): 77 | # input is 1 x 3 x H x W 78 | Batch, _, H, W = input.size() 79 | input = input.to(self.device) 80 | 81 | # det 82 | if scaled_boxes is None: 83 | with torch.no_grad(): 84 | bboxes_det, probs_det = self.det_engine(input) 85 | 86 | probs = probs_det.unsqueeze(3) 87 | bboxes = (bboxes_det * probs).sum(dim=1, keepdim=True) / probs.sum(dim=1, keepdim=True) 88 | bboxes = bboxes[:, 0, 0, :] 89 | 90 | w_half = (bboxes[:, 2] - bboxes[:, 0]) * 0.5 91 | h_half = (bboxes[:, 3] - bboxes[:, 1]) * 0.5 92 | x_c = (bboxes[:, 2] + bboxes[:, 0]) * 0.5 93 | y_c = (bboxes[:, 3] + bboxes[:, 1]) * 0.5 94 | h_half *= 1.2 if not self.training else random.uniform(1.0, 1.5) 95 | w_half = h_half / 288 * 192 96 | scaled_boxes = torch.zeros_like(bboxes) 97 | scaled_boxes[:, 0] = x_c - w_half 98 | scaled_boxes[:, 2] = x_c + w_half 99 | scaled_boxes[:, 1] = y_c - h_half 100 | scaled_boxes[:, 3] = y_c + h_half 101 | scaled_boxes = [box.unsqueeze(0) for box in scaled_boxes] 102 | else: 103 | bboxes_det, probs_det = None, None 104 | 105 | if self.verbose: 106 | print (scale_boxes) 107 | 108 | # seg 109 | output = self.model( 110 | torchvision.ops.roi_align(input, scaled_boxes, (288, 192))) 111 | 112 | x0_int, y0_int = 0, 0 113 | x1_int, y1_int = W, H 114 | scaled_boxes = torch.cat(scaled_boxes, dim=0) 115 | x0, y0, x1, y1 = torch.split(scaled_boxes, 1, dim=1) # each is Nx1 116 | 117 | img_y = torch.arange(y0_int, y1_int, device=self.device, dtype=torch.float32) + 0.5 118 | img_x = torch.arange(x0_int, x1_int, device=self.device, dtype=torch.float32) + 0.5 119 | img_y = (img_y - y0) / (y1 - y0) * 2 - 1 120 | img_x = (img_x - x0) / (x1 - x0) * 2 - 1 121 | # img_x, img_y have shapes (N, w), (N, h) 122 | 123 | gx = img_x[:, None, :].expand(Batch, img_y.size(1), img_x.size(1)) 124 | gy = img_y[:, :, None].expand(Batch, img_y.size(1), img_x.size(1)) 125 | grid = torch.stack([gx, gy], dim=3) 126 | 127 | # train.py 128 | if self.training: 129 | output = F.grid_sample(output, grid, align_corners=False) 130 | output = F.interpolate(output, size=(H, W), mode="bilinear") 131 | return output 132 | 133 | else: 134 | output = F.softmax(output, dim=1)[:, 1:2] 135 | output = F.grid_sample(output, grid, align_corners=False) 136 | output = F.interpolate(output, size=(H, W), mode="bilinear") 137 | output = (output > 0.5).float() 138 | output = torch.cat([input, output], dim=1) 139 | return output, bboxes_det, probs_det 140 | 141 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | INSTALL_REQUIREMENTS = ['numpy', 'torch', 'torchvision', 'Pillow', 'scikit-image', 'opencv-python', 'tqdm', 'imageio'] 4 | 5 | setuptools.setup( 6 | name='human_inst_seg', 7 | url='https://github.com/Project-Splinter/human_inst_seg', 8 | description='A Single Human Instance Segmentor runs at 50 FPS on GV100', 9 | version='0.0.2', 10 | author='Ruilong Li', 11 | author_email='ruilongl@usc.edu', 12 | license='MIT License', 13 | packages=setuptools.find_packages(), 14 | install_requires=INSTALL_REQUIREMENTS + [ 15 | 'segmentation_models_pytorch@git+https://github.com/qubvel/segmentation_models.pytorch', 16 | 'human_det@git+https://github.com/Project-Splinter/human_det', 17 | ] 18 | ) 19 | 20 | --------------------------------------------------------------------------------