├── cloths_segmentation ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── pre_trained_models.cpython-37.pyc ├── metrics.py ├── utils.py ├── pre_trained_models.py ├── dataloaders.py ├── configs │ ├── 2020-10-29.yaml │ ├── 2020-10-29a.yaml │ └── 2020-10-30.yaml ├── inference.py └── train.py ├── test.jpg ├── test.png ├── README.md └── rb.py /cloths_segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.2" 2 | -------------------------------------------------------------------------------- /test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/normalclone/clothes_segmentation/HEAD/test.jpg -------------------------------------------------------------------------------- /test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/normalclone/clothes_segmentation/HEAD/test.png -------------------------------------------------------------------------------- /cloths_segmentation/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/normalclone/clothes_segmentation/HEAD/cloths_segmentation/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /cloths_segmentation/__pycache__/pre_trained_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/normalclone/clothes_segmentation/HEAD/cloths_segmentation/__pycache__/pre_trained_models.cpython-37.pyc -------------------------------------------------------------------------------- /cloths_segmentation/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | EPSILON = 1e-15 4 | 5 | 6 | def binary_mean_iou(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 7 | output = (logits > 0).int() 8 | 9 | if output.shape != targets.shape: 10 | targets = torch.squeeze(targets, 1) 11 | 12 | intersection = (targets * output).sum() 13 | 14 | union = targets.sum() + output.sum() - intersection 15 | 16 | result = (intersection + EPSILON) / (union + EPSILON) 17 | 18 | return result 19 | -------------------------------------------------------------------------------- /cloths_segmentation/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union, Dict, List, Tuple 3 | 4 | 5 | def get_id2_file_paths(path: Union[str, Path]) -> Dict[str, Path]: 6 | return {x.stem: x for x in Path(path).glob("*.*")} 7 | 8 | 9 | def get_samples(image_path: Path, mask_path: Path) -> List[Tuple[Path, Path]]: 10 | """Couple masks and images. 11 | 12 | Args: 13 | image_path: 14 | mask_path: 15 | 16 | Returns: 17 | """ 18 | 19 | image2path = get_id2_file_paths(image_path) 20 | mask2path = get_id2_file_paths(mask_path) 21 | 22 | return [(image_file_path, mask2path[file_id]) for file_id, image_file_path in image2path.items()] 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Clothes Segmentation 2 | This is my implementation of [this project](https://github.com/ternaus/cloths_segmentation)! 3 | 4 | ## Dependencies 5 | - python >= 3.6 6 | - [pytorch](https://pytorch.org/) >= 1.2 7 | - opencv 8 | - matplotlib 9 | - albumentations, iglovikov_helper_functions, pytorch_lightning, pytorch_toolbelt, segmentation-models-pytorch, tqdm, wandb 10 | 11 | ## Installation 12 | 1. Download & install cuda 10.2 toolkit [here](https://developer.nvidia.com/cuda-10.2-download-archive?target_os=Linux&target_arch=x86_64&target_distro=Ubuntu&target_version=1804&target_type=debnetwork) 13 | 2. Download & install anaconda python 3.7 version 14 | 3. Install Dependencies 15 | 4. Run `main.py` 16 | 17 | ## A example 18 | 19 | 20 | -------------------------------------------------------------------------------- /cloths_segmentation/pre_trained_models.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from torch import nn 3 | from torch.utils import model_zoo 4 | from iglovikov_helper_functions.dl.pytorch.utils import rename_layers 5 | 6 | from segmentation_models_pytorch import Unet 7 | 8 | model = namedtuple("model", ["url", "model"]) 9 | 10 | models = { 11 | "Unet_2020-10-30": model( 12 | url="https://github.com/ternaus/cloths_segmentation/releases/download/0.0.1/weights.zip", 13 | model=Unet(encoder_name="timm-efficientnet-b3", classes=1, encoder_weights=None), 14 | ) 15 | } 16 | 17 | 18 | def create_model(model_name: str) -> nn.Module: 19 | model = models[model_name].model 20 | state_dict = model_zoo.load_url(models[model_name].url, progress=True, map_location="cpu")["state_dict"] 21 | state_dict = rename_layers(state_dict, {"model.": ""}) 22 | model.load_state_dict(state_dict) 23 | return model 24 | -------------------------------------------------------------------------------- /rb.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import albumentations as albu 5 | from iglovikov_helper_functions.utils.image_utils import load_rgb, pad, unpad 6 | from iglovikov_helper_functions.dl.pytorch.utils import tensor_from_rgb_image 7 | 8 | from cloths_segmentation.pre_trained_models import create_model 9 | model = create_model("Unet_2020-10-30") 10 | model.eval() 11 | 12 | image = cv2.imread(str(r"test.jpg")) 13 | image_2_extract = image 14 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 15 | transform = albu.Compose([albu.Normalize(p=1)], p=1) 16 | padded_image, pads = pad(image, factor=32, border=cv2.BORDER_CONSTANT) 17 | x = transform(image=padded_image)["image"] 18 | x = torch.unsqueeze(tensor_from_rgb_image(x), 0) 19 | 20 | with torch.no_grad(): 21 | prediction = model(x)[0][0] 22 | mask = (prediction > 0).cpu().numpy().astype(np.uint8) 23 | mask = unpad(mask, pads) 24 | rmask = (cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) * 255).astype(np.uint8) 25 | mask2 = np.where((rmask < 255), 0, 1).astype('uint8') 26 | image_2_extract = image_2_extract * mask2[:, :, 1, np.newaxis] 27 | 28 | tmp = cv2.cvtColor(image_2_extract, cv2.COLOR_BGR2GRAY) 29 | _, alpha = cv2.threshold(tmp, 0, 255, cv2.THRESH_BINARY) 30 | b, g, r = cv2.split(image_2_extract) 31 | rgba = [b, g, r, alpha] 32 | dst = cv2.merge(rgba, 4) 33 | cv2.imwrite("test.png", dst) 34 | cv2.waitKey(0) 35 | -------------------------------------------------------------------------------- /cloths_segmentation/dataloaders.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Dict, Any, Tuple 3 | 4 | import albumentations as albu 5 | import numpy as np 6 | import torch 7 | from iglovikov_helper_functions.utils.image_utils import load_rgb, load_grayscale 8 | from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class SegmentationDataset(Dataset): 13 | def __init__( 14 | self, 15 | samples: List[Tuple[Path, Path]], 16 | transform: albu.Compose, 17 | length: int = None, 18 | ) -> None: 19 | self.samples = samples 20 | self.transform = transform 21 | 22 | if length is None: 23 | self.length = len(self.samples) 24 | else: 25 | self.length = length 26 | 27 | def __len__(self) -> int: 28 | return self.length 29 | 30 | def __getitem__(self, idx: int) -> Dict[str, Any]: 31 | idx = idx % len(self.samples) 32 | 33 | image_path, mask_path = self.samples[idx] 34 | 35 | image = load_rgb(image_path, lib="cv2") 36 | mask = load_grayscale(mask_path) 37 | 38 | # apply augmentations 39 | sample = self.transform(image=image, mask=mask) 40 | image, mask = sample["image"], sample["mask"] 41 | 42 | mask = (mask > 0).astype(np.uint8) 43 | 44 | mask = torch.from_numpy(mask) 45 | 46 | return { 47 | "image_id": image_path.stem, 48 | "features": tensor_from_rgb_image(image), 49 | "masks": torch.unsqueeze(mask, 0).float(), 50 | } 51 | -------------------------------------------------------------------------------- /cloths_segmentation/configs/2020-10-29.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | seed: 1984 3 | 4 | num_workers: 4 5 | experiment_name: "2020-10-29" 6 | 7 | val_split: 0.2 8 | 9 | model: 10 | type: segmentation_models_pytorch.Unet 11 | encoder_name: timm-efficientnet-b3 12 | classes: 1 13 | encoder_weights: noisy-student 14 | 15 | trainer: 16 | type: pytorch_lightning.Trainer 17 | gpus: 4 18 | max_epochs: 30 19 | distributed_backend: ddp 20 | progress_bar_refresh_rate: 1 21 | benchmark: True 22 | precision: 16 23 | gradient_clip_val: 5.0 24 | num_sanity_val_steps: 2 25 | sync_batchnorm: True 26 | 27 | 28 | scheduler: 29 | type: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts 30 | T_0: 10 31 | T_mult: 2 32 | 33 | train_parameters: 34 | batch_size: 8 35 | 36 | checkpoint_callback: 37 | type: pytorch_lightning.callbacks.ModelCheckpoint 38 | filepath: "2020-10-29" 39 | monitor: val_iou 40 | verbose: True 41 | mode: max 42 | save_top_k: -1 43 | 44 | val_parameters: 45 | batch_size: 2 46 | 47 | optimizer: 48 | type: adamp.AdamP 49 | lr: 0.0001 50 | 51 | 52 | train_aug: 53 | transform: 54 | __class_fullname__: albumentations.core.composition.Compose 55 | bbox_params: null 56 | keypoint_params: null 57 | p: 1 58 | transforms: 59 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize 60 | always_apply: False 61 | max_size: 800 62 | p: 1 63 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded 64 | always_apply: False 65 | min_height: 800 66 | min_width: 800 67 | border_mode: 0 # cv2.BORDER_CONSTANT 68 | value: 0 69 | mask_value: 0 70 | p: 1 71 | - __class_fullname__: albumentations.augmentations.transforms.RandomCrop 72 | always_apply: False 73 | height: 512 74 | width: 512 75 | p: 1 76 | - __class_fullname__: albumentations.augmentations.transforms.HorizontalFlip 77 | always_apply: False 78 | p: 0.5 79 | - __class_fullname__: albumentations.augmentations.transforms.Normalize 80 | always_apply: false 81 | max_pixel_value: 255.0 82 | mean: 83 | - 0.485 84 | - 0.456 85 | - 0.406 86 | p: 1 87 | std: 88 | - 0.229 89 | - 0.224 90 | - 0.225 91 | 92 | val_aug: 93 | transform: 94 | __class_fullname__: albumentations.core.composition.Compose 95 | bbox_params: null 96 | keypoint_params: null 97 | p: 1 98 | transforms: 99 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize 100 | always_apply: False 101 | max_size: 800 102 | p: 1 103 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded 104 | always_apply: False 105 | min_height: 800 106 | min_width: 800 107 | border_mode: 0 # cv2.BORDER_CONSTANT 108 | value: 0 109 | mask_value: 0 110 | p: 1 111 | - __class_fullname__: albumentations.augmentations.transforms.Normalize 112 | always_apply: false 113 | max_pixel_value: 255.0 114 | mean: 115 | - 0.485 116 | - 0.456 117 | - 0.406 118 | p: 1 119 | std: 120 | - 0.229 121 | - 0.224 122 | - 0.225 123 | -------------------------------------------------------------------------------- /cloths_segmentation/configs/2020-10-29a.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | seed: 1984 3 | 4 | num_workers: 4 5 | experiment_name: "2020-10-29a" 6 | 7 | val_split: 0.1 8 | 9 | resume_from_checkpoint: 2020-10-29/epoch=4.ckpt 10 | 11 | model: 12 | type: segmentation_models_pytorch.Unet 13 | encoder_name: timm-efficientnet-b3 14 | classes: 1 15 | encoder_weights: noisy-student 16 | 17 | trainer: 18 | type: pytorch_lightning.Trainer 19 | gpus: 4 20 | max_epochs: 30 21 | distributed_backend: ddp 22 | progress_bar_refresh_rate: 1 23 | benchmark: True 24 | precision: 16 25 | gradient_clip_val: 5.0 26 | num_sanity_val_steps: 2 27 | sync_batchnorm: True 28 | 29 | 30 | scheduler: 31 | type: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts 32 | T_0: 10 33 | T_mult: 2 34 | 35 | train_parameters: 36 | batch_size: 8 37 | 38 | checkpoint_callback: 39 | type: pytorch_lightning.callbacks.ModelCheckpoint 40 | filepath: "2020-10-29a" 41 | monitor: val_iou 42 | verbose: True 43 | mode: max 44 | save_top_k: -1 45 | 46 | val_parameters: 47 | batch_size: 2 48 | 49 | optimizer: 50 | type: adamp.AdamP 51 | lr: 0.0001 52 | 53 | 54 | train_aug: 55 | transform: 56 | __class_fullname__: albumentations.core.composition.Compose 57 | bbox_params: null 58 | keypoint_params: null 59 | p: 1 60 | transforms: 61 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize 62 | always_apply: False 63 | max_size: 800 64 | p: 1 65 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded 66 | always_apply: False 67 | min_height: 800 68 | min_width: 800 69 | border_mode: 0 # cv2.BORDER_CONSTANT 70 | value: 0 71 | mask_value: 0 72 | p: 1 73 | - __class_fullname__: albumentations.augmentations.transforms.RandomCrop 74 | always_apply: False 75 | height: 512 76 | width: 512 77 | p: 1 78 | - __class_fullname__: albumentations.augmentations.transforms.HorizontalFlip 79 | always_apply: False 80 | p: 0.5 81 | - __class_fullname__: albumentations.augmentations.transforms.Normalize 82 | always_apply: false 83 | max_pixel_value: 255.0 84 | mean: 85 | - 0.485 86 | - 0.456 87 | - 0.406 88 | p: 1 89 | std: 90 | - 0.229 91 | - 0.224 92 | - 0.225 93 | 94 | val_aug: 95 | transform: 96 | __class_fullname__: albumentations.core.composition.Compose 97 | bbox_params: null 98 | keypoint_params: null 99 | p: 1 100 | transforms: 101 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize 102 | always_apply: False 103 | max_size: 800 104 | p: 1 105 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded 106 | always_apply: False 107 | min_height: 800 108 | min_width: 800 109 | border_mode: 0 # cv2.BORDER_CONSTANT 110 | value: 0 111 | mask_value: 0 112 | p: 1 113 | - __class_fullname__: albumentations.augmentations.transforms.Normalize 114 | always_apply: false 115 | max_pixel_value: 255.0 116 | mean: 117 | - 0.485 118 | - 0.456 119 | - 0.406 120 | p: 1 121 | std: 122 | - 0.229 123 | - 0.224 124 | - 0.225 125 | -------------------------------------------------------------------------------- /cloths_segmentation/configs/2020-10-30.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | seed: 1984 3 | 4 | num_workers: 4 5 | experiment_name: "2020-10-30" 6 | 7 | val_split: 0.1 8 | 9 | model: 10 | type: segmentation_models_pytorch.Unet 11 | encoder_name: timm-efficientnet-b3 12 | classes: 1 13 | encoder_weights: noisy-student 14 | 15 | trainer: 16 | type: pytorch_lightning.Trainer 17 | gpus: 4 18 | max_epochs: 70 19 | distributed_backend: ddp 20 | progress_bar_refresh_rate: 1 21 | benchmark: True 22 | precision: 16 23 | gradient_clip_val: 5.0 24 | num_sanity_val_steps: 2 25 | sync_batchnorm: True 26 | # resume_from_checkpoint: 2020-10-30/epoch=67.ckpt 27 | 28 | 29 | scheduler: 30 | type: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts 31 | T_0: 10 32 | T_mult: 2 33 | 34 | train_parameters: 35 | batch_size: 8 36 | 37 | checkpoint_callback: 38 | type: pytorch_lightning.callbacks.ModelCheckpoint 39 | filepath: "2020-10-30" 40 | monitor: val_iou 41 | verbose: True 42 | mode: max 43 | save_top_k: -1 44 | 45 | val_parameters: 46 | batch_size: 2 47 | 48 | optimizer: 49 | type: adamp.AdamP 50 | lr: 0.0001 51 | 52 | 53 | train_aug: 54 | transform: 55 | __class_fullname__: albumentations.core.composition.Compose 56 | bbox_params: null 57 | keypoint_params: null 58 | p: 1 59 | transforms: 60 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize 61 | always_apply: False 62 | max_size: 800 63 | p: 1 64 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded 65 | always_apply: False 66 | min_height: 800 67 | min_width: 800 68 | border_mode: 0 # cv2.BORDER_CONSTANT 69 | value: 0 70 | mask_value: 0 71 | p: 1 72 | - __class_fullname__: albumentations.augmentations.transforms.RandomCrop 73 | always_apply: False 74 | height: 512 75 | width: 512 76 | p: 1 77 | - __class_fullname__: albumentations.augmentations.transforms.HorizontalFlip 78 | always_apply: False 79 | p: 0.5 80 | - __class_fullname__: albumentations.augmentations.transforms.Normalize 81 | always_apply: false 82 | max_pixel_value: 255.0 83 | mean: 84 | - 0.485 85 | - 0.456 86 | - 0.406 87 | p: 1 88 | std: 89 | - 0.229 90 | - 0.224 91 | - 0.225 92 | 93 | val_aug: 94 | transform: 95 | __class_fullname__: albumentations.core.composition.Compose 96 | bbox_params: null 97 | keypoint_params: null 98 | p: 1 99 | transforms: 100 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize 101 | always_apply: False 102 | max_size: 800 103 | p: 1 104 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded 105 | always_apply: False 106 | min_height: 800 107 | min_width: 800 108 | border_mode: 0 # cv2.BORDER_CONSTANT 109 | value: 0 110 | mask_value: 0 111 | p: 1 112 | - __class_fullname__: albumentations.augmentations.transforms.Normalize 113 | always_apply: false 114 | max_pixel_value: 255.0 115 | mean: 116 | - 0.485 117 | - 0.456 118 | - 0.406 119 | p: 1 120 | std: 121 | - 0.229 122 | - 0.224 123 | - 0.225 124 | 125 | test_aug: 126 | transform: 127 | __class_fullname__: albumentations.core.composition.Compose 128 | bbox_params: null 129 | keypoint_params: null 130 | p: 1 131 | transforms: 132 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize 133 | always_apply: False 134 | max_size: 800 135 | p: 1 136 | - __class_fullname__: albumentations.augmentations.transforms.Normalize 137 | always_apply: false 138 | max_pixel_value: 255.0 139 | mean: 140 | - 0.485 141 | - 0.456 142 | - 0.406 143 | p: 1 144 | std: 145 | - 0.229 146 | - 0.224 147 | - 0.225 148 | -------------------------------------------------------------------------------- /cloths_segmentation/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from typing import Dict, List, Optional, Any 4 | 5 | import albumentations as albu 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torch.nn.parallel 10 | import torch.utils.data 11 | import torch.utils.data.distributed 12 | import yaml 13 | from albumentations.core.serialization import from_dict 14 | from iglovikov_helper_functions.config_parsing.utils import object_from_dict 15 | from iglovikov_helper_functions.dl.pytorch.utils import state_dict_from_disk, tensor_from_rgb_image 16 | from iglovikov_helper_functions.utils.image_utils import load_rgb, pad_to_size, unpad_from_size 17 | from torch.utils.data import Dataset 18 | from torch.utils.data.distributed import DistributedSampler 19 | from tqdm import tqdm 20 | 21 | 22 | def get_args(): 23 | parser = argparse.ArgumentParser() 24 | arg = parser.add_argument 25 | arg("-i", "--input_path", type=Path, help="Path with images.", required=True) 26 | arg("-c", "--config_path", type=Path, help="Path to config.", required=True) 27 | arg("-o", "--output_path", type=Path, help="Path to save masks.", required=True) 28 | arg("-b", "--batch_size", type=int, help="batch_size", default=1) 29 | arg("-j", "--num_workers", type=int, help="num_workers", default=12) 30 | arg("-w", "--weight_path", type=str, help="Path to weights.", required=True) 31 | arg("--world_size", default=-1, type=int, help="number of nodes for distributed training") 32 | arg("--local_rank", default=-1, type=int, help="node rank for distributed training") 33 | arg("--fp16", action="store_true", help="Use fp6") 34 | return parser.parse_args() 35 | 36 | 37 | class InferenceDataset(Dataset): 38 | def __init__(self, file_paths: List[Path], transform: albu.Compose) -> None: 39 | self.file_paths = file_paths 40 | self.transform = transform 41 | 42 | def __len__(self) -> int: 43 | return len(self.file_paths) 44 | 45 | def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]: 46 | image_path = self.file_paths[idx] 47 | 48 | image = load_rgb(image_path) 49 | height, width = image.shape[:2] 50 | 51 | image = self.transform(image=image)["image"] 52 | pad_dict = pad_to_size((max(image.shape[:2]), max(image.shape[:2])), image) 53 | 54 | return { 55 | "torched_image": tensor_from_rgb_image(pad_dict["image"]), 56 | "image_path": str(image_path), 57 | "pads": pad_dict["pads"], 58 | "original_width": width, 59 | "original_height": height, 60 | } 61 | 62 | 63 | def main(): 64 | args = get_args() 65 | torch.distributed.init_process_group(backend="nccl") 66 | 67 | with open(args.config_path) as f: 68 | hparams = yaml.load(f, Loader=yaml.SafeLoader) 69 | 70 | hparams.update( 71 | { 72 | "local_rank": args.local_rank, 73 | "fp16": args.fp16, 74 | } 75 | ) 76 | 77 | output_mask_path = args.output_path 78 | output_mask_path.mkdir(parents=True, exist_ok=True) 79 | hparams["output_mask_path"] = output_mask_path 80 | 81 | device = torch.device("cuda", args.local_rank) 82 | 83 | model = object_from_dict(hparams["model"]) 84 | model = model.to(device) 85 | 86 | if args.fp16: 87 | model = model.half() 88 | 89 | corrections: Dict[str, str] = {"model.": ""} 90 | state_dict = state_dict_from_disk(file_path=args.weight_path, rename_in_layers=corrections) 91 | model.load_state_dict(state_dict) 92 | 93 | model = torch.nn.parallel.DistributedDataParallel( 94 | model, device_ids=[args.local_rank], output_device=args.local_rank 95 | ) 96 | 97 | file_paths = [] 98 | 99 | for regexp in ["*.jpg", "*.png", "*.jpeg", "*.JPG"]: 100 | file_paths += sorted([x for x in tqdm(args.input_path.rglob(regexp))]) 101 | 102 | # Filter file paths for which we already have predictions 103 | file_paths = [x for x in file_paths if not (args.output_path / x.parent.name / f"{x.stem}.png").exists()] 104 | 105 | dataset = InferenceDataset(file_paths, transform=from_dict(hparams["test_aug"])) 106 | 107 | sampler = DistributedSampler(dataset, shuffle=False) 108 | 109 | dataloader = torch.utils.data.DataLoader( 110 | dataset, 111 | batch_size=args.batch_size, 112 | num_workers=args.num_workers, 113 | pin_memory=True, 114 | shuffle=False, 115 | drop_last=False, 116 | sampler=sampler, 117 | ) 118 | 119 | predict(dataloader, model, hparams, device) 120 | 121 | 122 | def predict(dataloader, model, hparams, device): 123 | model.eval() 124 | 125 | if hparams["local_rank"] == 0: 126 | loader = tqdm(dataloader) 127 | else: 128 | loader = dataloader 129 | 130 | with torch.no_grad(): 131 | for batch in loader: 132 | torched_images = batch["torched_image"] # images that are rescaled and padded 133 | 134 | if hparams["fp16"]: 135 | torched_images = torched_images.half() 136 | 137 | image_paths = batch["image_path"] 138 | pads = batch["pads"] 139 | heights = batch["original_height"] 140 | widths = batch["original_width"] 141 | 142 | batch_size = torched_images.shape[0] 143 | 144 | predictions = model(torched_images.to(device)) 145 | 146 | for batch_id in range(batch_size): 147 | file_id = Path(image_paths[batch_id]).stem 148 | folder_name = Path(image_paths[batch_id]).parent.name 149 | 150 | mask = (predictions[batch_id][0].cpu().numpy() > 0).astype(np.uint8) * 255 151 | mask = unpad_from_size(pads, image=mask)["image"] 152 | mask = cv2.resize( 153 | mask, (widths[batch_id].item(), heights[batch_id].item()), interpolation=cv2.INTER_NEAREST 154 | ) 155 | 156 | (hparams["output_mask_path"] / folder_name).mkdir(exist_ok=True, parents=True) 157 | cv2.imwrite(str(hparams["output_mask_path"] / folder_name / f"{file_id}.png"), mask) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /cloths_segmentation/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | from typing import Dict 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | import yaml 9 | from albumentations.core.serialization import from_dict 10 | from iglovikov_helper_functions.config_parsing.utils import object_from_dict 11 | from iglovikov_helper_functions.dl.pytorch.lightning import find_average 12 | from iglovikov_helper_functions.dl.pytorch.utils import state_dict_from_disk 13 | from pytorch_lightning.loggers import WandbLogger 14 | from pytorch_toolbelt.losses import JaccardLoss, BinaryFocalLoss 15 | from torch.utils.data import DataLoader 16 | 17 | from cloths_segmentation.dataloaders import SegmentationDataset 18 | from cloths_segmentation.metrics import binary_mean_iou 19 | from cloths_segmentation.utils import get_samples 20 | 21 | image_path = Path(os.environ["IMAGE_PATH"]) 22 | mask_path = Path(os.environ["MASK_PATH"]) 23 | 24 | 25 | def get_args(): 26 | parser = argparse.ArgumentParser() 27 | arg = parser.add_argument 28 | arg("-c", "--config_path", type=Path, help="Path to the config.", required=True) 29 | return parser.parse_args() 30 | 31 | 32 | class SegmentPeople(pl.LightningModule): 33 | def __init__(self, hparams): 34 | super().__init__() 35 | self.hparams = hparams 36 | 37 | self.model = object_from_dict(self.hparams["model"]) 38 | if "resume_from_checkpoint" in self.hparams: 39 | corrections: Dict[str, str] = {"model.": ""} 40 | 41 | state_dict = state_dict_from_disk( 42 | file_path=self.hparams["resume_from_checkpoint"], 43 | rename_in_layers=corrections, 44 | ) 45 | self.model.load_state_dict(state_dict) 46 | 47 | self.losses = [ 48 | ("jaccard", 0.1, JaccardLoss(mode="binary", from_logits=True)), 49 | ("focal", 0.9, BinaryFocalLoss()), 50 | ] 51 | 52 | def forward(self, batch: torch.Tensor) -> torch.Tensor: # type: ignore 53 | return self.model(batch) 54 | 55 | def setup(self, stage=0): 56 | samples = get_samples(image_path, mask_path) 57 | 58 | num_train = int((1 - self.hparams["val_split"]) * len(samples)) 59 | 60 | self.train_samples = samples[:num_train] 61 | self.val_samples = samples[num_train:] 62 | 63 | print("Len train samples = ", len(self.train_samples)) 64 | print("Len val samples = ", len(self.val_samples)) 65 | 66 | def train_dataloader(self): 67 | train_aug = from_dict(self.hparams["train_aug"]) 68 | 69 | if "epoch_length" not in self.hparams["train_parameters"]: 70 | epoch_length = None 71 | else: 72 | epoch_length = self.hparams["train_parameters"]["epoch_length"] 73 | 74 | result = DataLoader( 75 | SegmentationDataset(self.train_samples, train_aug, epoch_length), 76 | batch_size=self.hparams["train_parameters"]["batch_size"], 77 | num_workers=self.hparams["num_workers"], 78 | shuffle=True, 79 | pin_memory=True, 80 | drop_last=True, 81 | ) 82 | 83 | print("Train dataloader = ", len(result)) 84 | return result 85 | 86 | def val_dataloader(self): 87 | val_aug = from_dict(self.hparams["val_aug"]) 88 | 89 | result = DataLoader( 90 | SegmentationDataset(self.val_samples, val_aug, length=None), 91 | batch_size=self.hparams["val_parameters"]["batch_size"], 92 | num_workers=self.hparams["num_workers"], 93 | shuffle=False, 94 | pin_memory=True, 95 | drop_last=False, 96 | ) 97 | 98 | print("Val dataloader = ", len(result)) 99 | 100 | return result 101 | 102 | def configure_optimizers(self): 103 | optimizer = object_from_dict( 104 | self.hparams["optimizer"], 105 | params=[x for x in self.model.parameters() if x.requires_grad], 106 | ) 107 | 108 | scheduler = object_from_dict(self.hparams["scheduler"], optimizer=optimizer) 109 | self.optimizers = [optimizer] 110 | 111 | return self.optimizers, [scheduler] 112 | 113 | def training_step(self, batch, batch_idx): 114 | features = batch["features"] 115 | masks = batch["masks"] 116 | 117 | logits = self.forward(features) 118 | 119 | total_loss = 0 120 | logs = {} 121 | for loss_name, weight, loss in self.losses: 122 | ls_mask = loss(logits, masks) 123 | total_loss += weight * ls_mask 124 | logs[f"train_mask_{loss_name}"] = ls_mask 125 | 126 | logs["train_loss"] = total_loss 127 | 128 | logs["lr"] = self._get_current_lr() 129 | 130 | return {"loss": total_loss, "log": logs} 131 | 132 | def _get_current_lr(self) -> torch.Tensor: 133 | lr = [x["lr"] for x in self.optimizers[0].param_groups][0] # type: ignore 134 | return torch.Tensor([lr])[0].cuda() 135 | 136 | def validation_step(self, batch, batch_id): 137 | features = batch["features"] 138 | masks = batch["masks"] 139 | 140 | logits = self.forward(features) 141 | 142 | result = {} 143 | for loss_name, _, loss in self.losses: 144 | result[f"val_mask_{loss_name}"] = loss(logits, masks) 145 | 146 | result["val_iou"] = binary_mean_iou(logits, masks) 147 | 148 | return result 149 | 150 | def validation_epoch_end(self, outputs): 151 | logs = {"epoch": self.trainer.current_epoch} 152 | 153 | avg_val_iou = find_average(outputs, "val_iou") 154 | 155 | logs["val_iou"] = avg_val_iou 156 | 157 | return {"val_iou": avg_val_iou, "log": logs} 158 | 159 | 160 | def main(): 161 | args = get_args() 162 | 163 | with open(args.config_path) as f: 164 | hparams = yaml.load(f, Loader=yaml.SafeLoader) 165 | 166 | pipeline = SegmentPeople(hparams) 167 | 168 | Path(hparams["checkpoint_callback"]["filepath"]).mkdir(exist_ok=True, parents=True) 169 | 170 | trainer = object_from_dict( 171 | hparams["trainer"], 172 | logger=WandbLogger(hparams["experiment_name"]), 173 | checkpoint_callback=object_from_dict(hparams["checkpoint_callback"]), 174 | ) 175 | 176 | trainer.fit(pipeline) 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | --------------------------------------------------------------------------------