├── cloths_segmentation
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── pre_trained_models.cpython-37.pyc
├── metrics.py
├── utils.py
├── pre_trained_models.py
├── dataloaders.py
├── configs
│ ├── 2020-10-29.yaml
│ ├── 2020-10-29a.yaml
│ └── 2020-10-30.yaml
├── inference.py
└── train.py
├── test.jpg
├── test.png
├── README.md
└── rb.py
/cloths_segmentation/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.2"
2 |
--------------------------------------------------------------------------------
/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/normalclone/clothes_segmentation/HEAD/test.jpg
--------------------------------------------------------------------------------
/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/normalclone/clothes_segmentation/HEAD/test.png
--------------------------------------------------------------------------------
/cloths_segmentation/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/normalclone/clothes_segmentation/HEAD/cloths_segmentation/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/cloths_segmentation/__pycache__/pre_trained_models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/normalclone/clothes_segmentation/HEAD/cloths_segmentation/__pycache__/pre_trained_models.cpython-37.pyc
--------------------------------------------------------------------------------
/cloths_segmentation/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | EPSILON = 1e-15
4 |
5 |
6 | def binary_mean_iou(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
7 | output = (logits > 0).int()
8 |
9 | if output.shape != targets.shape:
10 | targets = torch.squeeze(targets, 1)
11 |
12 | intersection = (targets * output).sum()
13 |
14 | union = targets.sum() + output.sum() - intersection
15 |
16 | result = (intersection + EPSILON) / (union + EPSILON)
17 |
18 | return result
19 |
--------------------------------------------------------------------------------
/cloths_segmentation/utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Union, Dict, List, Tuple
3 |
4 |
5 | def get_id2_file_paths(path: Union[str, Path]) -> Dict[str, Path]:
6 | return {x.stem: x for x in Path(path).glob("*.*")}
7 |
8 |
9 | def get_samples(image_path: Path, mask_path: Path) -> List[Tuple[Path, Path]]:
10 | """Couple masks and images.
11 |
12 | Args:
13 | image_path:
14 | mask_path:
15 |
16 | Returns:
17 | """
18 |
19 | image2path = get_id2_file_paths(image_path)
20 | mask2path = get_id2_file_paths(mask_path)
21 |
22 | return [(image_file_path, mask2path[file_id]) for file_id, image_file_path in image2path.items()]
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Clothes Segmentation
2 | This is my implementation of [this project](https://github.com/ternaus/cloths_segmentation)!
3 |
4 | ## Dependencies
5 | - python >= 3.6
6 | - [pytorch](https://pytorch.org/) >= 1.2
7 | - opencv
8 | - matplotlib
9 | - albumentations, iglovikov_helper_functions, pytorch_lightning, pytorch_toolbelt, segmentation-models-pytorch, tqdm, wandb
10 |
11 | ## Installation
12 | 1. Download & install cuda 10.2 toolkit [here](https://developer.nvidia.com/cuda-10.2-download-archive?target_os=Linux&target_arch=x86_64&target_distro=Ubuntu&target_version=1804&target_type=debnetwork)
13 | 2. Download & install anaconda python 3.7 version
14 | 3. Install Dependencies
15 | 4. Run `main.py`
16 |
17 | ## A example
18 |
19 |
20 |
--------------------------------------------------------------------------------
/cloths_segmentation/pre_trained_models.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from torch import nn
3 | from torch.utils import model_zoo
4 | from iglovikov_helper_functions.dl.pytorch.utils import rename_layers
5 |
6 | from segmentation_models_pytorch import Unet
7 |
8 | model = namedtuple("model", ["url", "model"])
9 |
10 | models = {
11 | "Unet_2020-10-30": model(
12 | url="https://github.com/ternaus/cloths_segmentation/releases/download/0.0.1/weights.zip",
13 | model=Unet(encoder_name="timm-efficientnet-b3", classes=1, encoder_weights=None),
14 | )
15 | }
16 |
17 |
18 | def create_model(model_name: str) -> nn.Module:
19 | model = models[model_name].model
20 | state_dict = model_zoo.load_url(models[model_name].url, progress=True, map_location="cpu")["state_dict"]
21 | state_dict = rename_layers(state_dict, {"model.": ""})
22 | model.load_state_dict(state_dict)
23 | return model
24 |
--------------------------------------------------------------------------------
/rb.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import torch
4 | import albumentations as albu
5 | from iglovikov_helper_functions.utils.image_utils import load_rgb, pad, unpad
6 | from iglovikov_helper_functions.dl.pytorch.utils import tensor_from_rgb_image
7 |
8 | from cloths_segmentation.pre_trained_models import create_model
9 | model = create_model("Unet_2020-10-30")
10 | model.eval()
11 |
12 | image = cv2.imread(str(r"test.jpg"))
13 | image_2_extract = image
14 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
15 | transform = albu.Compose([albu.Normalize(p=1)], p=1)
16 | padded_image, pads = pad(image, factor=32, border=cv2.BORDER_CONSTANT)
17 | x = transform(image=padded_image)["image"]
18 | x = torch.unsqueeze(tensor_from_rgb_image(x), 0)
19 |
20 | with torch.no_grad():
21 | prediction = model(x)[0][0]
22 | mask = (prediction > 0).cpu().numpy().astype(np.uint8)
23 | mask = unpad(mask, pads)
24 | rmask = (cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) * 255).astype(np.uint8)
25 | mask2 = np.where((rmask < 255), 0, 1).astype('uint8')
26 | image_2_extract = image_2_extract * mask2[:, :, 1, np.newaxis]
27 |
28 | tmp = cv2.cvtColor(image_2_extract, cv2.COLOR_BGR2GRAY)
29 | _, alpha = cv2.threshold(tmp, 0, 255, cv2.THRESH_BINARY)
30 | b, g, r = cv2.split(image_2_extract)
31 | rgba = [b, g, r, alpha]
32 | dst = cv2.merge(rgba, 4)
33 | cv2.imwrite("test.png", dst)
34 | cv2.waitKey(0)
35 |
--------------------------------------------------------------------------------
/cloths_segmentation/dataloaders.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List, Dict, Any, Tuple
3 |
4 | import albumentations as albu
5 | import numpy as np
6 | import torch
7 | from iglovikov_helper_functions.utils.image_utils import load_rgb, load_grayscale
8 | from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image
9 | from torch.utils.data import Dataset
10 |
11 |
12 | class SegmentationDataset(Dataset):
13 | def __init__(
14 | self,
15 | samples: List[Tuple[Path, Path]],
16 | transform: albu.Compose,
17 | length: int = None,
18 | ) -> None:
19 | self.samples = samples
20 | self.transform = transform
21 |
22 | if length is None:
23 | self.length = len(self.samples)
24 | else:
25 | self.length = length
26 |
27 | def __len__(self) -> int:
28 | return self.length
29 |
30 | def __getitem__(self, idx: int) -> Dict[str, Any]:
31 | idx = idx % len(self.samples)
32 |
33 | image_path, mask_path = self.samples[idx]
34 |
35 | image = load_rgb(image_path, lib="cv2")
36 | mask = load_grayscale(mask_path)
37 |
38 | # apply augmentations
39 | sample = self.transform(image=image, mask=mask)
40 | image, mask = sample["image"], sample["mask"]
41 |
42 | mask = (mask > 0).astype(np.uint8)
43 |
44 | mask = torch.from_numpy(mask)
45 |
46 | return {
47 | "image_id": image_path.stem,
48 | "features": tensor_from_rgb_image(image),
49 | "masks": torch.unsqueeze(mask, 0).float(),
50 | }
51 |
--------------------------------------------------------------------------------
/cloths_segmentation/configs/2020-10-29.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | seed: 1984
3 |
4 | num_workers: 4
5 | experiment_name: "2020-10-29"
6 |
7 | val_split: 0.2
8 |
9 | model:
10 | type: segmentation_models_pytorch.Unet
11 | encoder_name: timm-efficientnet-b3
12 | classes: 1
13 | encoder_weights: noisy-student
14 |
15 | trainer:
16 | type: pytorch_lightning.Trainer
17 | gpus: 4
18 | max_epochs: 30
19 | distributed_backend: ddp
20 | progress_bar_refresh_rate: 1
21 | benchmark: True
22 | precision: 16
23 | gradient_clip_val: 5.0
24 | num_sanity_val_steps: 2
25 | sync_batchnorm: True
26 |
27 |
28 | scheduler:
29 | type: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
30 | T_0: 10
31 | T_mult: 2
32 |
33 | train_parameters:
34 | batch_size: 8
35 |
36 | checkpoint_callback:
37 | type: pytorch_lightning.callbacks.ModelCheckpoint
38 | filepath: "2020-10-29"
39 | monitor: val_iou
40 | verbose: True
41 | mode: max
42 | save_top_k: -1
43 |
44 | val_parameters:
45 | batch_size: 2
46 |
47 | optimizer:
48 | type: adamp.AdamP
49 | lr: 0.0001
50 |
51 |
52 | train_aug:
53 | transform:
54 | __class_fullname__: albumentations.core.composition.Compose
55 | bbox_params: null
56 | keypoint_params: null
57 | p: 1
58 | transforms:
59 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize
60 | always_apply: False
61 | max_size: 800
62 | p: 1
63 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded
64 | always_apply: False
65 | min_height: 800
66 | min_width: 800
67 | border_mode: 0 # cv2.BORDER_CONSTANT
68 | value: 0
69 | mask_value: 0
70 | p: 1
71 | - __class_fullname__: albumentations.augmentations.transforms.RandomCrop
72 | always_apply: False
73 | height: 512
74 | width: 512
75 | p: 1
76 | - __class_fullname__: albumentations.augmentations.transforms.HorizontalFlip
77 | always_apply: False
78 | p: 0.5
79 | - __class_fullname__: albumentations.augmentations.transforms.Normalize
80 | always_apply: false
81 | max_pixel_value: 255.0
82 | mean:
83 | - 0.485
84 | - 0.456
85 | - 0.406
86 | p: 1
87 | std:
88 | - 0.229
89 | - 0.224
90 | - 0.225
91 |
92 | val_aug:
93 | transform:
94 | __class_fullname__: albumentations.core.composition.Compose
95 | bbox_params: null
96 | keypoint_params: null
97 | p: 1
98 | transforms:
99 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize
100 | always_apply: False
101 | max_size: 800
102 | p: 1
103 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded
104 | always_apply: False
105 | min_height: 800
106 | min_width: 800
107 | border_mode: 0 # cv2.BORDER_CONSTANT
108 | value: 0
109 | mask_value: 0
110 | p: 1
111 | - __class_fullname__: albumentations.augmentations.transforms.Normalize
112 | always_apply: false
113 | max_pixel_value: 255.0
114 | mean:
115 | - 0.485
116 | - 0.456
117 | - 0.406
118 | p: 1
119 | std:
120 | - 0.229
121 | - 0.224
122 | - 0.225
123 |
--------------------------------------------------------------------------------
/cloths_segmentation/configs/2020-10-29a.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | seed: 1984
3 |
4 | num_workers: 4
5 | experiment_name: "2020-10-29a"
6 |
7 | val_split: 0.1
8 |
9 | resume_from_checkpoint: 2020-10-29/epoch=4.ckpt
10 |
11 | model:
12 | type: segmentation_models_pytorch.Unet
13 | encoder_name: timm-efficientnet-b3
14 | classes: 1
15 | encoder_weights: noisy-student
16 |
17 | trainer:
18 | type: pytorch_lightning.Trainer
19 | gpus: 4
20 | max_epochs: 30
21 | distributed_backend: ddp
22 | progress_bar_refresh_rate: 1
23 | benchmark: True
24 | precision: 16
25 | gradient_clip_val: 5.0
26 | num_sanity_val_steps: 2
27 | sync_batchnorm: True
28 |
29 |
30 | scheduler:
31 | type: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
32 | T_0: 10
33 | T_mult: 2
34 |
35 | train_parameters:
36 | batch_size: 8
37 |
38 | checkpoint_callback:
39 | type: pytorch_lightning.callbacks.ModelCheckpoint
40 | filepath: "2020-10-29a"
41 | monitor: val_iou
42 | verbose: True
43 | mode: max
44 | save_top_k: -1
45 |
46 | val_parameters:
47 | batch_size: 2
48 |
49 | optimizer:
50 | type: adamp.AdamP
51 | lr: 0.0001
52 |
53 |
54 | train_aug:
55 | transform:
56 | __class_fullname__: albumentations.core.composition.Compose
57 | bbox_params: null
58 | keypoint_params: null
59 | p: 1
60 | transforms:
61 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize
62 | always_apply: False
63 | max_size: 800
64 | p: 1
65 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded
66 | always_apply: False
67 | min_height: 800
68 | min_width: 800
69 | border_mode: 0 # cv2.BORDER_CONSTANT
70 | value: 0
71 | mask_value: 0
72 | p: 1
73 | - __class_fullname__: albumentations.augmentations.transforms.RandomCrop
74 | always_apply: False
75 | height: 512
76 | width: 512
77 | p: 1
78 | - __class_fullname__: albumentations.augmentations.transforms.HorizontalFlip
79 | always_apply: False
80 | p: 0.5
81 | - __class_fullname__: albumentations.augmentations.transforms.Normalize
82 | always_apply: false
83 | max_pixel_value: 255.0
84 | mean:
85 | - 0.485
86 | - 0.456
87 | - 0.406
88 | p: 1
89 | std:
90 | - 0.229
91 | - 0.224
92 | - 0.225
93 |
94 | val_aug:
95 | transform:
96 | __class_fullname__: albumentations.core.composition.Compose
97 | bbox_params: null
98 | keypoint_params: null
99 | p: 1
100 | transforms:
101 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize
102 | always_apply: False
103 | max_size: 800
104 | p: 1
105 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded
106 | always_apply: False
107 | min_height: 800
108 | min_width: 800
109 | border_mode: 0 # cv2.BORDER_CONSTANT
110 | value: 0
111 | mask_value: 0
112 | p: 1
113 | - __class_fullname__: albumentations.augmentations.transforms.Normalize
114 | always_apply: false
115 | max_pixel_value: 255.0
116 | mean:
117 | - 0.485
118 | - 0.456
119 | - 0.406
120 | p: 1
121 | std:
122 | - 0.229
123 | - 0.224
124 | - 0.225
125 |
--------------------------------------------------------------------------------
/cloths_segmentation/configs/2020-10-30.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | seed: 1984
3 |
4 | num_workers: 4
5 | experiment_name: "2020-10-30"
6 |
7 | val_split: 0.1
8 |
9 | model:
10 | type: segmentation_models_pytorch.Unet
11 | encoder_name: timm-efficientnet-b3
12 | classes: 1
13 | encoder_weights: noisy-student
14 |
15 | trainer:
16 | type: pytorch_lightning.Trainer
17 | gpus: 4
18 | max_epochs: 70
19 | distributed_backend: ddp
20 | progress_bar_refresh_rate: 1
21 | benchmark: True
22 | precision: 16
23 | gradient_clip_val: 5.0
24 | num_sanity_val_steps: 2
25 | sync_batchnorm: True
26 | # resume_from_checkpoint: 2020-10-30/epoch=67.ckpt
27 |
28 |
29 | scheduler:
30 | type: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
31 | T_0: 10
32 | T_mult: 2
33 |
34 | train_parameters:
35 | batch_size: 8
36 |
37 | checkpoint_callback:
38 | type: pytorch_lightning.callbacks.ModelCheckpoint
39 | filepath: "2020-10-30"
40 | monitor: val_iou
41 | verbose: True
42 | mode: max
43 | save_top_k: -1
44 |
45 | val_parameters:
46 | batch_size: 2
47 |
48 | optimizer:
49 | type: adamp.AdamP
50 | lr: 0.0001
51 |
52 |
53 | train_aug:
54 | transform:
55 | __class_fullname__: albumentations.core.composition.Compose
56 | bbox_params: null
57 | keypoint_params: null
58 | p: 1
59 | transforms:
60 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize
61 | always_apply: False
62 | max_size: 800
63 | p: 1
64 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded
65 | always_apply: False
66 | min_height: 800
67 | min_width: 800
68 | border_mode: 0 # cv2.BORDER_CONSTANT
69 | value: 0
70 | mask_value: 0
71 | p: 1
72 | - __class_fullname__: albumentations.augmentations.transforms.RandomCrop
73 | always_apply: False
74 | height: 512
75 | width: 512
76 | p: 1
77 | - __class_fullname__: albumentations.augmentations.transforms.HorizontalFlip
78 | always_apply: False
79 | p: 0.5
80 | - __class_fullname__: albumentations.augmentations.transforms.Normalize
81 | always_apply: false
82 | max_pixel_value: 255.0
83 | mean:
84 | - 0.485
85 | - 0.456
86 | - 0.406
87 | p: 1
88 | std:
89 | - 0.229
90 | - 0.224
91 | - 0.225
92 |
93 | val_aug:
94 | transform:
95 | __class_fullname__: albumentations.core.composition.Compose
96 | bbox_params: null
97 | keypoint_params: null
98 | p: 1
99 | transforms:
100 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize
101 | always_apply: False
102 | max_size: 800
103 | p: 1
104 | - __class_fullname__: albumentations.augmentations.transforms.PadIfNeeded
105 | always_apply: False
106 | min_height: 800
107 | min_width: 800
108 | border_mode: 0 # cv2.BORDER_CONSTANT
109 | value: 0
110 | mask_value: 0
111 | p: 1
112 | - __class_fullname__: albumentations.augmentations.transforms.Normalize
113 | always_apply: false
114 | max_pixel_value: 255.0
115 | mean:
116 | - 0.485
117 | - 0.456
118 | - 0.406
119 | p: 1
120 | std:
121 | - 0.229
122 | - 0.224
123 | - 0.225
124 |
125 | test_aug:
126 | transform:
127 | __class_fullname__: albumentations.core.composition.Compose
128 | bbox_params: null
129 | keypoint_params: null
130 | p: 1
131 | transforms:
132 | - __class_fullname__: albumentations.augmentations.transforms.LongestMaxSize
133 | always_apply: False
134 | max_size: 800
135 | p: 1
136 | - __class_fullname__: albumentations.augmentations.transforms.Normalize
137 | always_apply: false
138 | max_pixel_value: 255.0
139 | mean:
140 | - 0.485
141 | - 0.456
142 | - 0.406
143 | p: 1
144 | std:
145 | - 0.229
146 | - 0.224
147 | - 0.225
148 |
--------------------------------------------------------------------------------
/cloths_segmentation/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from typing import Dict, List, Optional, Any
4 |
5 | import albumentations as albu
6 | import cv2
7 | import numpy as np
8 | import torch
9 | import torch.nn.parallel
10 | import torch.utils.data
11 | import torch.utils.data.distributed
12 | import yaml
13 | from albumentations.core.serialization import from_dict
14 | from iglovikov_helper_functions.config_parsing.utils import object_from_dict
15 | from iglovikov_helper_functions.dl.pytorch.utils import state_dict_from_disk, tensor_from_rgb_image
16 | from iglovikov_helper_functions.utils.image_utils import load_rgb, pad_to_size, unpad_from_size
17 | from torch.utils.data import Dataset
18 | from torch.utils.data.distributed import DistributedSampler
19 | from tqdm import tqdm
20 |
21 |
22 | def get_args():
23 | parser = argparse.ArgumentParser()
24 | arg = parser.add_argument
25 | arg("-i", "--input_path", type=Path, help="Path with images.", required=True)
26 | arg("-c", "--config_path", type=Path, help="Path to config.", required=True)
27 | arg("-o", "--output_path", type=Path, help="Path to save masks.", required=True)
28 | arg("-b", "--batch_size", type=int, help="batch_size", default=1)
29 | arg("-j", "--num_workers", type=int, help="num_workers", default=12)
30 | arg("-w", "--weight_path", type=str, help="Path to weights.", required=True)
31 | arg("--world_size", default=-1, type=int, help="number of nodes for distributed training")
32 | arg("--local_rank", default=-1, type=int, help="node rank for distributed training")
33 | arg("--fp16", action="store_true", help="Use fp6")
34 | return parser.parse_args()
35 |
36 |
37 | class InferenceDataset(Dataset):
38 | def __init__(self, file_paths: List[Path], transform: albu.Compose) -> None:
39 | self.file_paths = file_paths
40 | self.transform = transform
41 |
42 | def __len__(self) -> int:
43 | return len(self.file_paths)
44 |
45 | def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]:
46 | image_path = self.file_paths[idx]
47 |
48 | image = load_rgb(image_path)
49 | height, width = image.shape[:2]
50 |
51 | image = self.transform(image=image)["image"]
52 | pad_dict = pad_to_size((max(image.shape[:2]), max(image.shape[:2])), image)
53 |
54 | return {
55 | "torched_image": tensor_from_rgb_image(pad_dict["image"]),
56 | "image_path": str(image_path),
57 | "pads": pad_dict["pads"],
58 | "original_width": width,
59 | "original_height": height,
60 | }
61 |
62 |
63 | def main():
64 | args = get_args()
65 | torch.distributed.init_process_group(backend="nccl")
66 |
67 | with open(args.config_path) as f:
68 | hparams = yaml.load(f, Loader=yaml.SafeLoader)
69 |
70 | hparams.update(
71 | {
72 | "local_rank": args.local_rank,
73 | "fp16": args.fp16,
74 | }
75 | )
76 |
77 | output_mask_path = args.output_path
78 | output_mask_path.mkdir(parents=True, exist_ok=True)
79 | hparams["output_mask_path"] = output_mask_path
80 |
81 | device = torch.device("cuda", args.local_rank)
82 |
83 | model = object_from_dict(hparams["model"])
84 | model = model.to(device)
85 |
86 | if args.fp16:
87 | model = model.half()
88 |
89 | corrections: Dict[str, str] = {"model.": ""}
90 | state_dict = state_dict_from_disk(file_path=args.weight_path, rename_in_layers=corrections)
91 | model.load_state_dict(state_dict)
92 |
93 | model = torch.nn.parallel.DistributedDataParallel(
94 | model, device_ids=[args.local_rank], output_device=args.local_rank
95 | )
96 |
97 | file_paths = []
98 |
99 | for regexp in ["*.jpg", "*.png", "*.jpeg", "*.JPG"]:
100 | file_paths += sorted([x for x in tqdm(args.input_path.rglob(regexp))])
101 |
102 | # Filter file paths for which we already have predictions
103 | file_paths = [x for x in file_paths if not (args.output_path / x.parent.name / f"{x.stem}.png").exists()]
104 |
105 | dataset = InferenceDataset(file_paths, transform=from_dict(hparams["test_aug"]))
106 |
107 | sampler = DistributedSampler(dataset, shuffle=False)
108 |
109 | dataloader = torch.utils.data.DataLoader(
110 | dataset,
111 | batch_size=args.batch_size,
112 | num_workers=args.num_workers,
113 | pin_memory=True,
114 | shuffle=False,
115 | drop_last=False,
116 | sampler=sampler,
117 | )
118 |
119 | predict(dataloader, model, hparams, device)
120 |
121 |
122 | def predict(dataloader, model, hparams, device):
123 | model.eval()
124 |
125 | if hparams["local_rank"] == 0:
126 | loader = tqdm(dataloader)
127 | else:
128 | loader = dataloader
129 |
130 | with torch.no_grad():
131 | for batch in loader:
132 | torched_images = batch["torched_image"] # images that are rescaled and padded
133 |
134 | if hparams["fp16"]:
135 | torched_images = torched_images.half()
136 |
137 | image_paths = batch["image_path"]
138 | pads = batch["pads"]
139 | heights = batch["original_height"]
140 | widths = batch["original_width"]
141 |
142 | batch_size = torched_images.shape[0]
143 |
144 | predictions = model(torched_images.to(device))
145 |
146 | for batch_id in range(batch_size):
147 | file_id = Path(image_paths[batch_id]).stem
148 | folder_name = Path(image_paths[batch_id]).parent.name
149 |
150 | mask = (predictions[batch_id][0].cpu().numpy() > 0).astype(np.uint8) * 255
151 | mask = unpad_from_size(pads, image=mask)["image"]
152 | mask = cv2.resize(
153 | mask, (widths[batch_id].item(), heights[batch_id].item()), interpolation=cv2.INTER_NEAREST
154 | )
155 |
156 | (hparams["output_mask_path"] / folder_name).mkdir(exist_ok=True, parents=True)
157 | cv2.imwrite(str(hparams["output_mask_path"] / folder_name / f"{file_id}.png"), mask)
158 |
159 |
160 | if __name__ == "__main__":
161 | main()
162 |
--------------------------------------------------------------------------------
/cloths_segmentation/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 | from typing import Dict
5 |
6 | import pytorch_lightning as pl
7 | import torch
8 | import yaml
9 | from albumentations.core.serialization import from_dict
10 | from iglovikov_helper_functions.config_parsing.utils import object_from_dict
11 | from iglovikov_helper_functions.dl.pytorch.lightning import find_average
12 | from iglovikov_helper_functions.dl.pytorch.utils import state_dict_from_disk
13 | from pytorch_lightning.loggers import WandbLogger
14 | from pytorch_toolbelt.losses import JaccardLoss, BinaryFocalLoss
15 | from torch.utils.data import DataLoader
16 |
17 | from cloths_segmentation.dataloaders import SegmentationDataset
18 | from cloths_segmentation.metrics import binary_mean_iou
19 | from cloths_segmentation.utils import get_samples
20 |
21 | image_path = Path(os.environ["IMAGE_PATH"])
22 | mask_path = Path(os.environ["MASK_PATH"])
23 |
24 |
25 | def get_args():
26 | parser = argparse.ArgumentParser()
27 | arg = parser.add_argument
28 | arg("-c", "--config_path", type=Path, help="Path to the config.", required=True)
29 | return parser.parse_args()
30 |
31 |
32 | class SegmentPeople(pl.LightningModule):
33 | def __init__(self, hparams):
34 | super().__init__()
35 | self.hparams = hparams
36 |
37 | self.model = object_from_dict(self.hparams["model"])
38 | if "resume_from_checkpoint" in self.hparams:
39 | corrections: Dict[str, str] = {"model.": ""}
40 |
41 | state_dict = state_dict_from_disk(
42 | file_path=self.hparams["resume_from_checkpoint"],
43 | rename_in_layers=corrections,
44 | )
45 | self.model.load_state_dict(state_dict)
46 |
47 | self.losses = [
48 | ("jaccard", 0.1, JaccardLoss(mode="binary", from_logits=True)),
49 | ("focal", 0.9, BinaryFocalLoss()),
50 | ]
51 |
52 | def forward(self, batch: torch.Tensor) -> torch.Tensor: # type: ignore
53 | return self.model(batch)
54 |
55 | def setup(self, stage=0):
56 | samples = get_samples(image_path, mask_path)
57 |
58 | num_train = int((1 - self.hparams["val_split"]) * len(samples))
59 |
60 | self.train_samples = samples[:num_train]
61 | self.val_samples = samples[num_train:]
62 |
63 | print("Len train samples = ", len(self.train_samples))
64 | print("Len val samples = ", len(self.val_samples))
65 |
66 | def train_dataloader(self):
67 | train_aug = from_dict(self.hparams["train_aug"])
68 |
69 | if "epoch_length" not in self.hparams["train_parameters"]:
70 | epoch_length = None
71 | else:
72 | epoch_length = self.hparams["train_parameters"]["epoch_length"]
73 |
74 | result = DataLoader(
75 | SegmentationDataset(self.train_samples, train_aug, epoch_length),
76 | batch_size=self.hparams["train_parameters"]["batch_size"],
77 | num_workers=self.hparams["num_workers"],
78 | shuffle=True,
79 | pin_memory=True,
80 | drop_last=True,
81 | )
82 |
83 | print("Train dataloader = ", len(result))
84 | return result
85 |
86 | def val_dataloader(self):
87 | val_aug = from_dict(self.hparams["val_aug"])
88 |
89 | result = DataLoader(
90 | SegmentationDataset(self.val_samples, val_aug, length=None),
91 | batch_size=self.hparams["val_parameters"]["batch_size"],
92 | num_workers=self.hparams["num_workers"],
93 | shuffle=False,
94 | pin_memory=True,
95 | drop_last=False,
96 | )
97 |
98 | print("Val dataloader = ", len(result))
99 |
100 | return result
101 |
102 | def configure_optimizers(self):
103 | optimizer = object_from_dict(
104 | self.hparams["optimizer"],
105 | params=[x for x in self.model.parameters() if x.requires_grad],
106 | )
107 |
108 | scheduler = object_from_dict(self.hparams["scheduler"], optimizer=optimizer)
109 | self.optimizers = [optimizer]
110 |
111 | return self.optimizers, [scheduler]
112 |
113 | def training_step(self, batch, batch_idx):
114 | features = batch["features"]
115 | masks = batch["masks"]
116 |
117 | logits = self.forward(features)
118 |
119 | total_loss = 0
120 | logs = {}
121 | for loss_name, weight, loss in self.losses:
122 | ls_mask = loss(logits, masks)
123 | total_loss += weight * ls_mask
124 | logs[f"train_mask_{loss_name}"] = ls_mask
125 |
126 | logs["train_loss"] = total_loss
127 |
128 | logs["lr"] = self._get_current_lr()
129 |
130 | return {"loss": total_loss, "log": logs}
131 |
132 | def _get_current_lr(self) -> torch.Tensor:
133 | lr = [x["lr"] for x in self.optimizers[0].param_groups][0] # type: ignore
134 | return torch.Tensor([lr])[0].cuda()
135 |
136 | def validation_step(self, batch, batch_id):
137 | features = batch["features"]
138 | masks = batch["masks"]
139 |
140 | logits = self.forward(features)
141 |
142 | result = {}
143 | for loss_name, _, loss in self.losses:
144 | result[f"val_mask_{loss_name}"] = loss(logits, masks)
145 |
146 | result["val_iou"] = binary_mean_iou(logits, masks)
147 |
148 | return result
149 |
150 | def validation_epoch_end(self, outputs):
151 | logs = {"epoch": self.trainer.current_epoch}
152 |
153 | avg_val_iou = find_average(outputs, "val_iou")
154 |
155 | logs["val_iou"] = avg_val_iou
156 |
157 | return {"val_iou": avg_val_iou, "log": logs}
158 |
159 |
160 | def main():
161 | args = get_args()
162 |
163 | with open(args.config_path) as f:
164 | hparams = yaml.load(f, Loader=yaml.SafeLoader)
165 |
166 | pipeline = SegmentPeople(hparams)
167 |
168 | Path(hparams["checkpoint_callback"]["filepath"]).mkdir(exist_ok=True, parents=True)
169 |
170 | trainer = object_from_dict(
171 | hparams["trainer"],
172 | logger=WandbLogger(hparams["experiment_name"]),
173 | checkpoint_callback=object_from_dict(hparams["checkpoint_callback"]),
174 | )
175 |
176 | trainer.fit(pipeline)
177 |
178 |
179 | if __name__ == "__main__":
180 | main()
181 |
--------------------------------------------------------------------------------