├── .github
└── FUNDING.yml
├── .gitignore
├── Data
├── README.md
├── __init__.py
├── dataset_preprocessor.py
├── dataset_preprocessor_web.py
├── preprocessor.py
├── preprocessor_web.py
├── preprocessors
│ ├── __init__.py
│ ├── detectron2_preprocessor.py
│ ├── edge_extractor.py
│ ├── face_alignment_preprocessor.py
│ └── human_parts_preprocessor.py
└── utils.py
├── LICENSE
├── README.md
├── conf
├── img_config.yaml
├── preprocess_data.yaml
├── preprocess_data_web.yaml
├── seg_config.yaml
└── show.yaml
├── log_utils.py
├── losses
├── __init__.py
├── discriminator.py
├── face_loss.py
├── loss_img.py
├── loss_seg.py
├── lpips.py
└── lpips_with_object.py
├── models
├── __init__.py
├── modules.py
├── transformer.py
└── vqvae.py
├── train.py
└── utils.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: casualganpapers
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *.pyc
3 |
--------------------------------------------------------------------------------
/Data/README.md:
--------------------------------------------------------------------------------
1 | # Data Aggregation
2 |
3 |
4 |
5 |
6 | ## Segmentation Dataset
7 | "VQ-SEG and VQ-IMG are trained on CC12m, CC, and MS-COCO." -> For the segmentation
8 | process we first need to convert all the 3 datasets to segmentation datasets using
9 | the 3 models described below:
10 |
11 | - Panoptic: https://github.com/facebookresearch/detectron2
12 | - Human Parts: https://github.com/PeikeLi/Self-Correction-Human-Parsing
13 | - Human Face: https://github.com/1adrianb/face-alignment
14 |
15 | These 3 models will be used to construct the dataset for the segmentation maps.
16 |
17 | VQ-SEG was trained to have 158 categories
18 | - 133 panoptic
19 | - 20 human parts
20 | - 5 human face (eye-brows, eyes, nose, outer-mouth, inner-mouth)
21 |
22 | ## Data Pipeline:
23 | 1. Take in an image or dataset (HxWx3)
24 | 2. seg_panoptic = detectron2(x)
25 | 3. seg_human = human_parsing(x)
26 | 4. seg_face = human_face(x)
27 | 5. Concatenate along channel axis
28 | 6. Add one channel for edges between objects
29 | 7. return segmentation map (HxWx159)
30 |
31 |
--------------------------------------------------------------------------------
/Data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset_preprocessor import COCO2014Dataset
2 | from .dataset_preprocessor_web import PreprocessedWebDataset
3 | from warnings import warn
4 | try:
5 | from .preprocessor import BasePreprocessor
6 | from .preprocessor_web import WebPreprocessor
7 | except ModuleNotFoundError:
8 | #warn("Some dependencies missing for data preprocessing.")
9 | pass
10 |
--------------------------------------------------------------------------------
/Data/dataset_preprocessor.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import numpy as np
3 | import os
4 | import cv2
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.utils.data import Dataset
8 | from torch.utils.data import ConcatDataset as ConcatDataset_
9 | from tqdm import tqdm
10 | import torch.multiprocessing as mp
11 | import warnings
12 | from torchvision import transforms
13 | from hydra.utils import instantiate
14 | import albumentations as A
15 | from albumentations.pytorch import ToTensorV2
16 | from urllib.request import urlretrieve
17 |
18 |
19 | class PreprocessedDataset(Dataset):
20 | def __init__(
21 | self,
22 | root=None,
23 | image_dirs=None,
24 | preprocessed_folder=None,
25 | ):
26 | self.image_dirs = image_dirs
27 |
28 | self.preprocessed_folder = preprocessed_folder
29 | self.preprocessed_path = os.path.join(preprocessed_folder, "segmentations", "%s_%s.npz", )
30 |
31 | self.root = root
32 | self.transforms = A.Compose([
33 | A.SmallestMaxSize(256),
34 | A.RandomCrop(256, 256, always_apply=True),
35 | # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
36 | ToTensorV2(transpose_mask=True)
37 | ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=["class_labels"], min_area=100, min_visibility=0.2),
38 | additional_targets={"bboxes0": "bboxes"})
39 |
40 | if not os.path.exists(preprocessed_folder):
41 | os.makedirs(preprocessed_folder)
42 | else:
43 | assert os.path.isdir(preprocessed_folder)
44 |
45 | self.img_names_path = os.path.join(preprocessed_folder, f"img_names_{self.name}.npz")
46 | if not os.path.exists(self.img_names_path):
47 | self.parse_image_names()
48 |
49 | img_list = np.load(self.img_names_path)
50 | self.img_names = img_list["img_names"]
51 | if "img_urls" in img_list:
52 | self.img_urls = img_list["img_urls"]
53 |
54 | def load_segmentation(self, idx):
55 | img_name = os.path.splitext(self.img_names[idx])[0]
56 |
57 | data_panoptic = np.load(self.preprocessed_path % (img_name, "panoptic"))
58 | data_human = np.load(self.preprocessed_path % (img_name, "human"))
59 | data_face = np.load(self.preprocessed_path % (img_name, "face"))
60 |
61 | # Panoptic
62 | seg_panoptic = F.one_hot(
63 | torch.from_numpy(data_panoptic["seg_panoptic"] + 1).to(torch.long), num_classes=134
64 | )[..., 1:]
65 | edges_panoptic = torch.from_numpy(data_panoptic["edges"]).unsqueeze(-1)
66 | box_thing = data_panoptic["box_things"]
67 |
68 | # Human parts
69 | seg_human = F.one_hot(
70 | torch.from_numpy(data_human["seg_human"] + 1).to(torch.long), num_classes=21
71 | )[..., 1:]
72 | edges_human = torch.from_numpy(data_human["edges"]).unsqueeze(-1)
73 |
74 | # Edges
75 | seg_edges = (edges_panoptic + edges_human).float()
76 |
77 | # Face
78 | seg_face = F.one_hot(
79 | torch.from_numpy(data_face["seg_face"]).to(torch.long), num_classes=6
80 | )[..., 1:]
81 | box_face = data_face["box_face"]
82 |
83 | # Concatenate masks
84 | seg_map = torch.cat(
85 | [seg_panoptic, seg_human, seg_face, seg_edges], dim=-1
86 | )
87 |
88 | return np.array(seg_map), box_thing, box_face
89 |
90 | def __getitem__(self, idx):
91 | segmentation, box_thing, box_face = self.load_segmentation(idx)
92 | image, _ = self.get_image(idx)
93 | data = self.transforms(image=image, mask=segmentation, bboxes=box_thing, bboxes0=box_face,
94 | class_labels=np.zeros(box_thing.shape[0]))
95 | return data["image"], data["mask"], data["bboxes"], data["bboxes0"], self.img_names[idx]
96 |
97 | def get_image(self, idx):
98 | raise NotImplementedError
99 |
100 | def parse_image_names(self):
101 | raise NotImplementedError
102 |
103 | def __len__(self):
104 | return len(self.img_names)
105 |
106 |
107 | class BaseCOCODataset(PreprocessedDataset):
108 | def get_image(self, idx):
109 | img_name = self.img_names[idx]
110 | path = os.path.join(self.root, img_name)
111 | image = cv2.imread(path)
112 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
113 | return image, img_name
114 |
115 | def parse_image_names(self):
116 | img_names = []
117 | for directory in self.image_dirs:
118 | for filename in os.listdir(os.path.join(self.root, directory)):
119 | if os.path.splitext(filename)[1] in [".jpg", ".png"]:
120 | img_names.append(os.path.join(directory, filename))
121 | np.savez(self.img_names_path, img_names=img_names)
122 |
123 |
124 | class COCO2014Dataset(BaseCOCODataset):
125 | name = "coco2014"
126 | image_dirs = "train2014"
127 |
128 | def __init__(self, root, preprocessed_folder, **kwargs):
129 | super().__init__(
130 | root=root,
131 | image_dirs=["train2014"],
132 | preprocessed_folder=preprocessed_folder,
133 | **kwargs,
134 | )
135 |
136 |
137 | class COCO2017Dataset(BaseCOCODataset):
138 | name = "coco2017"
139 | image_dirs = "train2017"
140 |
141 | def __init__(self, root, preprocessed_folder, **kwargs):
142 | super().__init__(
143 | root=root,
144 | image_dirs=["train2017"],
145 | preprocessed_folder=preprocessed_folder,
146 | **kwargs,
147 | )
148 |
149 |
150 | class Conceptual12mDataset(PreprocessedDataset):
151 | name = "cc12m"
152 |
153 | def __init__(self, root, preprocessed_folder, **kwargs):
154 | super().__init__(
155 | root=root,
156 | **kwargs,
157 | )
158 |
159 | def parse_image_names(self, listfile):
160 | img_names = []
161 | img_urls = []
162 | with open(listfile, "r") as urllist:
163 | for i, line in enumerate(urllist):
164 | url, caption = line.split("\t")
165 | caption = caption.strip()
166 | img_names.append(caption + ".jpg")
167 | np.savez(self.img_names_path, img_names=img_names, img_urls=img_urls)
168 |
169 | def get_image(self, idx):
170 | img_name = self.img_names[idx]
171 | path = os.path.join(self.root, img_name)
172 | if not os.path.exists(path):
173 | self.download_image(self.url[idx], img_name)
174 | image = cv2.imread(path)
175 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
176 | return image, img_name
177 |
178 | def download_image(self, url, image_name):
179 | try:
180 | image_path = os.path.join(self.root, image_name)
181 | urlretrieve(url, image_path)
182 | return True
183 | except HTTPError:
184 | print("Failed to download the image: ", image_name)
185 | return False
186 |
187 |
188 | class ConcatDataset(ConcatDataset_):
189 | def get_true_idx(self, idx):
190 | if idx < 0:
191 | if -idx > len(self):
192 | raise ValueError("absolute value of index should not exceed dataset length")
193 | idx = len(self) + idx
194 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
195 | if dataset_idx == 0:
196 | sample_idx = idx
197 | else:
198 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
199 | return dataset_idx, sample_idx
200 |
201 | def get_image(self, idx):
202 | dataset_idx, sample_idx = self.get_true_idx(idx)
203 | return self.datasets[dataset_idx].get_image(sample_idx)
204 |
205 |
206 | if __name__ == "__main__":
207 | coco = COCO2014Dataset(
208 | "./mydb", "./mydb/preprocessed"
209 | )
210 | from torchvision.utils import draw_bounding_boxes
211 | import matplotlib.pyplot as plt
212 |
213 | img, _, ft, fb, _ = coco[0]
214 | plt.imshow(draw_bounding_boxes(img, torch.tensor(ft + fb)).permute(1, 2, 0))
215 | plt.show()
216 | print()
217 |
--------------------------------------------------------------------------------
/Data/dataset_preprocessor_web.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import numpy as np
3 | import os
4 | import cv2
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.utils.data import Dataset
8 | from torch.utils.data import ConcatDataset as ConcatDataset_
9 | from tqdm import tqdm
10 | import torch.multiprocessing as mp
11 | import warnings
12 | from torchvision import transforms
13 | from hydra.utils import instantiate
14 | import albumentations as A
15 | from albumentations.pytorch import ToTensorV2
16 | from .utils import check_bboxes
17 | from urllib.request import urlretrieve
18 | from webdataset import WebDataset
19 | from webdataset.shardlists import split_by_node
20 | from webdataset.handlers import warn_and_continue
21 | from itertools import islice
22 |
23 | def my_split_by_node(src, group=None):
24 | rank, world_size, = int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"])
25 | if world_size > 1:
26 | for s in islice(islice(src, (rank*2)//world_size, None, 2), rank%(world_size//2), None, world_size//2):
27 | yield s
28 | else:
29 | for s in src:
30 | yield s
31 |
32 |
33 | class PreprocessData:
34 | def __init__(self, ready_queue):
35 | self.transforms = A.Compose([
36 | A.SmallestMaxSize(512),
37 | A.CenterCrop(512, 512, always_apply=True),
38 | ToTensorV2(transpose_mask=True)
39 | ])
40 | self.lasttar = "no"
41 | self.ready_queue = ready_queue
42 |
43 | def __call__(self, data):
44 | result = self.transforms(image=data["jpg"])
45 | data["jpg"] = result["image"]
46 | data["tarname"] = os.path.basename(data["__url__"])
47 | if self.lasttar!=data["tarname"]:
48 | rank = os.environ["RANK"]
49 | proc_type = os.environ["TYPE"]
50 | worker_info = torch.utils.data.get_worker_info()
51 | if worker_info is not None:
52 | worker = worker_info.id
53 | else:
54 | worker = None
55 |
56 | if self.lasttar != "no":
57 | self.ready_queue.put("%s/%s/processed/%s" % (rank+"-"+proc_type, worker, self.lasttar))
58 | print(self.lasttar, "processed!")
59 | self.ready_queue.put("%s/%s/started/%s" % (rank+"-"+proc_type, worker, data["tarname"]))
60 | self.lasttar = data["tarname"]
61 | return data
62 |
63 | class UnprocessedWebDataset(WebDataset):
64 | def __init__(self, root, *args, is_dir=False, ready_queue=None, **kwargs):
65 | if is_dir:
66 | shards = [os.path.join(root, filename) for filename in os.listdir(root) if os.path.splitext(filename)[1]==".tar"]
67 | shards.sort()
68 | self.basedir = root
69 | else:
70 | shards = root
71 | self.basedir = os.path.dirname(root)
72 | super().__init__(shards, *args, nodesplitter=my_split_by_node, handler=warn_and_continue, **kwargs)
73 | self.decode("rgb")
74 | self.map(PreprocessData(ready_queue))
75 | self.to_tuple("__key__", "tarname", "jpg")
76 |
77 |
78 | class ProcessData:
79 | def __init__(self,):
80 | self.pretransforms = A.Compose([
81 | A.SmallestMaxSize(512),
82 | A.CenterCrop(512, 512, always_apply=True),
83 | ])
84 | self.transforms = A.Compose([
85 | #A.SmallestMaxSize(512),
86 | #A.CenterCrop(512, 512, always_apply=True),
87 | #A.RandomCrop(256, 256, always_apply=True),
88 | # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
89 | ToTensorV2(transpose_mask=True)
90 | ], bbox_params=A.BboxParams(format='pascal_voc', min_area=100, min_visibility=0.2),
91 | additional_targets={"bboxes0": "bboxes"})
92 |
93 | def __call__(self, data):
94 | npz_data = data["npz"]
95 | # Panoptic
96 | seg_panoptic = F.one_hot(
97 | torch.from_numpy(npz_data["seg_panoptic"] + 1).to(torch.long), num_classes=134
98 | )[..., 1:]
99 | edges_panoptic = torch.from_numpy(npz_data["edge_panoptic"]).unsqueeze(-1)
100 | box_thing = npz_data["box_things"].tolist()
101 | for box in box_thing:
102 | box.append(0)
103 |
104 | # Human parts
105 | seg_human = F.one_hot(
106 | torch.from_numpy(npz_data["seg_human"] + 1).to(torch.long), num_classes=21
107 | )[..., 1:]
108 | edges_human = torch.from_numpy(npz_data["edge_human"]).unsqueeze(-1)
109 |
110 | # Edges
111 | seg_edges = (edges_panoptic + edges_human).float()
112 |
113 | # Face
114 | seg_face = F.one_hot(
115 | torch.from_numpy(npz_data["seg_face"]).to(torch.long), num_classes=6
116 | )[..., 1:]
117 | box_face = npz_data["box_face"].tolist()
118 | for box in box_face:
119 | box.append(0)
120 |
121 | # Concatenate masks
122 | seg_map = torch.cat(
123 | [seg_panoptic, seg_human, seg_face, seg_edges], dim=-1
124 | ).numpy()
125 |
126 | data["jpg"] = self.pretransforms(image=data["jpg"])["image"]
127 | box_thing = check_bboxes(box_thing)
128 | box_face = check_bboxes(box_face)
129 | transformed_data = self.transforms(image=data["jpg"], bboxes=box_thing, bboxes0=box_face,)
130 | data["jpg"] = transformed_data["image"]
131 | data["mask"] = seg_map
132 | data["box_things"] = transformed_data["bboxes"]
133 | data["box_face"] = transformed_data["bboxes0"]
134 | return data
135 |
136 |
137 | class PreprocessedWebDataset(WebDataset):
138 | def __init__(self, url, *args, **kwargs):
139 | super().__init__(url, *args, nodesplitter=split_by_node, handler=warn_and_continue, **kwargs)
140 | self.decode("rgb")
141 | #self.decode("npz")
142 | self.map(ProcessData(), handler=warn_and_continue)
143 | self.to_tuple("jpg", "mask", "box_things", "box_face", "txt", handler=warn_and_continue)
144 |
145 | class COCOWebDataset(PreprocessedWebDataset):
146 | def __init__(self, *args, **kwargs):
147 | super().__init__("pipe:aws s3 cp s3://s-mas/coco_processed/{00000..00010}.tar -", *args, **kwargs)
148 |
149 | class CC3MWebDataset(PreprocessedWebDataset):
150 | def __init__(self, *args, **kwargs):
151 | super().__init__("pipe:aws s3 cp s3://s-mas/cc3m_processed/{00000..00311}.tar -", *args, **kwargs)
152 |
153 | class S3ProcessedDataset(PreprocessedWebDataset):
154 | datasets = {
155 | "coco": "pipe:aws s3 cp s3://s-mas/coco_processed/{00000..00059}.tar -",
156 | "cc3m": "pipe:aws s3 cp s3://s-mas/cc3m_processed/{00000..00331}.tar -",
157 | "cc12m": "pipe:aws s3 cp s3://s-mas/cc12m_processed/{00000..01242}.tar -",
158 | "laion": "pipe:aws s3 cp s3://s-mas/laion_en_processed/{00000..01500}.tar -"
159 | }
160 | def __init__(self, names, *args, **kwargs):
161 | urls = []
162 | for name in names:
163 | assert name in self.datasets, f"There is no processed dataset {name}"
164 | urls.append(self.datasets[name])
165 | urls = "::".join(urls)
166 | super().__init__(urls, *args, **kwargs)
167 |
168 | if __name__ == "__main__":
169 | coco = COCO2014Dataset(
170 | "./mydb", "./mydb/preprocessed"
171 | )
172 | from torchvision.utils import draw_bounding_boxes
173 | import matplotlib.pyplot as plt
174 |
175 | img, _, ft, fb, _ = coco[0]
176 | plt.imshow(draw_bounding_boxes(img, torch.tensor(ft + fb)).permute(1, 2, 0))
177 | plt.show()
178 | print()
179 |
180 |
--------------------------------------------------------------------------------
/Data/preprocessor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.multiprocessing as mp
4 | from .preprocessors import Detectron2
5 | from .preprocessors import HumanParts
6 | from .preprocessors import HumanFace
7 | from tqdm import tqdm
8 | import numpy as np
9 |
10 |
11 | class BasePreprocessor:
12 | proc_types = {"panoptic": Detectron2, "human": HumanParts, "face": HumanFace}
13 |
14 | def __init__(
15 | self,
16 | preprocessed_folder,
17 | proc_per_gpu=None,
18 | proc_per_cpu=None,
19 | devices=None,
20 | machine_idx=0,
21 | machines_total=1,
22 | ):
23 | self.idx = machine_idx
24 | self.machines_total = machines_total
25 | self.devices = list(devices)
26 | self.proc_per_gpu = proc_per_gpu
27 | self.proc_per_cpu = proc_per_cpu
28 | self.preprocessed_folder = preprocessed_folder
29 | self.preprocessed_path = os.path.join(
30 | preprocessed_folder,
31 | "segmentations",
32 | f"%s_%s.npz",
33 | )
34 | self.log_path = os.path.join(
35 | preprocessed_folder,
36 | f"%s_%s.log",
37 | )
38 |
39 | def __call__(self, dataset):
40 | self.dataset = dataset
41 | assert torch.cuda.is_available(), "GPU required for preprocessing"
42 | procs = []
43 | mp.set_start_method("spawn")
44 | for proc_type in self.proc_per_gpu:
45 | devices = self.devices * self.proc_per_gpu[proc_type]
46 | n_cpus = self.proc_per_cpu[proc_type]
47 | proc_per_machine = len(devices) + n_cpus
48 | proc_total = proc_per_machine * self.machines_total
49 | # GPUs
50 | for proc_id, dev_id in enumerate(devices):
51 | p = mp.Process(
52 | target=self.preprocess_single_process,
53 | args=(
54 | proc_type,
55 | self.idx * proc_per_machine + proc_id,
56 | dev_id,
57 | proc_total,
58 | ),
59 | )
60 | p.start()
61 | procs.append(p)
62 | # CPUs
63 | for proc_id in range(n_cpus):
64 | p = mp.Process(
65 | target=self.preprocess_single_process,
66 | args=(
67 | proc_type,
68 | self.idx * proc_per_machine + len(devices) + proc_id,
69 | "cpu",
70 | proc_total,
71 | ),
72 | )
73 | p.start()
74 | procs.append(p)
75 | for proc in procs:
76 | proc.join()
77 |
78 | def preprocess_single_process(self, proc_type, proc_id, dev_id, proc_total):
79 | correct_names = []
80 | if dev_id != "cpu":
81 | torch.cuda.set_device( # https://github.com/pytorch/pytorch/issues/21819#issuecomment-553310128
82 | dev_id
83 | )
84 | device = f"cuda:{dev_id}"
85 | else:
86 | device = "cpu"
87 | processor = self.proc_types[proc_type](device=device)
88 | log_path = self.log_path % (proc_id, proc_type)
89 | self.check_path(log_path)
90 | with open(log_path, "w") as logfile:
91 | for idx in tqdm(range(len(self.dataset)), file=logfile):
92 | if idx % proc_total != proc_id:
93 | continue
94 | image, image_name = self.dataset.get_image(idx)
95 | image_name = os.path.splitext(image_name)[0]
96 | data = processor(image)
97 | save_path = self.preprocessed_path % (image_name, proc_type)
98 | self.check_path(save_path)
99 | np.savez(save_path, **data)
100 |
101 | def check_path(self, path):
102 | dirname = os.path.dirname(path)
103 | if not os.path.exists(dirname):
104 | try:
105 | os.makedirs(dirname)
106 | except FileExistsError:
107 | pass
108 |
--------------------------------------------------------------------------------
/Data/preprocessor_web.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.multiprocessing as mp
4 | from torch.utils.data import DataLoader
5 | from .preprocessors import Detectron2
6 | from .preprocessors import HumanParts
7 | from .preprocessors import HumanFace
8 | from tqdm import tqdm
9 | import numpy as np
10 | import hydra
11 | from webdataset import WebDataset, TarWriter
12 | from webdataset.handlers import warn_and_continue
13 | from time import time, sleep
14 | import json
15 | import shutil
16 | import fsspec
17 | import queue
18 |
19 |
20 | class WebPreprocessor:
21 | proc_types = {"panoptic": Detectron2, "human": HumanParts, "face": HumanFace}
22 |
23 | def __init__(
24 | self,
25 | preprocessed_folder,
26 | output_folder,
27 | proc_per_gpu=None,
28 | proc_per_cpu=None,
29 | devices=None,
30 | machine_idx=0,
31 | machines_total=1,
32 | batch_size=5,
33 | num_workers=2
34 | ):
35 | self.idx = machine_idx
36 | self.machines_total = machines_total
37 | self.devices = list(devices)
38 | self.output_folder = output_folder
39 | self.proc_per_gpu = proc_per_gpu
40 | self.proc_per_cpu = proc_per_cpu
41 | self.batch_size = batch_size
42 | self.num_workers = num_workers
43 | self.preprocessed_folder = preprocessed_folder
44 | self.preprocessed_path = os.path.join(
45 | preprocessed_folder,
46 | "untars",
47 | f"%s/%s_%s.npz",
48 | )
49 | self.repacked_path = os.path.join(
50 | preprocessed_folder,
51 | "tars",
52 | )
53 | self.log_path = os.path.join(
54 | preprocessed_folder,
55 | f"%s\\%s_%s.log",
56 | )
57 |
58 | def __call__(self, dataset):
59 | self.dataset = dataset
60 | assert torch.cuda.is_available(), "GPU required for preprocessing"
61 | procs = []
62 | mp.set_start_method("spawn")
63 | ready_queue = mp.Queue()
64 | proc_type_locks = {proc_type: mp.Value("b", 0) for proc_type in self.proc_types}
65 | proc_per_machine = 0
66 | for value in self.proc_per_gpu:
67 | proc_per_machine += len(value)
68 | for proc_type in self.proc_per_gpu:
69 | devices = len(self.proc_per_gpu[proc_type])
70 | n_cpus = self.proc_per_cpu[proc_type]
71 | proc_total_type = devices * self.machines_total
72 | # GPUs
73 | for proc_id, dev_id in enumerate(self.proc_per_gpu[proc_type]):
74 | p = mp.Process(
75 | target=self.preprocess_single_process,
76 | args=(
77 | proc_type,
78 | self.idx * devices + proc_id,
79 | dev_id,
80 | proc_total_type,
81 | ready_queue,
82 | proc_type_locks[proc_type]
83 | ),
84 | )
85 | p.start()
86 | procs.append(p)
87 | # CPUs
88 | for proc_id in range(n_cpus):
89 | p = mp.Process(
90 | target=self.preprocess_single_process,
91 | args=(
92 | proc_type,
93 | self.idx * proc_per_machine + len(devices) + proc_id,
94 | "cpu",
95 | proc_total,
96 | ready_queue,
97 | ),
98 | )
99 | p.start()
100 | procs.append(p)
101 |
102 | self.repacker_process(ready_queue, proc_per_machine, proc_type_locks)
103 | for proc in procs:
104 | proc.join()
105 |
106 | def preprocess_single_process(self, proc_type, proc_id, dev_id, proc_total, ready_queue, proc_type_lock):
107 | os.environ["RANK"] = str(proc_id)
108 | os.environ["TYPE"] = proc_type
109 | os.environ["WORLD_SIZE"] = str(proc_total)
110 | ready_queue.put("%s/%s/Init/%s" %(proc_id, proc_total, proc_type))
111 | dataset = hydra.utils.instantiate(self.dataset, ready_queue=ready_queue)
112 | dataloader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
113 | correct_names = []
114 | torch.set_num_threads(15)
115 | if dev_id != "cpu":
116 | torch.cuda.set_device( # https://github.com/pytorch/pytorch/issues/21819#issuecomment-553310128
117 | dev_id
118 | )
119 | device = f"cuda:{dev_id}"
120 | else:
121 | device = "cpu"
122 | processor = self.proc_types[proc_type](device=device)
123 | log_path = self.log_path % (proc_id, proc_total, proc_type)
124 | self.check_path(log_path)
125 | with open(log_path, "w") as logfile:
126 | x = 0
127 | t = time()
128 | times = []
129 | for batch in tqdm(dataloader, file=logfile):
130 | if proc_type_lock.value:
131 | print("Waiting slower preprocessors ", proc_type)
132 | while proc_type_lock.value:
133 | sleep(0.1)
134 | print("Waiting released", proc_type)
135 | times.append(str(time()-t))
136 | t = time()
137 | imgnames, tarnames, images = batch
138 | times.append(str(time()-t))
139 | t = time()
140 | batched_data = processor(images*255.)
141 | times.append(str(time()-t))
142 | t = time()
143 | #print(proc_id, proc_total, imgnames, tarnames)
144 | #print(batched_data["box_things"])
145 | for i in range(len(imgnames)):
146 | tarname, imgname = tarnames[i], imgnames[i]
147 | save_path = self.preprocessed_path % (tarname, imgname, proc_type)
148 | self.check_path(save_path)
149 | data = {key: batched_data[key][i] for key in batched_data}
150 | np.savez(save_path, **data)
151 | times.append(str(time()-t))
152 | t = time()
153 | #ready_queue.put("%s/%s/info/%s" % (proc_id, proc_type, ",".join(times)))
154 | times = []
155 | ready_queue.put("%s/%s/done/und" % (str(proc_id) +"-"+ proc_type, proc_type))
156 |
157 | def repacker_process(self, ready_queue, proc_per_machine, proc_type_locks):
158 | proc_done = 0
159 | procs = []
160 | max_repackings = 20
161 | repackings_queue = queue.Queue()
162 | repacking_done = mp.Queue()
163 | info = {"repackings": 0}
164 | progress = {"panoptic": 0, "human": 0, "face": 0}
165 | while proc_done < proc_per_machine:
166 | command = ready_queue.get()
167 | print("Got", command)
168 | proc_id, worker, state, tarname = command.split("/")
169 | if state == "done":
170 | proc_done += 1
171 | for worker in info[proc_id]:
172 | tarname = info[proc_id][worker]
173 | info[tarname] = 0 if tarname not in info else info[tarname]
174 | info[tarname] += 1
175 | if info[tarname] == 3:
176 | repackings_queue.put(tarname)
177 | #self.repack_single_tar(tarname)
178 | elif state == "started":
179 | if proc_id not in info:
180 | info[proc_id] = {}
181 | info[proc_id][worker] = tarname
182 | elif state == "processed":
183 | info[tarname] = 0 if tarname not in info else info[tarname]
184 | info[tarname] += 1
185 | proc_type = proc_id.split("-")[1]
186 | progress[proc_type] += 1
187 | if info[tarname] == 3:
188 | repackings_queue.put(tarname)
189 | #self.repack_single_tar(tarname)
190 | if progress[proc_type] - np.min(list(progress.values())) > 30:
191 | proc_type_locks[proc_type].value = 1
192 | if progress[proc_type] == np.min(list(progress.values())):
193 | for t in proc_type_locks:
194 | proc_type_locks[t].value = 0
195 |
196 |
197 | elif state== "info":
198 | continue
199 |
200 | while repacking_done.qsize() > 0:
201 | msg = repacking_done.get()
202 | with open("info.log", "a") as f:
203 | f.write(msg+"\n")
204 | info["repackings"] -= 1
205 |
206 | # Repack original and processed data to new tar
207 | while info["repackings"] < max_repackings and repackings_queue.qsize() > 0:
208 | tarname = repackings_queue.get()
209 | print("Started repacking", tarname)
210 | p = mp.Process(
211 | target=self.repack_single_tar,
212 | args=(
213 | tarname,
214 | repacking_done,
215 | ),
216 | )
217 | p.start()
218 | procs.append(p)
219 | info["repackings"] += 1
220 |
221 |
222 | with open("info.state", "w") as f:
223 | line = json.dumps(info, indent=3, sort_keys=True)
224 | f.write(str(line))
225 | with open("info.log", "a") as f:
226 | f.write(command+"\n")
227 |
228 |
229 |
230 | for p in procs:
231 | p.join()
232 | print("Processed all data!")
233 |
234 | def repack_single_tar(self, tarname, repacking_done):
235 | if os.path.isdir(self.dataset.root):
236 | root = self.datatset.root
237 | else:
238 | root = os.path.dirname(self.dataset.root)
239 | old_data = WebDataset(os.path.join(root, tarname), handler=warn_and_continue)
240 | output_folder = "s3://s-mas/" + self.output_folder
241 | fs, output_path = fsspec.core.url_to_fs(output_folder)
242 | tar_fd = fs.open(f"{output_path}/{tarname.split(' ')[0]}", "wb")
243 | new_data = TarWriter(tar_fd)
244 | for sample in old_data:
245 | new_sample = {}
246 | imgname = new_sample["__key__"] = sample["__key__"]
247 | new_sample["jpg"] = sample["jpg"]
248 | new_sample["txt"] = sample["txt"]
249 |
250 | data_face = np.load(self.preprocessed_path % (tarname, imgname, "face"), allow_pickle=True)
251 | data_human = np.load(self.preprocessed_path % (tarname, imgname, "human"))
252 | data_panoptic = np.load(self.preprocessed_path % (tarname, imgname, "panoptic"))
253 |
254 | data_merged = {}
255 | data_merged["seg_panoptic"] = data_panoptic["seg_panoptic"]
256 | data_merged["edge_panoptic"] = data_panoptic["edges"]
257 | data_merged["box_things"] = data_panoptic["box_things"]
258 | data_merged["seg_human"] = data_human["seg_human"]
259 | data_merged["edge_human"] = data_human["edges"]
260 | data_merged["seg_face"] = data_face["seg_face"]
261 | data_merged["box_face"] = data_face["box_face"]
262 | new_sample["npz"] = data_merged
263 |
264 | new_data.write(new_sample)
265 | print("Finished repacking", tarname)
266 | shutil.rmtree(os.path.join(self.preprocessed_folder, "untars", tarname))
267 | new_data.close()
268 | repacking_done.put("Finished repacking "+ tarname)
269 |
270 |
271 |
272 | def check_path(self, path):
273 | dirname = os.path.dirname(path)
274 | if not os.path.exists(dirname):
275 | try:
276 | os.makedirs(dirname)
277 | except FileExistsError:
278 | pass
279 |
--------------------------------------------------------------------------------
/Data/preprocessors/__init__.py:
--------------------------------------------------------------------------------
1 | from .detectron2_preprocessor import PanopticPreprocesor as Detectron2
2 | from .human_parts_preprocessor import HumanPartsPreprocessor as HumanParts
3 | from .face_alignment_preprocessor import FaceAlignmentPreprocessor as HumanFace
4 | from .edge_extractor import get_edges
5 |
--------------------------------------------------------------------------------
/Data/preprocessors/detectron2_preprocessor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import torch.nn.functional as F
5 | # from torchvision.ops import masks_to_boxes
6 | import numpy as np
7 | import detectron2
8 | from detectron2.projects.panoptic_deeplab import add_panoptic_deeplab_config
9 | from detectron2.engine import DefaultPredictor
10 | from detectron2.checkpoint import DetectionCheckpointer
11 | from detectron2.config import get_cfg
12 | from detectron2.modeling import build_model
13 | from .edge_extractor import get_edges
14 |
15 |
16 | def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
17 | if masks.numel() == 0:
18 | return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
19 | n = masks.shape[0]
20 | bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float)
21 | for index, mask in enumerate(masks):
22 | y, x = torch.where(mask != 0)
23 |
24 | bounding_boxes[index, 0] = torch.min(x)
25 | bounding_boxes[index, 1] = torch.min(y)
26 | bounding_boxes[index, 2] = torch.max(x)
27 | bounding_boxes[index, 3] = torch.max(y)
28 |
29 | return bounding_boxes
30 |
31 |
32 | class Predictor:
33 | def __init__(self, cfg):
34 | self.cfg = cfg.clone()
35 | self.model = build_model(self.cfg)
36 | self.model.eval()
37 | checkpointer = DetectionCheckpointer(self.model)
38 | checkpointer.load(cfg.MODEL.WEIGHTS)
39 | self.input_format = cfg.INPUT.FORMAT
40 |
41 | def __call__(self, imgs: np.array):
42 | # imgs should be numpy b x c x h x w
43 | with torch.no_grad():
44 | #imgs = torch.as_tensor(imgs.astype("float32"))
45 | if self.input_format == "RGB":
46 | # whether the model expects BGR inputs or RGB
47 | imgs = imgs.flip([1])
48 |
49 | height, width = imgs.shape[2:]
50 |
51 | inputs = [{"image": image, "height": height, "width": width} for image in imgs]
52 | predictions = self.model(inputs)
53 | return predictions
54 |
55 |
56 | class PanopticPreprocesor:
57 | proc_type = "panoptic"
58 | def __init__(
59 | self,
60 | config="/home/ubuntu/anaconda3/envs/schp/lib/python3.8/site-packages/detectron2/projects/panoptic_deeplab/configs/COCO-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_200k_bs64_crop_640_640_coco_dsconv.yaml",
61 | num_classes=133,
62 | model_weights="/home/ubuntu/anaconda3/envs/schp/lib/python3.8/site-packages/detectron2/projects/panoptic_deeplab/configs/COCO-PanopticSegmentation/model_final_5e6da2.pkl",
63 | device=None,
64 | ):
65 | self.num_classes = num_classes
66 |
67 | cfg = get_cfg()
68 | print(cfg.MODEL.DEVICE)
69 | opts = ["MODEL.WEIGHTS", model_weights]
70 | add_panoptic_deeplab_config(cfg)
71 | cfg.merge_from_file(config)
72 | cfg.merge_from_list(opts)
73 | if device is not None:
74 | cfg.MODEL.DEVICE = device
75 | cfg.freeze()
76 | print("Building model")
77 | self.predictor = Predictor(cfg)
78 |
79 | def bounding_boxes(self, panoptics):
80 | all_boxes = []
81 | for panoptic in panoptics:
82 | #panoptic = torch.Tensor(panoptic)
83 | obj_ids = torch.unique(panoptic)
84 | thing_ids = obj_ids[obj_ids/1000 < 80] # according to panopticapi first 80 classes are "things"
85 | binary_masks = panoptic == thing_ids[:, None, None]
86 | boxes = masks_to_boxes(binary_masks)
87 | all_boxes.append(boxes.cpu().numpy())
88 | return all_boxes
89 |
90 | def __call__(self, imgs: np.array):
91 | # imgs should be numpy b x c x h x w
92 | data = {}
93 | # Returns tensor of shape [H, W] with values equal to 1000*class_id + instance_idx
94 | # panoptic = self.predictor(imgs)["panoptic_seg"][0].cpu()
95 | panoptic = self.predictor(imgs)
96 | panoptic = list(map(lambda pan: pan["panoptic_seg"][0], panoptic))
97 | bounding_boxes = self.bounding_boxes(panoptic)
98 | panoptic = list(map(lambda pan: pan.cpu().numpy(), panoptic))
99 | panoptic = np.array(panoptic)
100 | edges = get_edges(panoptic)
101 | data["seg_panoptic"] = np.array(panoptic // 1000, dtype=np.uint8)
102 | data["box_things"] = bounding_boxes
103 | data["edges"] = edges.astype(bool)
104 | return data
105 |
106 |
107 | if __name__ == "__main__":
108 | # img = cv2.imread("humans.jpg")
109 | img = np.random.randint(0, 255, (5, 3, 300, 300), dtype=np.uint8)
110 | detectron2 = PanopticPreprocesor()
111 | panoptic = detectron2(img)
112 | print(panoptic)
113 | # torch.save(panoptic, "test.pth")
114 |
--------------------------------------------------------------------------------
/Data/preprocessors/edge_extractor.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | THICKNESS = 1
5 |
6 |
7 | def get_edges(masks):
8 | #!face contours are not used
9 | all_edges = np.zeros(masks.shape)
10 | for i, mask in enumerate(masks):
11 | edges = np.zeros(masks[0].shape)
12 | contours, _ = cv2.findContours(
13 | mask, cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_SIMPLE
14 | )
15 | edges = cv2.drawContours(edges, contours, -1, 1, THICKNESS)
16 | all_edges[i] = edges
17 | return all_edges
18 |
--------------------------------------------------------------------------------
/Data/preprocessors/face_alignment_preprocessor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 | import face_alignment
5 | import ssl
6 | import matplotlib.pyplot as plt
7 | ssl._create_default_https_context = ssl._create_unverified_context
8 | from time import time
9 |
10 | BEARD = 0
11 | BROW = 1
12 | NOSE = 2
13 | EYE = 3
14 | MOUTH = 4
15 |
16 |
17 | class FaceAlignmentPreprocessor:
18 | proc_type = "face"
19 | last_beard = 17
20 | last_brow = 27
21 | last_nose = 36
22 | last_eye = 48
23 | last_mouth = 68
24 |
25 | def __init__(self, n_classes=5, face_confidence=0.95, device="cuda"):
26 | self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, face_detector_kwargs={"filter_threshold":face_confidence}, device=device)
27 | self.face_confidence = face_confidence
28 | self.n_classes = n_classes
29 | self.class_idxs = {
30 | BEARD: (0, self.last_beard),
31 | BROW: (self.last_beard, self.last_brow),
32 | NOSE: (self.last_brow, self.last_nose),
33 | EYE: (self.last_nose, self.last_eye),
34 | MOUTH: (self.last_eye, self.last_mouth),
35 | }
36 | #self.class_idxs = {
37 | # BEARD: torch.arange(0, self.last_beard),
38 | # BROW: torch.arange(self.last_beard, self.last_brow),
39 | # NOSE: torch.arange(self.last_brow, self.last_nose),
40 | # EYE: torch.arange(self.last_nose, self.last_eye),
41 | # MOUTH: torch.arange(self.last_eye, self.last_mouth),
42 | #}
43 |
44 |
45 | def process_image(self, img):
46 | img = img[:, :, ::-1] # face_alignment work with BGR colorspace
47 | points = self.fa.get_landmarks(img)
48 | seg_mask = torch.zeros(*img.shape[:-1])
49 | if points is None:
50 | return seg_mask
51 | for face in points:
52 | face = face.astype(int)
53 | for class_id in range(self.n_classes):
54 | for point_id in self.class_idxs[class_id]:
55 | try:
56 | seg_mask[face[point_id, 1], face[point_id, 0]] = class_id + 1
57 | except IndexError:
58 | # Probably only part of the face on the image
59 | pass
60 | return seg_mask # F.one_hot(seg_mask.to(torch.long), num_classes=6)[..., 1:].permute(2, 0, 1)
61 |
62 | def interpolate_face(self, face):
63 | interpolation = []
64 | for class_id in range(self.n_classes):
65 | part_interpolation = []
66 | part = face[self.class_idxs[class_id]]
67 | for idx, (i, j) in enumerate(zip(part, part[1:])):
68 | if self.class_idxs[class_id][idx] in (21, 41): # to avoid that both eyes (or both brows) are connected
69 | continue
70 | # print(self.class_idxs[class_id][idx])
71 | part_interpolation.extend(
72 | list(np.round(np.linspace(i, j, 100)).astype(np.int32)) +
73 | list(np.round(np.linspace(i, j, 100)).astype(np.int32) + [0, 1]) +
74 | list(np.round(np.linspace(i, j, 100)).astype(np.int32) + [0, -1]) +
75 | list(np.round(np.linspace(i, j, 100)).astype(np.int32) + [1, 0]) +
76 | list(np.round(np.linspace(i, j, 100)).astype(np.int32) + [-1, 0])
77 | )
78 | interpolation.append(part_interpolation)
79 | return interpolation
80 | def process_image_interpolated(self, imgs: np.array):
81 | # imgs should be numpy b x c x h x w
82 | imgs = imgs.flip([1]) # face_alignment works with BGR colorspace
83 | faces = self.fa.face_detector.detect_from_batch(imgs)
84 | # faces = list(filter(lambda face: face[-1] > self.face_confidence, faces))
85 | faces = list(map(lambda img: list(filter(lambda face: face[-1] > self.face_confidence, img)), faces))
86 | batched_points = self.fa.get_landmarks_from_batch(imgs, detected_faces=faces)
87 | seg_mask = np.zeros((imgs.shape[0], *imgs.shape[2:]))
88 | if batched_points is None:
89 | return seg_mask
90 | for i, points in enumerate(batched_points):
91 | for face in points:
92 | face = self.interpolate_face(face.astype(int))
93 | for class_id in range(self.n_classes):
94 | for point in face[class_id]:
95 | try:
96 | seg_mask[i, point[1], point[0]] = class_id + 1
97 | except IndexError as e:
98 | # Probably only part of the face on the image
99 | pass
100 | boxes = [[face[:-1] for face in faces_in_image] for faces_in_image in faces]
101 | return seg_mask, boxes # F.one_hot(seg_mask.to(torch.long), num_classes=6)[..., 1:].permute(2, 0, 1)
102 |
103 | def draw_interpolated_face(self,mask, face):
104 | for class_id in range(self.n_classes):
105 | start, stop = self.class_idxs[class_id]
106 | if class_id not in (EYE, BROW):
107 | cv2.drawContours(mask, [face[start: stop]], 0 , (class_id+1), 1)
108 | else:
109 | step = (stop-start)//2
110 | cv2.drawContours(mask, [face[start:start+step]], 0, (class_id+1), 1)
111 | cv2.drawContours(mask, [face[start+step: stop]], 0, (class_id+1), 1)
112 | return mask
113 |
114 | def process_image_interpolated_fast(self, imgs: np.array):
115 | # imgs should be numpy b x c x h x w
116 | t = time()
117 | times = []
118 | imgs = imgs.flip([1]) # face_alignment works with BGR colorspace
119 | faces = self.fa.face_detector.detect_from_batch(imgs)
120 | times.append([time()-t])
121 | t = time()
122 | batched_points = self.fa.get_landmarks_from_batch(imgs, detected_faces=faces)
123 | seg_mask = np.zeros((imgs.shape[0], *imgs.shape[2:]))
124 | times.append([time()-t])
125 | t = time()
126 | for image_idx, points in enumerate(batched_points):
127 | for face in points:
128 | seg_mask[image_idx] = self.draw_interpolated_face(seg_mask[image_idx], face.astype(int))
129 | boxes = [[face[:-1] for face in faces_in_image] for faces_in_image in faces]
130 | times.append([time()-t])
131 | t = time()
132 | #print("HERE!!!", times)
133 | return seg_mask, boxes # F.one_hot(seg_mask.to(torch.long), num_classes=6)[..., 1:].permute(2, 0, 1)
134 |
135 | def process_image_interpolated_old(self, img):
136 | img = img[:, :, ::-1] # face_alignment work with BGR colorspace
137 | faces = self.fa.face_detector.detect_from_image(img.copy())
138 | faces = list(filter(lambda face: face[-1] > self.face_confidence, faces))
139 | points = self.fa.get_landmarks(img, detected_faces=faces)
140 | seg_mask = np.zeros(img.shape[:-1])
141 | if points is None:
142 | return seg_mask
143 | for face in points:
144 | face = self.interpolate_face(face.astype(int))
145 | for class_id in range(self.n_classes):
146 | for point in face[class_id]:
147 | try:
148 | seg_mask[point[1], point[0]] = class_id + 1
149 | except IndexError as e:
150 | # Probably only part of the face on the image
151 | pass
152 | boxes = [face[:-1] for face in faces]
153 | return seg_mask, boxes # F.one_hot(seg_mask.to(torch.long), num_classes=6)[..., 1:].permute(2, 0, 1)
154 |
155 | def plot_face(self, seg_mask: torch.Tensor):
156 | plt.imshow(seg_mask.clamp(0, 1).detach().cpu().numpy(), cmap="gray")
157 | plt.show()
158 |
159 | def __call__(self, img):
160 | data = {}
161 | mask, boxes = self.process_image_interpolated_fast(img)
162 | # mask, boxes = self.process_image_interpolated_old(img)
163 | data["seg_face"] = mask.astype(np.uint8)
164 | data["box_face"] = boxes
165 | return data
166 |
167 |
168 | if __name__ == "__main__":
169 | face_alignment_preprocessor = FaceAlignmentPreprocessor()
170 | img = cv2.imread("humans.jpg") # cv2 has other order of channels.
171 | print(img.shape)
172 | img = np.repeat(img.transpose(2, 0, 1)[None, ...], 5, axis=0)
173 | data = face_alignment_preprocessor(img)
174 | #face_alignment_preprocessor.plot_face(alignment)
175 | print(data["seg_face"].shape)
176 | print(len(data["box_face"]))
177 | print(data["box_face"])
178 | # torch.save(alignment, "alignment.pth")
179 |
--------------------------------------------------------------------------------
/Data/preprocessors/human_parts_preprocessor.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import cv2
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | import torchvision.transforms as transforms
7 | from torchgeometry import warp_affine
8 |
9 | HUMAN_PARSER_DIR = "/home/ubuntu/MakeAScene/Self-Correction-Human-Parsing"
10 | sys.path.append(HUMAN_PARSER_DIR)
11 | import networks
12 | from simple_extractor import get_palette, dataset_settings
13 | from collections import OrderedDict
14 | from utils.transforms import get_affine_transform
15 | from .edge_extractor import get_edges
16 | import torchvision
17 |
18 |
19 | def transform_logits(logits, center, scale, width, height, input_size):
20 | trans = torch.Tensor(get_affine_transform(center, scale, 0, input_size, inv=1)).expand(logits.shape[0], -1, -1)
21 | target_logits = warp_affine(logits, trans, (int(width), int(height)))
22 | return target_logits
23 |
24 |
25 | class HumanPartsPreprocessor:
26 | proc_type = "human"
27 | def __init__(
28 | self,
29 | weights=HUMAN_PARSER_DIR + "/checkpoints/final.pth",
30 | device="cuda",
31 | ):
32 | self.device = device
33 |
34 | dataset_info = dataset_settings["lip"]
35 | self.num_classes = dataset_info["num_classes"]
36 | self.input_size = dataset_info["input_size"]
37 | self.model = networks.init_model(
38 | "resnet101", num_classes=self.num_classes, pretrained=None
39 | )
40 | self.aspect_ratio = self.input_size[1] * 1.0 / self.input_size[0]
41 |
42 | state_dict = torch.load(weights)["state_dict"]
43 | new_state_dict = OrderedDict()
44 | for k, v in state_dict.items():
45 | name = k[7:] # remove `module.`
46 | new_state_dict[name] = v
47 | self.model.load_state_dict(new_state_dict)
48 | self.model.eval()
49 | self.model.to(device)
50 |
51 | self.transform = transforms.Compose(
52 | [
53 | # transforms.ToTensor(),
54 | transforms.Normalize(
55 | mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]
56 | ),
57 | ]
58 | )
59 | self.upsample = torch.nn.Upsample(
60 | size=self.input_size, mode="bilinear", align_corners=True
61 | )
62 |
63 | def _box2cs(self, box):
64 | x, y, w, h = box[:4]
65 | return self._xywh2cs(x, y, w, h)
66 |
67 | def _xywh2cs(self, x, y, w, h):
68 | center = np.zeros((2), dtype=np.float32)
69 | center[0] = x + w * 0.5
70 | center[1] = y + h * 0.5
71 | if w > self.aspect_ratio * h:
72 | h = w * 1.0 / self.aspect_ratio
73 | elif w < self.aspect_ratio * h:
74 | w = h * self.aspect_ratio
75 | scale = np.array([w, h], dtype=np.float32)
76 | return center, scale
77 |
78 | def segment_image(self, imgs: np.array):
79 | # imgs should be numpy b x c x h x w
80 | #imgs = torch.Tensor(imgs)
81 | b, _, h, w = imgs.shape # we can assume b x 3 x 512 x 512
82 | # check if h, w in correct order
83 | # Get person center and scale
84 | #person_center, scale = self._box2cs([0, 0, w - 1, h - 1])
85 | #c = person_center
86 | #s = scale
87 | #r = 0
88 | #trans = torch.Tensor(get_affine_transform(person_center, s, r, self.input_size)).expand(b, -1, -1)
89 | #imgs = warp_affine(imgs, trans, (int(self.input_size[1]), int(self.input_size[0])))
90 | imgs = torchvision.transforms.functional.resize(imgs, [self.input_size[1], self.input_size[0]])
91 |
92 | imgs = self.transform(imgs/255.).to(self.device)
93 |
94 | with torch.no_grad():
95 | output = self.model(imgs)
96 | upsample_output = self.upsample(output[0][-1]) # reshapes from 1, 20, 119, 119 to 1, 20, 473, 473
97 |
98 | #logits_result = transform_logits(upsample_output.cpu(), c, s, w, h, input_size=self.input_size)
99 | logits_result = torchvision.transforms.functional.resize(upsample_output, [h, w])
100 | mask = logits_result.argmax(dim=1)
101 | return mask.cpu().numpy()
102 |
103 | def __call__(self, imgs):
104 | data = {}
105 | mask = self.segment_image(imgs)
106 | edges = get_edges(mask)
107 | data["seg_human"] = mask.astype(np.uint8)
108 | data["edges"] = edges.astype(bool)
109 | return data
110 |
111 |
112 | if __name__ == "__main__":
113 | human_processor = HumanPartsPreprocessor()
114 | img = cv2.imread("humans.jpg")
115 | img = torch.randint_like(torch.Tensor(img), 0, 255, dtype=torch.float).permute(2, 0, 1).expand(5, -1, -1, -1).numpy()
116 | masks = human_processor(img)
117 | print(masks)
118 |
--------------------------------------------------------------------------------
/Data/utils.py:
--------------------------------------------------------------------------------
1 |
2 | def check_bbox(bbox):
3 | """Check if bbox minimums are lesser then maximums"""
4 | x_min, y_min, x_max, y_max = bbox[:4]
5 | bbox[0] = 0 if x_min < 0 else x_min
6 | bbox[1] = 0 if y_min < 0 else y_min
7 | bbox[2] = 511 if x_max >511 else x_max
8 | bbox[3] = 511 if y_max >511 else y_max
9 | if x_max <= x_min:
10 | return None
11 | if y_max <= y_min:
12 | return None
13 | if y_max - y_min < 16 or x_max - x_min < 16:
14 | # print(f"removing bbox: {bbox}")
15 | return None
16 |
17 | return bbox
18 |
19 |
20 | def check_bboxes(bboxes):
21 | """Check if bboxes boundaries are in range 0, 1 and minimums are lesser then maximums"""
22 | _bboxes = []
23 | for bbox in bboxes:
24 | check = check_bbox(bbox)
25 | if check is not None:
26 | _bboxes.append(bbox)
27 | # else:
28 | # print(f"Removing bbox: {bbox}")
29 | return _bboxes
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Casual GAN Papers
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Make-A-Scene - PyTorch
2 | Pytorch implementation (unofficial) of Make-A-Scene: Scene-Based Text-to-Image Generation with Human Priors
3 |
4 |
5 |
6 | Figure 1. from paper
7 |
8 |
9 | ## Note: this is work in progress.
10 | We are at training stage! The process can be followed in the Discord-Channel on the LAION Discord https://discord.gg/DghvZDKu.
11 | The data preprocessing has been finished as well as training VQSEG. We are currently training VQIMG. Training checkpoints will be released soon with demos.
12 | The transformer implementation is in progess and will hopefully be started to train as soon as VQIMG finishes.
13 |
14 | ## Demo
15 | VQIMG: https://colab.research.google.com/drive/1SPyQ-epTsAOAu8BEohUokN4-b5RM_TnE?usp=sharing
16 |
17 | ## Paper Description
18 | Make-A-Scene modifies the VQGAN framework. It makes heavy use of using semantic segmentation maps for extra conditioning. This enables more influence on the generation process. Morever, it also conditions on text. The main improvements are the following:
19 | 1. Segmentation condition: separate VQVAE is trained (VQ-SEG) + loss modified to a weighted binary cross entropy. (3.4)
20 | 2. VQGAN training (VQ-IMG) is extended by Face-Loss & Object-Loss (3.3 & 3.5)
21 | 3. Classifier Guidance for the autoregressive transformer (3.7)
22 |
23 | ## Training Pipeline
24 |
25 |
26 | Figure 6. from paper
27 |
28 |
29 | ## What needs to be done?
30 | Refer to the different folders to see details.
31 | - [X] [VQ-SEG](https://github.com/CasualGANPapers/Make-A-Scene/tree/main/VQ-SEG)
32 | - [X] [VQ-IMG](https://github.com/CasualGANPapers/Make-A-Scene/tree/main/VQ-IMG)
33 | - [ ] [Transformer]()
34 | - [X] [Data Aggregation](https://github.com/CasualGANPapers/Make-A-Scene/tree/main/Data)
35 |
36 | ## Citation
37 | ```bibtex
38 | @misc{https://doi.org/10.48550/arxiv.2203.13131,
39 | doi = {10.48550/ARXIV.2203.13131},
40 | url = {https://arxiv.org/abs/2203.13131},
41 | author = {Gafni, Oran and Polyak, Adam and Ashual, Oron and Sheynin, Shelly and Parikh, Devi and Taigman, Yaniv},
42 | title = {Make-A-Scene: Scene-Based Text-to-Image Generation with Human Priors},
43 | publisher = {arXiv},
44 | year = {2022},
45 | copyright = {arXiv.org perpetual, non-exclusive license}
46 | }
47 | ```
48 |
--------------------------------------------------------------------------------
/conf/img_config.yaml:
--------------------------------------------------------------------------------
1 | mode: pretrain_image
2 | devices:
3 | - 0
4 | - 1
5 | - 2
6 | - 3
7 | - 4
8 | - 5
9 | - 6
10 | - 7
11 |
12 | total_steps: 800000
13 | accumulate_grad: 8
14 | resume: False
15 | checkpoint: ./outputs/2022-04-03/19-00-36/checkpoint.pt
16 | log_period: 50
17 | batch_size: 2 # 192 for 256 model and 128 for 512 model
18 |
19 | model:
20 | _target_: models.VQBASE
21 | embed_dim: 256
22 | n_embed: 8192
23 | init_steps: 3000
24 | reservoir_size: 12500 # 2e5 / 8
25 | ddconfig:
26 | z_channels: 256
27 | in_channels: 3
28 | out_channels: 3
29 | channels: [128, 128, 128, 256, 512, 512] # [1, 1, 2, 4, 4]
30 | num_res_blocks: 2
31 | resolution: 512
32 | attn_resolutions:
33 | - 32
34 | dropout: 0.0
35 |
36 | optimizer:
37 | vq:
38 | lr: 5e-6
39 | betas:
40 | - 0.5
41 | - 0.9
42 | disc:
43 | lr: 4.5e-6
44 | betas:
45 | - 0.5
46 | - 0.9
47 |
48 | dataset:
49 | _target_: Data.dataset_preprocessor_web.S3ProcessedDataset
50 | resampled: True
51 | names:
52 | - cc3m
53 | - cc12m
54 | # path: file:D:/PycharmProjects/Make-A-Scene/server/Make-A-Scene/dataset/coco/{00000..00004}.tar
55 | # path: file:D:/PycharmProjects/Make-A-Scene/server/Make-A-Scene/dataset/coco/great_dataset.tar
56 |
57 | loss:
58 | #_target_: losses.VQVAEWithBCELoss
59 | _target_: losses.loss_img.VQLPIPSWithDiscriminator
60 | disc_start: 250001
61 | disc_weight: 0.8
62 | codebook_weight: 1.0
63 |
64 | dataloader:
65 | batch_size: ${batch_size}
66 | num_workers: 8
67 | pin_memory: True
68 |
69 | hydra:
70 | job:
71 | chdir: True
72 | run:
73 | dir: ./outputs/${mode}/${now:%Y-%m-%d}/${now:%H-%M-%S}
74 |
--------------------------------------------------------------------------------
/conf/preprocess_data.yaml:
--------------------------------------------------------------------------------
1 | mode: preprocess_dataset
2 | root: "/home/silent/hdd/nets/db/mydb"
3 | preprocessed: "/home/silent/hdd/nets/db/mydb/preprocessed"
4 |
5 | preprocessor:
6 | _target_: Data.BasePreprocessor
7 | preprocessed_folder: ${preprocessed}
8 | devices:
9 | - 0
10 | proc_per_gpu:
11 | panoptic: 0
12 | human: 0
13 | face: 0
14 | proc_per_cpu:
15 | panoptic: 1
16 | human: 1
17 | face: 1
18 |
19 | dataset:
20 | _target_: Data.dataset_preprocessor.ConcatDataset
21 | datasets:
22 | - ${coco2014}
23 | - ${coco2017}
24 |
25 | coco2014:
26 | _target_: Data.dataset_preprocessor.COCO2014Dataset
27 | root: ${root}
28 | preprocessed_folder: ${preprocessed}
29 |
30 | coco2017:
31 | _target_: Data.dataset_preprocessor.COCO2017Dataset
32 | root: ${root}
33 | preprocessed_folder: ${preprocessed}
34 |
35 |
36 | hydra:
37 | run:
38 | dir: ./outputs/${mode}/${now:%Y-%m-%d}/${now:%H-%M-%S}
39 |
--------------------------------------------------------------------------------
/conf/preprocess_data_web.yaml:
--------------------------------------------------------------------------------
1 | mode: preprocess_dataset
2 | #root: "/home/ubuntu/MakeAScene/data/coco/mscoco"
3 | root: "pipe:aws s3 cp s3://s-mas/laion-high-resolution/{00000..01500}.tar -"
4 | preprocessed: "/home/ubuntu/data/laion_en_tmp/"
5 | output_folder: "laion_en_processed"
6 |
7 | preprocessor:
8 | _target_: Data.WebPreprocessor
9 | preprocessed_folder: ${preprocessed}
10 | output_folder: ${output_folder}
11 | batch_size: 32
12 | num_workers: 2
13 | machines_total: 2
14 | machine_idx: 0
15 | devices:
16 | - 9
17 | proc_per_gpu:
18 | panoptic:
19 | - 4
20 | - 5
21 | - 6
22 | - 7
23 | human:
24 | - 2
25 | - 3
26 | face:
27 | - 0
28 | - 1
29 | proc_per_cpu:
30 | panoptic: 0
31 | human: 0
32 | face: 0
33 |
34 | dataset:
35 | _target_: Data.dataset_preprocessor_web.UnprocessedWebDataset
36 | root: ${root}
37 |
38 |
39 | hydra:
40 | run:
41 | dir: ./outputs/${mode}/${now:%Y-%m-%d}/${now:%H-%M-%S}
42 |
--------------------------------------------------------------------------------
/conf/seg_config.yaml:
--------------------------------------------------------------------------------
1 | mode: pretrain_segmentation
2 | devices:
3 | - 0
4 |
5 | total_steps: 6000000
6 | accumulate_grad: 3
7 | resume: False
8 | checkpoint: ./outputs/2022-04-03/19-00-36/checkpoint.pt
9 | log_period: 50
10 | batch_size: 2
11 |
12 | model:
13 | _target_: models.VQBASE
14 | embed_dim: 256
15 | n_embed: 1024
16 | ddconfig:
17 | double_z: false
18 | z_channels: 256
19 | resolution: 256
20 | in_channels: 159
21 | out_ch: 159
22 | ch: 128
23 | ch_mult:
24 | - 1
25 | - 1
26 | - 2
27 | - 2
28 | - 4
29 | num_res_blocks: 2
30 | attn_resolutions:
31 | - 16
32 | dropout: 0.0
33 |
34 | optimizer:
35 | lr: 4.5e-6
36 | betas:
37 | - 0.5
38 | - 0.9
39 |
40 | dataset:
41 | _target_: Data.dataset_preprocessor.COCO2014Dataset
42 | root: "D:\\PycharmProjects\\Make-A-Scene\\Data\\coco\\tmpdb_2\\"
43 | preprocessed_folder: "D:\\PycharmProjects\\Make-A-Scene\\Data\\coco\\tmpdb_2\\preprocessed_folder"
44 | force_preprocessing: False
45 |
46 | loss:
47 | #_target_: losses.VQVAEWithBCELoss
48 | _target_: losses.BCELossWithQuant
49 |
50 | dataloader:
51 | batch_size: 2
52 | num_workers: 8
53 | shuffle: True
54 | pin_memory: True
55 |
56 | hydra:
57 | run:
58 | dir: ./outputs/${mode}/${now:%Y-%m-%d}/${now:%H-%M-%S}
59 |
--------------------------------------------------------------------------------
/conf/show.yaml:
--------------------------------------------------------------------------------
1 | mode: show_segmentation
2 | devices:
3 | - 0
4 |
5 | checkpoint: ./outputs/pretrain_segmentation/2022-04-04/16-40-10/checkpoint.pt
6 |
7 | model:
8 | _target_: models.VQBASE
9 | embed_dim: 256
10 | n_embed: 1024
11 | image_key: "segmentation"
12 | #n_labels: 182
13 | ddconfig:
14 | double_z: false
15 | z_channels: 256
16 | resolution: 256
17 | in_channels: 159
18 | out_ch: 159
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | num_res_blocks: 2
27 | attn_resolutions:
28 | - 16
29 | dropout: 0.0
30 |
31 |
32 | dataset:
33 | _target_: Data.dataset_preprocessor.COCO2014Dataset
34 | root: "/path_to_coco"
35 | preprocessed_folder: "/path_to_preprocessed_folder"
36 | force_preprocessing: False
37 |
38 | hydra:
39 | run:
40 | dir: ./outputs/${mode}/${now:%Y-%m-%d}/${now:%H-%M-%S}
41 |
--------------------------------------------------------------------------------
/log_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.utils.tensorboard import SummaryWriter
5 | import torchvision.utils as vutils
6 | import os
7 |
8 |
9 | class Logger:
10 | def __init__(self, proc_id, log_dir=".", device="cuda"):
11 | self.proc_id = proc_id
12 | if proc_id != 0:
13 | return
14 | self.writer = SummaryWriter(log_dir)
15 | self.step = 0
16 | os.makedirs("./results")
17 |
18 | def log(self, step=None, img=None, img_rec=None, **kwargs):
19 | if self.proc_id != 0:
20 | return
21 | self.step = step if step is not None else self.step + 1
22 | for key in kwargs:
23 | self.writer.add_scalar(key, kwargs[key].detach().cpu().item(), self.step)
24 | if img is not None and img_rec is not None and self.step%500==0:
25 | img = img.detach().cpu()
26 | img_rec = img_rec.detach().cpu()
27 | pairs = torch.cat([img,img_rec]).detach().cpu()
28 | img_grid = vutils.make_grid(pairs)
29 | self.writer.add_image('samples', img_grid.detach().cpu(), global_step=step)
30 |
31 |
32 | class Visualizer:
33 | dims = {
34 | "panoptic": [0, 133],
35 | "human": [133, 153],
36 | "face": [153, 158],
37 | "edge": [158, 159]
38 | }
39 |
40 | def __init__(self, log_dir=".", device="cuda"):
41 | self.weights = {}
42 | for key in self.dims:
43 | size = self.dims[key][1] - self.dims[key][0]
44 | weight = torch.randn([3, size, 1, 1]).to(device)
45 | self.weights[key] = weight
46 | os.makedirs("./results")
47 |
48 | def log_images(self, seg, seg_rec):
49 | seg = self.colorize(seg)
50 | seg_rec = self.colorize(seg_rec, logits=True)
51 | both = torch.cat((seg, seg_rec))
52 | grid = vutils.make_grid(both, nrow=2)
53 | vutils.save_image(both, f"./results/{self.step}.jpg", nrow=4)
54 |
55 | def colorize(self, seg, logits=False):
56 | results = []
57 | for key in self.dims:
58 | seg_key = seg[:, self.dims[key][0]: self.dims[key][1]]
59 | if logits:
60 | n_classes = seg_key.shape[1]
61 | if "face" in key or "edge" in key:
62 | mask = seg_key.sigmoid() > 0.2
63 | seg_key = torch.argmax(seg_key, dim=1, keepdim=False)
64 | seg_key = F.one_hot(seg_key, num_classes=n_classes)
65 | seg_key = seg_key.permute(0, 3, 1, 2).float()
66 | if "face" in key or "edge" in key:
67 | seg_key *= mask
68 |
69 | weight = self.weights[key]
70 | with torch.no_grad():
71 | x = F.conv2d(seg_key, weight)
72 | x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
73 | results.append(x)
74 | return results
75 | seg = seg[:, :133]
76 | if logits:
77 | # seg = (seg.sigmoid()>0.35).to(torch.float)
78 | seg = torch.argmax(seg, dim=1, keepdim=False)
79 | seg = F.one_hot(seg, num_classes=133)
80 | seg = seg.squeeze(1).permute(0, 3, 1, 2).float()
81 |
82 | return x
83 |
84 | def __call__(self, step, image=None, seg=None, seg_rec=None):
85 | results = [image]
86 | if seg is not None:
87 | results.extend(self.colorize(seg))
88 | results.append(torch.zeros_like(image))
89 | if seg_rec is not None:
90 | results.extend(self.colorize(seg_rec, logits=True))
91 | results = torch.cat(results)
92 | vutils.save_image(results, f"./results/result_{step}.jpg", nrow=5)
93 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .loss_seg import BCELossWithQuant, VQVAEWithBCELoss
2 | from .loss_img import VQLPIPSWithDiscriminator
3 |
--------------------------------------------------------------------------------
/losses/discriminator.py:
--------------------------------------------------------------------------------
1 | """
2 | PatchGAN Discriminator (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538)
3 | """
4 |
5 | import torch.nn as nn
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find('Conv') != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find('BatchNorm') != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class Discriminator(nn.Module):
18 | def __init__(self, in_channels=3, num_filters_last=64, n_layers=3):
19 | super(Discriminator, self).__init__()
20 |
21 | layers = [nn.Conv2d(in_channels, num_filters_last, 4, 2, 1), nn.LeakyReLU(0.2)]
22 | num_filters_mult = 1
23 |
24 | for i in range(1, n_layers + 1):
25 | num_filters_mult_last = num_filters_mult
26 | num_filters_mult = min(2 ** i, 8)
27 | layers += [
28 | nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4,
29 | 2 if i < n_layers else 1, 1, bias=False),
30 | nn.BatchNorm2d(num_filters_last * num_filters_mult),
31 | nn.LeakyReLU(0.2, True)
32 | ]
33 |
34 | layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
35 | self.model = nn.Sequential(*layers)
36 |
37 | def forward(self, x):
38 | return self.model(x)
--------------------------------------------------------------------------------
/losses/face_loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Code taken from and modified https://github.com/cydonia999/VGGFace2-pytorch
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import transforms
8 | from torchvision.transforms.functional import crop
9 |
10 | __all__ = ['FaceLoss']
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1):
14 | """3x3 convolution with padding"""
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=1, bias=False)
17 |
18 |
19 | class Bottleneck(nn.Module):
20 | expansion = 4
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None):
23 | super(Bottleneck, self).__init__()
24 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
25 | self.bn1 = nn.BatchNorm2d(planes)
26 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
27 | self.bn2 = nn.BatchNorm2d(planes)
28 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
29 | self.bn3 = nn.BatchNorm2d(planes * 4)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.downsample = downsample
32 | self.stride = stride
33 |
34 | def forward(self, x):
35 | residual = x
36 |
37 | out = self.conv1(x)
38 | out = self.bn1(out)
39 | out = self.relu(out)
40 |
41 | out = self.conv2(out)
42 | out = self.bn2(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv3(out)
46 | out = self.bn3(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | class FaceLoss(nn.Module):
58 | def __init__(self):
59 | super(FaceLoss, self).__init__()
60 | layers = [3, 4, 6, 3]
61 | self.inplanes = 64
62 | self.alphas = [0.1, 0.25 * 0.01, 0.25 * 0.1, 0.25 * 0.2, 0.25 * 0.02]
63 |
64 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
65 | self.bn1 = nn.BatchNorm2d(64)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
68 |
69 | self.layer1 = self._make_layer(Bottleneck, 64, layers[0])
70 | self.layer2 = self._make_layer(Bottleneck, 128, layers[1], stride=2)
71 | self.layer3 = self._make_layer(Bottleneck, 256, layers[2], stride=2)
72 | self.layer4 = self._make_layer(Bottleneck, 512, layers[3], stride=2)
73 |
74 | self.channels = [64, 256, 512, 1024, 2048]
75 |
76 | self.load_state_dict(torch.load("/home/ubuntu/Make-A-Scene/losses/face_loss_weights.pt", map_location="cpu"), strict=False)
77 |
78 | for param in self.parameters():
79 | param.requires_grad = False
80 |
81 | self.face_transforms = nn.Sequential(
82 | transforms.Resize(256),
83 | transforms.CenterCrop(254)
84 | )
85 | self.eval()
86 |
87 | def _make_layer(self, block, planes, blocks, stride=1):
88 | downsample = None
89 | if stride != 1 or self.inplanes != planes * block.expansion:
90 | downsample = nn.Sequential(
91 | nn.Conv2d(self.inplanes, planes * block.expansion,
92 | kernel_size=1, stride=stride, bias=False),
93 | nn.BatchNorm2d(planes * block.expansion),
94 | )
95 |
96 | layers = []
97 | layers.append(block(self.inplanes, planes, stride, downsample))
98 | self.inplanes = planes * block.expansion
99 | for i in range(1, blocks):
100 | layers.append(block(self.inplanes, planes))
101 |
102 | return nn.Sequential(*layers)
103 |
104 | def _forward(self, x): # 224x224
105 | features = []
106 | x = self.conv1(x) # 112x112
107 | features.append(x)
108 | x = self.bn1(x)
109 | x = self.relu(x)
110 | x = self.maxpool(x) # 56x56 ignore
111 |
112 | x = self.layer1(x) # 56x56
113 | features.append(x)
114 | x = self.layer2(x) # 28x28
115 | features.append(x)
116 | x = self.layer3(x) # 14x14 ignore (maybe not)
117 | features.append(x)
118 | x = self.layer4(x) # 7x7
119 | features.append(x)
120 |
121 | return features
122 |
123 | def forward(self, img, rec, bbox):
124 | """
125 | Takes in original image and reconstructed image and feeds it through face network and takes the difference
126 | between the different resolutions and scales by alpha_{i}.
127 | Normalizing the features and applying spatial resolution was taken from LPIPS and wasn't mentioned in the paper.
128 | """
129 | faces = self.prepare_faces(img, rec, bbox)
130 | if faces is None:
131 | return img.new_tensor(0)
132 | faces = faces[:6] # otherwise it may cause oom error.
133 | features = self._forward(faces)
134 | features = [f.chunk(2) for f in features]
135 | # diffs = [a * torch.abs(p[0] - p[1]).sum() for a, p in zip(self.alphas, features)]
136 | diffs = [a * torch.abs(p[0] - p[1]).sum(dim=0).mean() for a, p in zip(self.alphas, features)]
137 | # diffs = [a*torch.abs(self.norm_tensor(tf) - self.norm_tensor(rf)) for a, tf, rf in zip(self.alphas, true_features, rec_features)]
138 |
139 | # diffs = [a * torch.mean(torch.abs(tf - rf)) for a, tf, rf in zip(self.alphas, features)]
140 | return sum(diffs)
141 | # return sum(diffs) / len(diffs)
142 |
143 | def prepare_faces(self, imgs, recs, bboxes):
144 | faces_gt = []
145 | faces_gen = []
146 | for img, rec, bboxes in zip(imgs, recs, bboxes):
147 | for bbox in bboxes:
148 | top = bbox[1]
149 | left = bbox[0]
150 | height = bbox[3] - bbox[1]
151 | width = bbox[2] - bbox[0]
152 | crop_img = crop(img, top, left, height, width)
153 | faces_gt.append(self.face_transforms(crop_img))
154 | crop_rec = crop(rec, top, left, height, width)
155 | faces_gen.append(self.face_transforms(crop_rec))
156 | if len(faces_gt) == 0:
157 | return None
158 | faces_gt = torch.stack(faces_gt, dim=0)
159 | faces_gen = torch.stack(faces_gen, dim=0)
160 | return torch.cat([faces_gt, faces_gen], dim=0)
161 |
162 |
163 |
164 | if __name__ == '__main__':
165 | model = FaceLoss()
166 | # x = torch.randn(1, 3, 256, 256)
167 | # x_rec = torch.randn(1, 3, 256, 256)
168 | x = torch.randn(2, 3, 101, 101)
169 | x_rec = torch.randn(2, 3, 101, 101)
170 | print(model.forward(x, x_rec))
171 |
--------------------------------------------------------------------------------
/losses/loss_img.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .lpips import LPIPS
5 | from .lpips_with_object import LPIPSWithObject
6 | from .discriminator import Discriminator, weights_init
7 | from .face_loss import FaceLoss
8 | from torchvision.transforms.functional import crop
9 |
10 |
11 | def adopt_weight(weight, global_step, threshold=0, value=0.0):
12 | if global_step < threshold:
13 | weight = value
14 | return weight
15 |
16 |
17 | def hinge_d_loss(logits_real, logits_fake):
18 | loss_real = torch.mean(F.relu(1.0 - logits_real))
19 | loss_fake = torch.mean(F.relu(1.0 + logits_fake))
20 | d_loss = 0.5 * (loss_real + loss_fake)
21 | return d_loss
22 |
23 |
24 | def vanilla_d_loss(logits_real, logits_fake):
25 | d_loss = 0.5 * (
26 | torch.mean(torch.nn.functional.softplus(-logits_real))
27 | + torch.mean(torch.nn.functional.softplus(logits_fake))
28 | )
29 | return d_loss
30 |
31 |
32 | class VQLPIPSWithDiscriminator(nn.Module):
33 | def __init__(
34 | self,
35 | disc_start,
36 | codebook_weight=1.0,
37 | pixelloss_weight=1.0,
38 | disc_factor=1.0,
39 | disc_weight=1.0,
40 | perceptual_weight=1.0,
41 | ):
42 | super().__init__()
43 | self.codebook_weight = codebook_weight
44 | self.pixel_weight = pixelloss_weight
45 | self.perceptual_loss = LPIPSWithObject().eval()
46 | self.perceptual_weight = perceptual_weight
47 |
48 | self.face_loss = FaceLoss()
49 | #self.object_loss = self.perceptual_loss
50 |
51 | self.discriminator = Discriminator().apply(weights_init)
52 | self.discriminator_iter_start = disc_start
53 | self.disc_factor = disc_factor
54 | self.discriminator_weight = disc_weight
55 |
56 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
57 | nll_grads = torch.autograd.grad(nll_loss, last_layer.weight, retain_graph=True)[
58 | 0
59 | ]
60 | g_grads = torch.autograd.grad(g_loss, last_layer.weight, retain_graph=True)[0]
61 |
62 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
63 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
64 | d_weight = d_weight * self.discriminator_weight
65 | return d_weight
66 |
67 | def forward(
68 | self,
69 | optimizer_idx,
70 | global_step,
71 | images,
72 | reconstructions,
73 | codebook_loss=None,
74 | bbox_obj=None,
75 | bbox_face=None,
76 | last_layer=None,
77 | ):
78 | if optimizer_idx == 0: # vqvae loss
79 | rec_loss = torch.abs(images.contiguous() - reconstructions.contiguous())
80 | p_loss = self.perceptual_loss(
81 | images.contiguous(), reconstructions.contiguous(), bbox_obj
82 | )
83 | rec_loss = rec_loss + self.perceptual_weight * p_loss
84 |
85 | nll_loss = rec_loss
86 | nll_loss = torch.mean(nll_loss)
87 |
88 | face_loss = self.face_loss(images, reconstructions, bbox_face)
89 |
90 | object_loss = images.new_tensor(0)
91 | #for img, rec, bboxes in zip(images, reconstructions, bbox_obj):
92 | # img_object_loss = img.new_tensor(0)
93 | # for bbox in bboxes:
94 | # # xmin, ymin, xmax, ymax
95 | # top = bbox[1]
96 | # left = bbox[0]
97 | # height = bbox[3] - bbox[1]
98 | # width = bbox[2] - bbox[0]
99 | # crop_img = crop(img, top, left, height, width).unsqueeze(
100 | # 0
101 | # ) # bbox needs to be [x, y, height, width]
102 | # crop_rec = crop(rec, top, left, height, width).unsqueeze(0)
103 | # img_object_loss += self.object_loss(
104 | # crop_img.contiguous(), crop_rec.contiguous()
105 | # ).mean() # TODO: check if crops are actually correct
106 | # object_loss += img_object_loss/(len(bboxes)+1)
107 |
108 | logits_fake = self.discriminator(
109 | reconstructions.contiguous()
110 | ) # cont not necessary
111 | g_loss = -torch.mean(logits_fake)
112 |
113 | d_weight = self.calculate_adaptive_weight(
114 | nll_loss, g_loss, last_layer=last_layer
115 | )
116 |
117 | disc_factor = adopt_weight(
118 | self.disc_factor, global_step, threshold=self.discriminator_iter_start
119 | )
120 |
121 | loss = (
122 | nll_loss
123 | + d_weight * disc_factor * g_loss
124 | + self.codebook_weight * codebook_loss.mean()
125 | + face_loss
126 | + object_loss
127 | )
128 | # 0.001 * face_loss
129 | return loss, (nll_loss, object_loss, face_loss)
130 | # return loss, (nll_loss, images.new_tensor(0), images.new_tensor(0))
131 |
132 | if optimizer_idx == 1: # gan loss
133 | disc_factor = adopt_weight(
134 | self.disc_factor, global_step, threshold=self.discriminator_iter_start
135 | )
136 |
137 | logits_real = self.discriminator(images.contiguous().detach())
138 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
139 |
140 | d_loss = disc_factor * hinge_d_loss(logits_real, logits_fake)
141 | return d_loss
142 |
--------------------------------------------------------------------------------
/losses/loss_seg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class BCELossWithQuant(nn.Module):
7 | def __init__(self, image_channels=159, codebook_weight=1.0):
8 | super().__init__()
9 | self.codebook_weight = codebook_weight
10 | self.register_buffer(
11 | "weight",
12 | torch.ones(image_channels).index_fill(0, torch.arange(153, 158), 20),
13 | )
14 |
15 | def forward(self, qloss, target, prediction):
16 | bce_loss = F.binary_cross_entropy_with_logits(
17 | prediction.permute(0, 2, 3, 1),
18 | target.permute(0, 2, 3, 1),
19 | pos_weight=self.weight,
20 | )
21 | loss = bce_loss + self.codebook_weight * qloss
22 | return loss
23 |
24 |
25 | class VQVAEWithBCELoss(nn.Module):
26 | def __init__(self, image_channels=159, codebook_weight=1.0):
27 | super().__init__()
28 | self.codebook_weight = codebook_weight
29 | self.register_buffer(
30 | "weight",
31 | torch.ones(image_channels).index_fill(0, torch.arange(153, 158), 20),
32 | )
33 |
34 | def forward(self, qloss, target, prediction):
35 | bce_mse_loss = F.mse_loss(prediction.sigmoid(), target) + F.binary_cross_entropy_with_logits(
36 | prediction.permute(0, 2, 3, 1),
37 | target.permute(0, 2, 3, 1),
38 | pos_weight=self.weight,
39 | )
40 | loss = bce_mse_loss + self.codebook_weight * qloss
41 | return loss
42 |
--------------------------------------------------------------------------------
/losses/lpips.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torchvision.models import vgg16
5 | from collections import namedtuple
6 | import requests
7 | from tqdm import tqdm
8 |
9 |
10 | URL_MAP = {
11 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
12 | }
13 |
14 | CKPT_MAP = {
15 | "vgg_lpips": "/home/ubuntu/Make-A-Scene/weights/vgg.pth"
16 | }
17 |
18 |
19 | def download(url, local_path, chunk_size=1024):
20 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
21 | with requests.get(url, stream=True) as r:
22 | total_size = int(r.headers.get("content-length", 0))
23 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
24 | with open(local_path, "wb") as f:
25 | for data in r.iter_content(chunk_size=chunk_size):
26 | if data:
27 | f.write(data)
28 | pbar.update(chunk_size)
29 |
30 |
31 | def get_ckpt_path(name, root):
32 | assert name in URL_MAP
33 | #path = os.path.join(root, CKPT_MAP[name])
34 | path = os.path.join(CKPT_MAP[name])
35 | if not os.path.exists(path):
36 | print(f"Downloading {name} model from {URL_MAP[name]} to {path}")
37 | download(URL_MAP[name], path)
38 | return path
39 |
40 |
41 | class LPIPS(nn.Module):
42 | def __init__(self):
43 | super(LPIPS, self).__init__()
44 | self.scaling_layer = ScalingLayer()
45 | self.channels = [64, 128, 256, 512, 512]
46 | self.vgg = VGG16()
47 | self.lin0 = NetLinLayer(self.channels[0])
48 | self.lin1 = NetLinLayer(self.channels[1])
49 | self.lin2 = NetLinLayer(self.channels[2])
50 | self.lin3 = NetLinLayer(self.channels[3])
51 | self.lin4 = NetLinLayer(self.channels[4])
52 | self.load_from_pretrained()
53 | self.lins = [
54 | self.lin0,
55 | self.lin1,
56 | self.lin2,
57 | self.lin3,
58 | self.lin4
59 | ]
60 |
61 | for param in self.parameters():
62 | param.requires_grad = False
63 |
64 | def load_from_pretrained(self, name="vgg_lpips"):
65 | ckpt = get_ckpt_path(name, "vgg_lpips")
66 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
67 |
68 | def forward(self, real_x, fake_x):
69 | features_real = self.vgg(self.scaling_layer(real_x))
70 | features_fake = self.vgg(self.scaling_layer(fake_x))
71 | diffs = {}
72 |
73 | for i in range(len(self.channels)):
74 | diffs[i] = (norm_tensor(features_real[i]) - norm_tensor(features_fake[i])) ** 2
75 |
76 | return sum([spatial_average(self.lins[i].model(diffs[i])) for i in range(len(self.channels))])
77 |
78 |
79 | class ScalingLayer(nn.Module):
80 | def __init__(self):
81 | super(ScalingLayer, self).__init__()
82 | self.register_buffer("shift", torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
83 | self.register_buffer("scale", torch.Tensor([.458, .448, .450])[None, :, None, None])
84 |
85 | def forward(self, x):
86 | return (x - self.shift) / self.scale
87 |
88 |
89 | class NetLinLayer(nn.Module):
90 | def __init__(self, in_channels, out_channels=1):
91 | super(NetLinLayer, self).__init__()
92 | self.model = nn.Sequential(
93 | nn.Dropout(),
94 | nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
95 | )
96 |
97 |
98 | class VGG16(nn.Module):
99 | def __init__(self):
100 | super(VGG16, self).__init__()
101 | vgg_pretrained_features = vgg16(pretrained=True).features
102 | slices = [vgg_pretrained_features[i] for i in range(30)]
103 | self.slice1 = nn.Sequential(*slices[0:4])
104 | self.slice2 = nn.Sequential(*slices[4:9])
105 | self.slice3 = nn.Sequential(*slices[9:16])
106 | self.slice4 = nn.Sequential(*slices[16:23])
107 | self.slice5 = nn.Sequential(*slices[23:30])
108 |
109 | for param in self.parameters():
110 | param.requires_grad = False
111 |
112 | def forward(self, x):
113 | h = self.slice1(x)
114 | h_relu1 = h
115 | h = self.slice2(h)
116 | h_relu2 = h
117 | h = self.slice3(h)
118 | h_relu3 = h
119 | h = self.slice4(h)
120 | h_relu4 = h
121 | h = self.slice5(h)
122 | h_relu5 = h
123 | vgg_outputs = namedtuple("VGGOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
124 | return vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
125 |
126 |
127 | def norm_tensor(x):
128 | """
129 | Normalize images by their length to make them unit vector?
130 | :param x: batch of images
131 | :return: normalized batch of images
132 | """
133 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
134 | return x / (norm_factor + 1e-10)
135 |
136 |
137 | def spatial_average(x):
138 | """
139 | imgs have: batch_size x channels x width x height --> average over width and height channel
140 | :param x: batch of images
141 | :return: averaged images along width and height
142 | """
143 | return x.mean([2, 3], keepdim=True)
144 |
145 |
--------------------------------------------------------------------------------
/losses/lpips_with_object.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Function
5 | from torchvision.models import vgg16
6 | from collections import namedtuple
7 | import requests
8 | from tqdm import tqdm
9 | from .lpips import LPIPS
10 |
11 |
12 | class WeightThingGrad(Function):
13 | def forward(ctx, input, bboxes):
14 | weights = torch.ones_like(input)
15 | for weight, img_bboxes in zip(weights, bboxes):
16 | for x_min, y_min, x_max, y_max in img_bboxes:
17 | weight[:, x_min:x_max, y_min:y_max]
18 |
19 | ctx.save_for_backward(weights)
20 | return input
21 |
22 | def backward(ctx, grad_output):
23 | weights = ctx.saved_tensors[0]
24 | return grad_output*weights, None
25 |
26 | weight_thing_grad = WeightThingGrad.apply
27 |
28 | class LPIPSWithObject(LPIPS):
29 | def forward(self, real_x, fake_x, object_boxes):
30 | fake_x = weight_thing_grad(fake_x, object_boxes)
31 | return super().forward(real_x, fake_x)
32 |
33 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .vqvae import VQBASE
2 |
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | # Taken from https://github.com/CompVis/taming-transformers
2 | # pytorch_diffusion + derived encoder decoder
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | from fast_pytorch_kmeans import KMeans
9 | from torch import einsum
10 | import torch.distributed as dist
11 | from einops import rearrange
12 |
13 |
14 | def get_timestep_embedding(timesteps, embedding_dim):
15 | """
16 | This matches the implementation in Denoising Diffusion Probabilistic Models:
17 | From Fairseq.
18 | Build sinusoidal embeddings.
19 | This matches the implementation in tensor2tensor, but differs slightly
20 | from the description in Section 3.5 of "Attention Is All You Need".
21 | """
22 | assert len(timesteps.shape) == 1
23 |
24 | half_dim = embedding_dim // 2
25 | emb = math.log(10000) / (half_dim - 1)
26 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
27 | emb = emb.to(device=timesteps.device)
28 | emb = timesteps.float()[:, None] * emb[None, :]
29 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
30 | if embedding_dim % 2 == 1: # zero pad
31 | emb = torch.nn.functional.pad(emb, (0,1,0,0))
32 | return emb
33 |
34 |
35 | def nonlinearity(x):
36 | # swish
37 | return x*torch.sigmoid(x)
38 |
39 |
40 | def Normalize(in_channels):
41 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
42 |
43 |
44 | class Upsample(nn.Module):
45 | def __init__(self, in_channels, with_conv):
46 | super().__init__()
47 | self.with_conv = with_conv
48 | if self.with_conv:
49 | self.conv = torch.nn.Conv2d(in_channels,
50 | in_channels,
51 | kernel_size=3,
52 | stride=1,
53 | padding=1)
54 |
55 | def forward(self, x):
56 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
57 | if self.with_conv:
58 | x = self.conv(x)
59 | return x
60 |
61 |
62 | class Downsample(nn.Module):
63 | def __init__(self, in_channels, with_conv):
64 | super().__init__()
65 | self.with_conv = with_conv
66 | if self.with_conv:
67 | # no asymmetric padding in torch conv, must do it ourselves
68 | self.conv = torch.nn.Conv2d(in_channels,
69 | in_channels,
70 | kernel_size=3,
71 | stride=2,
72 | padding=0)
73 |
74 | def forward(self, x):
75 | if self.with_conv:
76 | pad = (0,1,0,1)
77 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
78 | x = self.conv(x)
79 | else:
80 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
81 | return x
82 |
83 |
84 | class ResnetBlock(nn.Module):
85 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
86 | super().__init__()
87 | self.in_channels = in_channels
88 | out_channels = in_channels if out_channels is None else out_channels
89 | self.out_channels = out_channels
90 | self.use_conv_shortcut = conv_shortcut
91 |
92 | self.norm1 = Normalize(in_channels)
93 | self.conv1 = torch.nn.Conv2d(in_channels,
94 | out_channels,
95 | kernel_size=3,
96 | stride=1,
97 | padding=1)
98 | self.norm2 = Normalize(out_channels)
99 | self.dropout = torch.nn.Dropout(dropout)
100 | self.conv2 = torch.nn.Conv2d(out_channels,
101 | out_channels,
102 | kernel_size=3,
103 | stride=1,
104 | padding=1)
105 | if self.in_channels != self.out_channels:
106 | if self.use_conv_shortcut:
107 | self.conv_shortcut = torch.nn.Conv2d(in_channels,
108 | out_channels,
109 | kernel_size=3,
110 | stride=1,
111 | padding=1)
112 | else:
113 | self.nin_shortcut = torch.nn.Conv2d(in_channels,
114 | out_channels,
115 | kernel_size=1,
116 | stride=1,
117 | padding=0)
118 |
119 | def forward(self, x):
120 | h = x
121 | h = self.norm1(h)
122 | h = nonlinearity(h)
123 | h = self.conv1(h)
124 |
125 | h = self.norm2(h)
126 | h = nonlinearity(h)
127 | h = self.dropout(h)
128 | h = self.conv2(h)
129 |
130 | if self.in_channels != self.out_channels:
131 | if self.use_conv_shortcut:
132 | x = self.conv_shortcut(x)
133 | else:
134 | x = self.nin_shortcut(x)
135 |
136 | return x+h
137 |
138 |
139 | class AttnBlock(nn.Module):
140 | def __init__(self, in_channels):
141 | super().__init__()
142 | self.in_channels = in_channels
143 |
144 | self.norm = Normalize(in_channels)
145 | self.q = torch.nn.Conv2d(in_channels,
146 | in_channels,
147 | kernel_size=1,
148 | stride=1,
149 | padding=0)
150 | self.k = torch.nn.Conv2d(in_channels,
151 | in_channels,
152 | kernel_size=1,
153 | stride=1,
154 | padding=0)
155 | self.v = torch.nn.Conv2d(in_channels,
156 | in_channels,
157 | kernel_size=1,
158 | stride=1,
159 | padding=0)
160 | self.proj_out = torch.nn.Conv2d(in_channels,
161 | in_channels,
162 | kernel_size=1,
163 | stride=1,
164 | padding=0)
165 |
166 |
167 | def forward(self, x):
168 | h_ = x
169 | h_ = self.norm(h_)
170 | q = self.q(h_)
171 | k = self.k(h_)
172 | v = self.v(h_)
173 |
174 | # compute attention
175 | b,c,h,w = q.shape
176 | q = q.reshape(b,c,h*w)
177 | q = q.permute(0,2,1) # b,hw,c
178 | k = k.reshape(b,c,h*w) # b,c,hw
179 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
180 | w_ = w_ * (int(c)**(-0.5))
181 | w_ = torch.nn.functional.softmax(w_, dim=2)
182 |
183 | # attend to values
184 | v = v.reshape(b,c,h*w)
185 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
186 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
187 | h_ = h_.reshape(b,c,h,w)
188 |
189 | h_ = self.proj_out(h_)
190 |
191 | return x+h_
192 |
193 |
194 | class Swish(nn.Module):
195 | def forward(self, x):
196 | return x * torch.sigmoid(x)
197 |
198 |
199 | class Encoder(nn.Module):
200 | """
201 | Encoder of VQ-GAN to map input batch of images to latent space.
202 | Dimension Transformations:
203 | 3x256x256 --Conv2d--> 32x256x256
204 | for loop:
205 | --ResBlock--> 64x256x256 --DownBlock--> 64x128x128
206 | --ResBlock--> 128x128x128 --DownBlock--> 128x64x64
207 | --ResBlock--> 256x64x64 --DownBlock--> 256x32x32
208 | --ResBlock--> 512x32x32
209 | --ResBlock--> 512x32x32
210 | --NonLocalBlock--> 512x32x32
211 | --ResBlock--> 512x32x32
212 | --GroupNorm-->
213 | --Swish-->
214 | --Conv2d-> 256x32x32
215 | """
216 |
217 | def __init__(self, in_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs):
218 | super(Encoder, self).__init__()
219 | layers = [nn.Conv2d(in_channels, channels[0], 3, 1, 1)]
220 | for i in range(len(channels) - 1):
221 | in_channels = channels[i]
222 | out_channels = channels[i + 1]
223 | for j in range(num_res_blocks):
224 | layers.append(ResnetBlock(in_channels=in_channels, out_channels=out_channels, dropout=0.0))
225 | in_channels = out_channels
226 | if resolution in attn_resolutions:
227 | layers.append(AttnBlock(in_channels))
228 | if i < len(channels) - 2:
229 | layers.append(Downsample(channels[i + 1], with_conv=True))
230 | resolution //= 2
231 | layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0))
232 | layers.append(AttnBlock(channels[-1]))
233 | layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0))
234 | layers.append(Normalize(channels[-1]))
235 | layers.append(Swish())
236 | layers.append(nn.Conv2d(channels[-1], z_channels, 3, 1, 1))
237 | self.model = nn.Sequential(*layers)
238 |
239 | def forward(self, x):
240 | return self.model(x)
241 |
242 |
243 | # class Encoder(nn.Module):
244 | # def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
245 | # attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
246 | # resolution, z_channels, double_z=True, **ignore_kwargs):
247 | # super().__init__()
248 | # self.ch = ch
249 | # self.temb_ch = 0
250 | # self.num_resolutions = len(ch_mult)
251 | # self.num_res_blocks = num_res_blocks
252 | # self.resolution = resolution
253 | # self.in_channels = in_channels
254 | #
255 | # # downsampling
256 | # self.conv_in = torch.nn.Conv2d(in_channels,
257 | # self.ch,
258 | # kernel_size=3,
259 | # stride=1,
260 | # padding=1)
261 | #
262 | # curr_res = resolution
263 | # in_ch_mult = (1,)+tuple(ch_mult)
264 | # self.down = nn.ModuleList()
265 | # for i_level in range(self.num_resolutions):
266 | # block = nn.ModuleList()
267 | # attn = nn.ModuleList()
268 | # block_in = ch*in_ch_mult[i_level]
269 | # block_out = ch*ch_mult[i_level]
270 | # for i_block in range(self.num_res_blocks):
271 | # block.append(ResnetBlock(in_channels=block_in,
272 | # out_channels=block_out,
273 | # temb_channels=self.temb_ch,
274 | # dropout=dropout))
275 | # block_in = block_out
276 | # if curr_res in attn_resolutions:
277 | # attn.append(AttnBlock(block_in))
278 | # down = nn.Module()
279 | # down.block = block
280 | # down.attn = attn
281 | # if i_level != self.num_resolutions-1:
282 | # down.downsample = Downsample(block_in, resamp_with_conv)
283 | # curr_res = curr_res // 2
284 | # self.down.append(down)
285 | #
286 | # # middle
287 | # self.mid = nn.Module()
288 | # self.mid.block_1 = ResnetBlock(in_channels=block_in,
289 | # out_channels=block_in,
290 | # temb_channels=self.temb_ch,
291 | # dropout=dropout)
292 | # self.mid.attn_1 = AttnBlock(block_in)
293 | # self.mid.block_2 = ResnetBlock(in_channels=block_in,
294 | # out_channels=block_in,
295 | # temb_channels=self.temb_ch,
296 | # dropout=dropout)
297 | #
298 | # # end
299 | # self.norm_out = Normalize(block_in)
300 | # self.conv_out = torch.nn.Conv2d(block_in,
301 | # 2*z_channels if double_z else z_channels,
302 | # kernel_size=3,
303 | # stride=1,
304 | # padding=1)
305 | #
306 | #
307 | # def forward(self, x):
308 | # #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
309 | #
310 | # # timestep embedding
311 | # temb = None
312 | #
313 | # # downsampling
314 | # hs = [self.conv_in(x)]
315 | # for i_level in range(self.num_resolutions):
316 | # for i_block in range(self.num_res_blocks):
317 | # h = self.down[i_level].block[i_block](hs[-1], temb)
318 | # if len(self.down[i_level].attn) > 0:
319 | # h = self.down[i_level].attn[i_block](h)
320 | # hs.append(h)
321 | # if i_level != self.num_resolutions-1:
322 | # hs.append(self.down[i_level].downsample(hs[-1]))
323 | #
324 | # # middle
325 | # h = hs[-1]
326 | # h = self.mid.block_1(h, temb)
327 | # h = self.mid.attn_1(h)
328 | # h = self.mid.block_2(h, temb)
329 | #
330 | # # end
331 | # h = self.norm_out(h)
332 | # h = nonlinearity(h)
333 | # h = self.conv_out(h)
334 | # return h
335 |
336 |
337 | class Decoder(nn.Module):
338 | def __init__(self, out_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs):
339 | super(Decoder, self).__init__()
340 | ch_mult = channels[1:]
341 | num_resolutions = len(ch_mult)
342 | block_in = ch_mult[num_resolutions - 1]
343 | curr_res = resolution// 2 ** (num_resolutions - 1)
344 |
345 | layers = [nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1),
346 | ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0),
347 | AttnBlock(block_in),
348 | ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0)
349 | ]
350 |
351 | for i in reversed(range(num_resolutions)):
352 | block_out = ch_mult[i]
353 | for i_block in range(num_res_blocks+1):
354 | layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=0.0))
355 | block_in = block_out
356 | if curr_res in attn_resolutions:
357 | layers.append(AttnBlock(block_in))
358 | if i > 0:
359 | layers.append(Upsample(block_in, with_conv=True))
360 | curr_res = curr_res * 2
361 |
362 | layers.append(Normalize(block_in))
363 | layers.append(Swish())
364 | layers.append(nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1))
365 |
366 | self.model = nn.Sequential(*layers)
367 |
368 | def forward(self, x):
369 | return self.model(x)
370 |
371 |
372 | # class Decoder(nn.Module):
373 | # def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
374 | # attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
375 | # resolution, z_channels, **ignorekwargs):
376 | # super().__init__()
377 | # self.temb_ch = 0
378 | # self.num_resolutions = len(ch_mult)
379 | # self.num_res_blocks = num_res_blocks
380 | # self.resolution = resolution
381 | # self.in_channels = in_channels
382 | #
383 | # block_in = ch*ch_mult[self.num_resolutions-1]
384 | # curr_res = resolution // 2**(self.num_resolutions-1)
385 | # self.z_shape = (1,z_channels,curr_res,curr_res)
386 | #
387 | # # z to block_in
388 | # self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
389 | #
390 | # # middle
391 | # self.mid = nn.Module()
392 | # self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
393 | # self.mid.attn_1 = AttnBlock(block_in)
394 | # self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout)
395 | #
396 | # # upsampling
397 | # self.up = nn.ModuleList()
398 | # for i_level in reversed(range(self.num_resolutions)):
399 | # block = nn.ModuleList()
400 | # attn = nn.ModuleList()
401 | # block_out = ch*ch_mult[i_level]
402 | # for i_block in range(self.num_res_blocks+1):
403 | # block.append(ResnetBlock(in_channels=block_in,
404 | # out_channels=block_out,
405 | # temb_channels=self.temb_ch,
406 | # dropout=dropout))
407 | # block_in = block_out
408 | # if curr_res in attn_resolutions:
409 | # attn.append(AttnBlock(block_in))
410 | # up = nn.Module()
411 | # up.block = block
412 | # up.attn = attn
413 | # if i_level != 0:
414 | # up.upsample = Upsample(block_in, resamp_with_conv)
415 | # curr_res = curr_res * 2
416 | # self.up.insert(0, up) # prepend to get consistent order
417 | #
418 | # # end
419 | # self.norm_out = Normalize(block_in)
420 | # self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
421 | #
422 | # def forward(self, z):
423 | # self.last_z_shape = z.shape
424 | #
425 | # # timestep embedding
426 | # temb = None
427 | #
428 | # # z to block_in
429 | # h = self.conv_in(z)
430 | #
431 | # # middle
432 | # h = self.mid.block_1(h, temb)
433 | # h = self.mid.attn_1(h)
434 | # h = self.mid.block_2(h, temb)
435 | #
436 | # # upsampling
437 | # for i_level in reversed(range(self.num_resolutions)):
438 | # for i_block in range(self.num_res_blocks+1):
439 | # h = self.up[i_level].block[i_block](h, temb)
440 | # if len(self.up[i_level].attn) > 0:
441 | # h = self.up[i_level].attn[i_block](h)
442 | # if i_level != 0:
443 | # h = self.up[i_level].upsample(h)
444 | #
445 | # h = self.norm_out(h)
446 | # h = nonlinearity(h)
447 | # h = self.conv_out(h)
448 | # return h
449 |
450 |
451 | class Codebook(nn.Module):
452 | """
453 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
454 | avoids costly matrix multiplications and allows for post-hoc remapping of indices.
455 | """
456 | def __init__(self, codebook_size, codebook_dim, beta, init_steps=2000, reservoir_size=2e5):
457 | super().__init__()
458 | self.codebook_size = codebook_size
459 | self.codebook_dim = codebook_dim
460 | self.beta = beta
461 |
462 | self.embedding = nn.Embedding(self.codebook_size, self.codebook_dim)
463 | self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
464 |
465 | self.q_start_collect, self.q_init, self.q_re_end, self.q_re_step = init_steps, init_steps * 3, init_steps * 30, init_steps // 2
466 | self.q_counter = 0
467 | self.reservoir_size = int(reservoir_size)
468 | self.reservoir = None
469 |
470 | def forward(self, z):
471 | z = rearrange(z, 'b c h w -> b h w c').contiguous()
472 | batch_size = z.size(0)
473 | z_flattened = z.view(-1, self.codebook_dim)
474 | if self.training:
475 | self.q_counter += 1
476 | # x_flat = x.permute(0, 2, 3, 1).reshape(-1, z.shape(1))
477 | if self.q_counter > self.q_start_collect:
478 | z_new = z_flattened.clone().detach().view(batch_size, -1, self.codebook_dim)
479 | z_new = z_new[:, torch.randperm(z_new.size(1))][:, :10].reshape(-1, self.codebook_dim)
480 | self.reservoir = z_new if self.reservoir is None else torch.cat([self.reservoir, z_new], dim=0)
481 | self.reservoir = self.reservoir[torch.randperm(self.reservoir.size(0))[:self.reservoir_size]].detach()
482 | if self.q_counter < self.q_init:
483 | z_q = rearrange(z, 'b h w c -> b c h w').contiguous()
484 | return z_q, z_q.new_tensor(0), None # z_q, loss, min_encoding_indices
485 | else:
486 | # if self.q_counter < self.q_init + self.q_re_end:
487 | if self.q_init <= self.q_counter < self.q_re_end:
488 | if (self.q_counter - self.q_init) % self.q_re_step == 0 or self.q_counter == self.q_init + self.q_re_end - 1:
489 | kmeans = KMeans(n_clusters=self.codebook_size)
490 | world_size = dist.get_world_size()
491 | print("Updating codebook from reservoir.")
492 | if world_size > 1:
493 | global_reservoir = [torch.zeros_like(self.reservoir) for _ in range(world_size)]
494 | dist.all_gather(global_reservoir, self.reservoir.clone())
495 | global_reservoir = torch.cat(global_reservoir, dim=0)
496 | else:
497 | global_reservoir = self.reservoir
498 | kmeans.fit_predict(global_reservoir) # reservoir is 20k encoded latents
499 | self.embedding.weight.data = kmeans.centroids.detach()
500 |
501 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
502 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
503 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
504 |
505 | min_encoding_indices = torch.argmin(d, dim=1)
506 | z_q = self.embedding(min_encoding_indices).view(z.shape)
507 |
508 | # compute loss for embedding
509 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
510 |
511 | # preserve gradients
512 | z_q = z + (z_q - z).detach()
513 |
514 | # reshape back to match original input shape
515 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
516 |
517 | return z_q, loss, min_encoding_indices
518 |
519 | def get_codebook_entry(self, indices, shape):
520 | # get quantized latent vectors
521 | z_q = self.embedding(indices)
522 |
523 | if shape is not None:
524 | z_q = z_q.view(shape)
525 | # reshape back to match original input shape
526 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
527 |
528 | return z_q
529 |
530 |
531 | if __name__ == '__main__':
532 | enc = Encoder()
533 | dec = Decoder()
534 | print(sum([p.numel() for p in enc.parameters()]))
535 | print(sum([p.numel() for p in dec.parameters()]))
536 | x = torch.randn(1, 3, 512, 512)
537 | res = enc(x)
538 | print(res.shape)
539 | res = dec(res)
540 | print(res.shape)
541 |
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | taken from: https://github.com/ai-forever/ru-dalle/blob/master/rudalle/dalle/transformer.py slightly modified
3 | """
4 |
5 | import math
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import functional as F
9 |
10 |
11 | @torch.jit.script
12 | def gelu(x):
13 | """OpenAI's gelu implementation."""
14 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
15 |
16 |
17 | class SelfAttention(nn.Module):
18 | def __init__(self,
19 | hidden_dim,
20 | num_attn_heads,
21 | attn_dropout_prob,
22 | out_dropout_prob,
23 | cogview_pb_relax=True,
24 | rudalle_relax=False
25 | ):
26 | super(SelfAttention, self).__init__()
27 |
28 | self.hidden_dim = hidden_dim
29 | self.num_attn_heads = num_attn_heads
30 | self.d = math.sqrt(self.hidden_dim // self.num_attn_heads)
31 | self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim)
32 | self.attn_drop = nn.Dropout(attn_dropout_prob)
33 |
34 | self.out_proj = nn.Linear(hidden_dim, hidden_dim)
35 | self.out_drop = nn.Dropout(out_dropout_prob)
36 |
37 | self.cogview_pb_relax = cogview_pb_relax
38 | self.rudalle_relax = rudalle_relax
39 |
40 | def split_heads(self, x):
41 | new_shape = x.size()[:-1] + (self.num_attn_heads, self.hidden_dim // self.num_attn_heads)
42 | return x.view(*new_shape).permute(0, 2, 1, 3)
43 |
44 | def calculate_attention(self, q, k, mask):
45 | k_t = k.transpose(-1, -2)
46 | mask_value = 10000.
47 | if self.cogview_pb_relax:
48 | if self.rudalle_relax:
49 | sigma = k_t.std()
50 | attn_scores = torch.matmul(q / self.d, k_t / sigma)
51 | attn_scores_max = attn_scores.detach().max(dim=-1)[0]
52 | attn_scores_min = (attn_scores.detach() + 65504).min(dim=-1)[0]
53 | shift = torch.min(attn_scores_min, attn_scores_max).unsqueeze(-1).expand_as(attn_scores) / 2
54 | attn_scores = (attn_scores - shift) / sigma
55 | mask_value = 65504
56 | else:
57 | attn_scores = torch.matmul(q / self.d, k_t)
58 | else:
59 | attn_scores = torch.matmul(q, k_t) / self.d
60 |
61 | mask = mask[:, :, -attn_scores.shape[-2]:]
62 | attn_scores = mask * attn_scores - (1. - mask) * mask_value
63 | if self.cogview_pb_relax and not self.rudalle_relax:
64 | alpha = 32
65 | attn_scores_scaled = attn_scores / alpha
66 | attn_scores_scaled_max, _ = attn_scores_scaled.detach().view(
67 | [attn_scores.shape[0], attn_scores.shape[1], -1]).max(dim=-1)
68 | attn_scores_scaled_max = attn_scores_scaled_max[..., None, None].expand(
69 | [-1, -1, attn_scores.size(2), attn_scores.size(3)])
70 | attn_scores = (attn_scores_scaled - attn_scores_scaled_max) * alpha
71 | return attn_scores
72 |
73 | def forward(self, x, mask, use_cache=False, cache=None):
74 | if use_cache and cache is not None:
75 | qkv = self.qkv(x[:, cache[0].shape[-2]:, :])
76 | else:
77 | qkv = self.qkv(x)
78 |
79 | q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1) # probably use different dim
80 | q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v)
81 |
82 | if use_cache and cache is not None:
83 | past_k, past_v, past_output = cache
84 | k = torch.cat([past_k, k], dim=-2)
85 | v = torch.cat([past_v, v], dim=-2)
86 | attn_scores = self.calculate_attention(q, k, mask)
87 | else:
88 | attn_scores = self.calculate_attention(q, k, mask)
89 |
90 | attn_probs = nn.Softmax(dim=-1)(attn_scores) # [b, np, s, s]
91 | attn_probs = self.attn_drop(attn_probs)
92 |
93 | if self.rudalle_relax:
94 | scale = v.detach().max().item()
95 | context = torch.matmul(attn_probs, v / scale)
96 | else:
97 | context = torch.matmul(attn_probs, v)
98 |
99 | context = context.permute(0, 2, 1, 3).contiguous()
100 | context = context.view(context.shape[0], context.shape[1],
101 | context.shape[2] * context.shape[3])
102 |
103 | if self.rudalle_relax:
104 | scale = context.detach().max().item()
105 | context /= scale
106 |
107 | out = self.out_proj(context)
108 | if use_cache and cache is not None:
109 | out = torch.concat([past_output, out], dim=-2)
110 |
111 | if use_cache:
112 | cache = k, v, out
113 |
114 | out = self.out_drop(out)
115 | return out, cache
116 |
117 |
118 | class MLP(nn.Module):
119 | def __init__(self,
120 | hidden_dim,
121 | dropout_prob,
122 | rudalle_relax=False
123 | ):
124 | super(MLP, self).__init__()
125 | self.lin1 = nn.Linear(hidden_dim, 4 * hidden_dim)
126 | self.lin2 = nn.Linear(4 * hidden_dim, hidden_dim)
127 | self.dropout = nn.Dropout(dropout_prob)
128 | self.rudalle_relax = rudalle_relax
129 |
130 | def forward(self, x):
131 | x = self.lin1(x)
132 | x = gelu(x)
133 | if self.rudalle_relax:
134 | scale = x.detach().max().item() / 4
135 | x = self.lin2(x / scale)
136 | x = (x / x.detach().max(dim=-1)[0].unsqueeze(-1)) * scale
137 | else:
138 | x = self.lin2(x)
139 | return self.dropout(x)
140 |
141 |
142 | class TransformerLayer(nn.Module):
143 | def __init__(self,
144 | hidden_dim,
145 | num_attn_heads,
146 | attn_dropout_prop,
147 | out_dropout_prob,
148 | cogview_pb_relax=True,
149 | cogview_sandwich_layernorm=True,
150 | cogview_layernorm_prescale=False,
151 | rudalle_relax=False
152 | ):
153 | super().__init__()
154 | self.cogview_pb_relax = cogview_pb_relax
155 | self.cogview_sandwich_layernorm = cogview_sandwich_layernorm
156 | self.cogview_layernorm_prescale = cogview_layernorm_prescale
157 | self.rudalle_relax = rudalle_relax
158 |
159 | self.ln_in = nn.LayerNorm(hidden_dim, eps=1e-5)
160 | self.ln_out = nn.LayerNorm(hidden_dim, eps=1e-5)
161 | if cogview_sandwich_layernorm:
162 | self.first_ln_sandwich = nn.LayerNorm(hidden_dim, eps=1e-5)
163 | self.second_ln_sandwich = nn.LayerNorm(hidden_dim, eps=1e-5)
164 |
165 | self.attn = SelfAttention(hidden_dim=hidden_dim,
166 | num_attn_heads=num_attn_heads,
167 | attn_dropout_prob=attn_dropout_prop,
168 | out_dropout_prob=out_dropout_prob,
169 | cogview_pb_relax=cogview_pb_relax,
170 | rudalle_relax=rudalle_relax)
171 |
172 | self.mlp = MLP(hidden_dim=hidden_dim,
173 | dropout_prob=out_dropout_prob,
174 | rudalle_relax=rudalle_relax)
175 |
176 | def forward(self, x, mask, cache=None, use_cache=False, mlp_cache=False):
177 | if self.cogview_layernorm_prescale:
178 | ln_in = self.ln_in(x / x.detach().max(dim=-1)[0].unsqueeze(-1))
179 | else:
180 | ln_in = self.ln_in(x)
181 | attn_out, new_cache = self.attn(ln_in, mask, cache, use_cache)
182 |
183 | if self.cogview_sandwich_layernorm:
184 | if self.cogview_layernorm_prescale:
185 | attn_out = self.first_ln_sandwich(attn_out / attn_out.detach().max(dim=-1)[0].unsqueeze(-1))
186 | else:
187 | attn_out = self.first_ln_sandwich(attn_out)
188 |
189 | x = x + attn_out
190 | cached = 0 if cache is None else cache[0].shape[2]
191 |
192 | if self.cogview_layernorm_prescale:
193 | ln_out = self.ln_out(x / x.detach().max(dim=-1)[0].unsqueeze(-1))
194 | else:
195 | ln_out = self.ln_out(x)
196 |
197 | if use_cache and cached:
198 | mlp_out = torch.cat(
199 | (cache[-1] if mlp_cache else ln_out[..., :cached, :], self.mlp(ln_out[..., :cached, :])), dim=-2)
200 | if mlp_cache:
201 | new_cache = new_cache + (mlp_out,)
202 | else:
203 | mlp_out = self.mlp(ln_out)
204 |
205 | if self.cogview_sandwich_layernorm:
206 | mlp_out = self.second_ln_sandwich(mlp_out)
207 |
208 | x = x + mlp_out
209 |
210 | return x, new_cache
211 |
212 |
213 | class Transformer(nn.Module):
214 | def __init__(self,
215 | num_layers,
216 | hidden_dim,
217 | num_attn_heads,
218 | image_tokens_per_dim,
219 | seg_tokens_per_dim,
220 | text_length,
221 | attn_dropout_prop=0,
222 | out_dropout_prob=0,
223 | cogview_pb_relax=True,
224 | cogview_sandwich_layernorm=True,
225 | cogview_layernorm_prescale=False,
226 | rudalle_relax=False
227 | ):
228 | super(Transformer, self).__init__()
229 | self.num_layers = num_layers
230 | self.cogview_pb_relax = cogview_pb_relax
231 | self.rudalle_relax = rudalle_relax
232 |
233 | self.layers = nn.ModuleList([
234 | TransformerLayer(
235 | hidden_dim,
236 | num_attn_heads,
237 | attn_dropout_prop,
238 | out_dropout_prob,
239 | cogview_pb_relax,
240 | cogview_sandwich_layernorm,
241 | cogview_layernorm_prescale,
242 | rudalle_relax
243 | ) for _ in range(num_layers)
244 | ])
245 |
246 | self.register_buffer("mask", self._create_mask(text_length, seg_tokens_per_dim, image_tokens_per_dim))
247 | self.final_ln = nn.LayerNorm(hidden_dim, eps=1e-5)
248 |
249 | def _create_mask(self, text_length, seg_tokens_per_dim, image_tokens_per_dim):
250 | size = text_length + seg_tokens_per_dim ** 2 + image_tokens_per_dim ** 2
251 | return torch.tril(torch.ones(size, size, dtype=torch.float32))
252 |
253 | def get_block_size(self):
254 | return self.block_size
255 |
256 | def forward(self, x, attn_mask, cache=None, use_cache=None):
257 | if cache is None:
258 | cache = {}
259 |
260 | for i, layer in enumerate(self.layers):
261 | mask = attn_mask
262 | layer_mask = self.mask[:mask.size(2), :mask.size(3)]
263 | mask = torch.mul(attn_mask, layer_mask)
264 | x, layer_cache = layer(x, mask, cache.get(i), mlp_cache=i == len(self.layers) - 1, use_cache=use_cache)
265 | cache[i] = layer_cache
266 |
267 | if self.rudalle_relax:
268 | ln_out = self.final_ln(x / x.detach().max(dim=-1)[0].unsqueeze(-1))
269 | else:
270 | ln_out = self.final_ln(x)
271 |
272 | return ln_out, cache
273 |
274 |
275 | class MakeAScene(nn.Module):
276 | def __init__(self,
277 | num_layers,
278 | hidden_dim,
279 | num_attn_heads,
280 | image_vocab_size,
281 | seg_vocab_size,
282 | text_vocab_size,
283 | image_tokens_per_dim,
284 | seg_tokens_per_dim,
285 | text_length
286 | ):
287 | super(MakeAScene, self).__init__()
288 | self.image_tokens_per_dim = image_tokens_per_dim
289 | self.seg_tokens_per_dim = seg_tokens_per_dim
290 | self.image_length = image_tokens_per_dim ** 2
291 | self.seg_length = seg_tokens_per_dim ** 2
292 | self.text_length = text_length
293 | self.total_length = self.text_length + self.seg_length + self.image_length
294 | self.text_vocab_size = text_vocab_size
295 |
296 | self.transformer = Transformer(num_layers, hidden_dim, num_attn_heads,
297 | image_tokens_per_dim, seg_tokens_per_dim,
298 | text_length)
299 |
300 | self.image_token_embedding = nn.Embedding(image_vocab_size, hidden_dim)
301 | self.seg_token_embedding = nn.Embedding(seg_vocab_size, hidden_dim)
302 | self.text_token_embedding = nn.Embedding(text_vocab_size, hidden_dim)
303 |
304 | self.text_pos_embeddings = torch.nn.Embedding(text_length, hidden_dim)
305 | self.seg_row_embeddings = torch.nn.Embedding(seg_tokens_per_dim, hidden_dim)
306 | self.seg_col_embeddings = torch.nn.Embedding(seg_tokens_per_dim, hidden_dim)
307 | self.image_row_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_dim)
308 | self.image_col_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_dim)
309 | self._init_weights(self.text_pos_embeddings)
310 | self._init_weights(self.seg_row_embeddings)
311 | self._init_weights(self.seg_col_embeddings)
312 | self._init_weights(self.image_row_embeddings)
313 | self._init_weights(self.image_col_embeddings)
314 |
315 | self.to_logits = torch.nn.Sequential(
316 | torch.nn.LayerNorm(hidden_dim), # TODO: check if this is redundant
317 | torch.nn.Linear(hidden_dim, image_vocab_size),
318 | )
319 |
320 | def _init_weights(self, module):
321 | if isinstance(module, (nn.Linear, nn.Embedding)):
322 | module.weight.data.normal_(mean=0.0, std=0.02)
323 | if isinstance(module, nn.Linear) and module.bias is not None:
324 | module.bias.data.zero_()
325 | elif isinstance(module, nn.LayerNorm):
326 | module.bias.data.zero_()
327 | module.weight.data.fill_(1.0)
328 |
329 | def get_seg_pos_embeddings(self, seg_input_ids):
330 | input_shape = seg_input_ids.size()
331 | row_ids = torch.arange(input_shape[-1],
332 | dtype=torch.long, device=self.device) // self.seg_tokens_per_dim
333 | row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1])
334 | col_ids = torch.arange(input_shape[-1],
335 | dtype=torch.long, device=self.device) % self.seg_tokens_per_dim
336 | col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1])
337 | return self.seg_row_embeddings(row_ids) + self.seg_col_embeddings(col_ids)
338 |
339 | def get_image_pos_embeddings(self, image_input_ids, past_length=0):
340 | input_shape = image_input_ids.size()
341 | row_ids = torch.arange(past_length, input_shape[-1] + past_length,
342 | dtype=torch.long, device=self.device) // self.image_tokens_per_dim
343 | row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1])
344 | col_ids = torch.arange(past_length, input_shape[-1] + past_length,
345 | dtype=torch.long, device=self.device) % self.image_tokens_per_dim
346 | col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1])
347 | return self.image_row_embeddings(row_ids) + self.image_col_embeddings(col_ids)
348 |
349 | def forward(self, text_tokens, seg_tokens, img_tokens):
350 | text_range = torch.arange(self.text_length)
351 | text_range += (self.text_vocab_size - self.text_length)
352 | text_range = text_range.to(self.device)
353 | text_tokens = torch.where(text_tokens == 0, text_range, text_tokens)
354 | text_pos = self.text_pos_embeddings(torch.arange(text_tokens.shape[1], device=self.device))
355 | text_embeddings = self.text_token_embedding(text_tokens) + text_pos
356 |
357 | seg_pos = self.get_seg_pos_embeddings(seg_tokens)
358 | seg_embeddings = self.seg_token_embedding(seg_tokens) + seg_pos
359 |
360 | embeddings = torch.cat((text_embeddings, seg_embeddings), dim=1)
361 | if img_tokens is not None:
362 | img_pos = self.get_image_pos_embeddings(img_tokens)
363 | image_embeddings = self.image_token_embedding(img_tokens) + img_pos
364 | embeddings = torch.cat((embeddings, image_embeddings), dim=1)
365 |
366 | attention_mask = torch.tril(
367 | torch.ones((embeddings.shape[0], 1, self.total_length, self.total_length), device=self.device)
368 | )
369 | attention_mask[:, :, :-self.image_length, :-self.image_length] = 1
370 | attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]]
371 |
372 | transformer_output, present_cache = self.transformer(
373 | embeddings, attention_mask,
374 | cache=None, use_cache=False
375 | )
376 |
377 | logits = self.to_logits(transformer_output)
378 | return logits[:, -self.image_length-1:-1, :]
379 |
380 |
381 | if __name__ == '__main__':
382 | from omegaconf import OmegaConf
383 | batch_size = 2
384 |
385 | num_layers = 2
386 | hidden_dim = 64
387 | num_attn_heads = 8
388 |
389 | text_length = 128
390 | seg_per_dim = 16
391 | image_per_dim = 32
392 |
393 | text_vocab_size = 128
394 | seg_vocab_size = 128
395 | image_vocab_size = 128
396 |
397 | model = MakeAScene(num_layers, hidden_dim, num_attn_heads, image_vocab_size, seg_vocab_size, text_vocab_size+text_length, image_per_dim, seg_per_dim, text_length).to('cuda')
398 | model.device = torch.device('cuda')
399 | pred_logits = model(torch.randint(high=text_vocab_size+text_length, size=(batch_size, text_length)).cuda(),
400 | torch.randint(high=seg_vocab_size, size=(batch_size, seg_per_dim**2)).cuda(),
401 | torch.randint(high=image_vocab_size, size=(batch_size, image_per_dim**2)).cuda())
402 |
403 | assert pred_logits.shape == torch.Size([batch_size, image_per_dim ** 2, image_vocab_size])
404 |
--------------------------------------------------------------------------------
/models/vqvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from .modules import Encoder, Decoder
5 | from .modules import Codebook
6 |
7 |
8 | class VQBASE(nn.Module):
9 | def __init__(self, ddconfig, n_embed, embed_dim, init_steps, reservoir_size):
10 | super(VQBASE, self).__init__()
11 | self.encoder = Encoder(**ddconfig)
12 | self.decoder = Decoder(**ddconfig)
13 | self.quantize = Codebook(n_embed, embed_dim, beta=0.25, init_steps=init_steps, reservoir_size=reservoir_size) # TODO: change length_one_epoch
14 | self.quant_conv = nn.Sequential(
15 | nn.Conv2d(ddconfig["z_channels"], embed_dim, 1),
16 | nn.SyncBatchNorm(embed_dim)
17 | )
18 | self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
19 |
20 | def encode(self, x):
21 | h = self.encoder(x)
22 | h = self.quant_conv(h)
23 | quant, emb_loss, info = self.quantize(h)
24 | return quant, emb_loss
25 |
26 | def decode(self, quant):
27 | quant = self.post_quant_conv(quant)
28 | dec = self.decoder(quant)
29 | return dec
30 |
31 | def decode_code(self, code_b):
32 | quant_b = self.quantize.embed_code(code_b)
33 | dec = self.decode(quant_b)
34 | return dec
35 |
36 | def forward(self, input):
37 | quant, diff = self.encode(input)
38 | dec = self.decode(quant)
39 | return dec, diff
40 |
41 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import torch.distributed as dist
5 | from torch.utils.data import DataLoader
6 | from torch.nn.parallel import DistributedDataParallel
7 | import torchvision
8 | import torch.multiprocessing as mp
9 | from tqdm import tqdm
10 | import hydra
11 | from log_utils import Logger, Visualizer
12 | from utils import collate_fn, change_requires_grad
13 | import os
14 | from time import time
15 | import traceback
16 |
17 |
18 | def train(proc_id, cfg):
19 | print(proc_id)
20 | parallel = len(cfg.devices) > 1
21 | if parallel:
22 | torch.cuda.set_device(proc_id)
23 | torch.backends.cudnn.benchmark = True
24 | dist.init_process_group(backend="nccl", init_method="env://", world_size=len(cfg.devices), rank=proc_id)
25 | device = torch.device(proc_id)
26 | dataset = hydra.utils.instantiate(cfg.dataset, _recursive_=False)
27 | print(cfg.dataset)
28 | dataloader = DataLoader(dataset, **cfg.dataloader, collate_fn=collate_fn)
29 | model = hydra.utils.instantiate(cfg.model).to(device)
30 | loss_fn = hydra.utils.instantiate(cfg.loss).to(device)
31 | if parallel:
32 | model = DistributedDataParallel(model, device_ids=[device], output_device=device) #just try, find_unused_parameters=True)
33 | if "discriminator" in loss_fn._modules.keys():
34 | loss_fn = DistributedDataParallel(loss_fn, device_ids=[device], output_device=device,)
35 | logger = Logger(proc_id, device=device)
36 |
37 | if cfg.mode == "pretrain_segmentation":
38 | optim = torch.optim.Adam(model.parameters(), **cfg.optimizer)
39 |
40 | for step in range(cfg.total_steps):
41 | data = next(dataloader_iter)
42 | _, seg = data
43 | seg = seg.to(device)
44 | seg_rec, q_loss = model(seg)
45 | loss = loss_fn(q_loss, seg, seg_rec)
46 |
47 | if step % cfg.log_period == 0:
48 | logger.log(loss, q_loss, seg, seg_rec, step)
49 | torch.save(model.state_dict(), "checkpoint.pt")
50 |
51 | loss.backward()
52 | if step % cfg.accumulate_grad == 0:
53 | optim.step()
54 | optim.zero_grad()
55 |
56 | if step == cfg.total_steps:
57 | torch.save(model.state_dict(), "final.pt")
58 | return
59 |
60 | elif cfg.mode == "pretrain_image":
61 | vq_optim = torch.optim.Adam(model.parameters(), **cfg.optimizer.vq)
62 | for param_group in vq_optim.param_groups:
63 | param_group["lr"]/=cfg.accumulate_grad
64 | disc_optim = torch.optim.Adam(loss_fn.module.discriminator.parameters(), **cfg.optimizer.disc)
65 | for param_group in disc_optim.param_groups:
66 | param_group["lr"]/=cfg.accumulate_grad
67 |
68 | start = 0
69 | if cfg.resume:
70 | checkpoint = torch.load(cfg.checkpoint, map_location=device)
71 | model.module.load_state_dict(checkpoint["model"])
72 | loss_fn.module.discriminator.load_state_dict(checkpoint["discriminator"])
73 | vq_optim.load_state_dict(checkpoint["optim"])
74 | disc_optim.load_state_dict(checkpoint["disc_optim"])
75 | start = checkpoint["step"]
76 | model.module.quantize.q_counter = start
77 |
78 | pbar = tqdm(enumerate(dataloader, start=start), total=cfg.total_steps, initial=start) if proc_id == 0 else enumerate(dataloader, start=start)
79 | try:
80 | for step, data in pbar:
81 | img, _, bbox_objects, bbox_faces, _ = data
82 | img = img.to(device)
83 | bbox_objects = bbox_objects.to(device).to(torch.float32)
84 | img_rec, q_loss = model(img)
85 |
86 | change_requires_grad(model, False)
87 | d_loss , (d_loss_ema,) = loss_fn(optimizer_idx=1, global_step=step, images=img, reconstructions=img_rec)
88 | d_loss.backward()
89 | change_requires_grad(model, True)
90 |
91 | change_requires_grad(loss_fn.module.discriminator, False)
92 | loss, (nll_loss, face_loss, g_loss) = loss_fn(optimizer_idx=0, global_step=step, images=img,
93 | reconstructions=img_rec,
94 | codebook_loss=q_loss, bbox_obj=bbox_objects,
95 | bbox_face=bbox_faces,
96 | last_layer=model.module.decoder.model[-1])
97 | loss.backward()
98 | change_requires_grad(loss_fn.module.discriminator, True)
99 | if step % cfg.accumulate_grad == 0:
100 | disc_optim.step()
101 | disc_optim.zero_grad()
102 | vq_optim.step()
103 | vq_optim.zero_grad()
104 |
105 | if step % cfg.log_period == 0 and proc_id == 0:
106 | logger.log(loss=loss, q_loss=q_loss, img=img, img_rec=img_rec, d_loss=d_loss, nll_loss=nll_loss,
107 | face_loss=face_loss, g_loss=g_loss, d_loss_ema=d_loss_ema, step=step)
108 | if step % cfg.save_period == 0 and proc_id == 0:
109 | state = {
110 | "model": model.module.state_dict(),
111 | "discriminator": loss_fn.module.discriminator.state_dict(),
112 | "optim": vq_optim.state_dict(),
113 | "disc_optim": disc_optim.state_dict(),
114 | "step": step
115 | }
116 | torch.save(state, f"checkpoint_{step//5e4}.pt")
117 |
118 | if step == cfg.total_steps:
119 | state = {
120 | "model": model.module.state_dict(),
121 | "discriminator": loss_fn.module.discriminator.state_dict(),
122 | "optim": vq_optim.state_dict(),
123 | "disc_optim": disc_optim.state_dict(),
124 | "step": step
125 | }
126 | torch.save(state, "final.pt")
127 | return
128 | except Exception as e:
129 | print('Caught exception in worker thread (x = %d):' % proc_id)
130 | # This prints the type, value, and stack trace of the
131 | # current exception being handled.
132 | with open("error.log", "a") as f:
133 | traceback.print_exc(file=f)
134 | raise e
135 |
136 | elif cfg.mode == "train_transformer":
137 | optim = torch.optim.Adam(model.parameters(), **cfg.optimizer.stage2)
138 |
139 | pbar = tqdm(enumerate(dataloader), total=cfg.total_steps) if proc_id == 0 else enumerate(dataloader)
140 | try:
141 | for step, data in pbar:
142 | img_token, seg_token, _, _, text_token = data
143 | img_token = img_token.to(device)
144 | seg_token = seg_token.to(device)
145 | text_token = text_token.to(device)
146 |
147 | if step >= cfg.start_uncond and random() < cfg.uncond_p:
148 | text_token *= 0
149 |
150 | pred_logit = model(text_token, seg_token, img_token)
151 |
152 | loss = F.cross_entropy(pred_logit.view(-1, pred_logit.shape[-1]), img_token.view(-1))
153 | loss.backward()
154 | if step % cfg.accumulate_grad == 0:
155 | optim.step()
156 | optim.zero_grad()
157 |
158 | ### LOGGING PART
159 | if step % cfg.log_period == 0:
160 | logger.log(loss=loss, step=step)
161 | torch.save(model.module.state_dict(), "checkpoint.pt")
162 |
163 | if step == cfg.total_steps:
164 | torch.save(model.module.state_dict(), "final.pt")
165 | return
166 | except Exception as e:
167 | print('Caught exception in worker thread (x = %d):' % proc_id)
168 |
169 | # This prints the type, value, and stack trace of the
170 | # current exception being handled.
171 | with open("error.log", "a") as f:
172 | traceback.print_exc(file=f)
173 | raise e
174 |
175 | def visualize(cfg):
176 | device = torch.device(cfg.devices[0])
177 | model = torch.nn.DataParallel(hydra.utils.instantiate(cfg.model)).to(device)
178 | checkpoint = hydra.utils.to_absolute_path(cfg.checkpoint)
179 | state_dict = torch.load(checkpoint, map_location=device)
180 | model.load_state_dict(state_dict)
181 | model = model.module
182 | model.eval()
183 | img = torch.rand(1, 159, 256, 256).to(device)
184 | dataset = hydra.utils.instantiate(cfg.dataset)
185 | visualizer = Visualizer(device=device)
186 | print("Processing...")
187 | for i, data in enumerate(dataset):
188 | if i == 40:
189 | break
190 | print("Processing image ", i)
191 | img, seg = data
192 | img = img.to(device).unsqueeze(0)
193 | seg = seg.to(device).unsqueeze(0)
194 | seg_rec, _ = model(seg)
195 | visualizer(i, image=img, seg=seg, seg_rec=seg_rec, )
196 |
197 | print(model(dataset[0][1].unsqueeze(0).to(device))[0].shape)
198 |
199 |
200 | def preprocess_dataset(cfg):
201 | # dataset = hydra.utils.instantiate(cfg.dataset,)
202 | dataset = cfg.dataset
203 | preprocessor = hydra.utils.instantiate(cfg.preprocessor)
204 | preprocessor(dataset)
205 |
206 |
207 | @hydra.main(config_path="conf", config_name="img_config", version_base="1.2")
208 | def launch(cfg):
209 | if "pretrain" in cfg.mode:
210 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(d) for d in cfg.devices])
211 | cfg.checkpoint = hydra.utils.to_absolute_path(cfg.checkpoint)
212 | if len(cfg.devices) == 1:
213 | train(0, cfg)
214 | else:
215 | os.environ["MASTER_ADDR"] = "localhost"
216 | os.environ["MASTER_PORT"] = "33751"
217 | p = mp.spawn(train, nprocs=len(cfg.devices), args=(cfg,))
218 | elif "show" in cfg.mode:
219 | visualize(cfg)
220 | elif "preprocess_dataset" in cfg.mode:
221 | preprocess_dataset(cfg)
222 |
223 |
224 | if __name__ == "__main__":
225 | launch()
226 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def collate_fn(batch, need_seg=False):
5 | images = torch.stack([i[0] for i in batch], dim=0)
6 | if need_seg:
7 | segmentation_maps = torch.stack([torch.Tensor(i[1]).permute(2, 0, 1) for i in batch], dim=0)
8 | else:
9 | segmentation_maps = []
10 | object_boxes = [list(map(lambda bbox: list(map(int, bbox))[:-1], i[2])) for i in batch]
11 | face_boxes = [list(map(lambda bbox: list(map(int, bbox))[:-1], i[3])) for i in batch]
12 | captions = [i[4] for i in batch]
13 | return [images, segmentation_maps, object_boxes, face_boxes, captions]
14 |
15 |
16 | def collate_fn_(batch):
17 | images, segmentation_maps = None, None
18 | object_boxes, face_boxes, captions = [], [], []
19 | for i in batch:
20 | images = i[0] if images is None else torch.stack([images, i[0]], dim=0)
21 | segmentation_maps = torch.Tensor(i[1]).permute(2, 0, 1) if segmentation_maps is None else torch.stack([segmentation_maps, torch.Tensor(i[1]).permute(2, 0, 1)], dim=0)
22 | object_boxes.append(i[2][:-1])
23 | face_boxes.append(i[3][:-1])
24 | captions.append(i[4])
25 | return [images, segmentation_maps, object_boxes, face_boxes, captions]
26 |
27 | def change_requires_grad(model, state):
28 | for parameter in model.parameters():
29 | parameter.requires_grad = state
30 |
31 |
--------------------------------------------------------------------------------