├── .github └── FUNDING.yml ├── .gitignore ├── Data ├── README.md ├── __init__.py ├── dataset_preprocessor.py ├── dataset_preprocessor_web.py ├── preprocessor.py ├── preprocessor_web.py ├── preprocessors │ ├── __init__.py │ ├── detectron2_preprocessor.py │ ├── edge_extractor.py │ ├── face_alignment_preprocessor.py │ └── human_parts_preprocessor.py └── utils.py ├── LICENSE ├── README.md ├── conf ├── img_config.yaml ├── preprocess_data.yaml ├── preprocess_data_web.yaml ├── seg_config.yaml └── show.yaml ├── log_utils.py ├── losses ├── __init__.py ├── discriminator.py ├── face_loss.py ├── loss_img.py ├── loss_seg.py ├── lpips.py └── lpips_with_object.py ├── models ├── __init__.py ├── modules.py ├── transformer.py └── vqvae.py ├── train.py └── utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: casualganpapers 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | -------------------------------------------------------------------------------- /Data/README.md: -------------------------------------------------------------------------------- 1 | # Data Aggregation 2 |

3 | results 4 |

5 | 6 | ## Segmentation Dataset 7 | "VQ-SEG and VQ-IMG are trained on CC12m, CC, and MS-COCO." -> For the segmentation 8 | process we first need to convert all the 3 datasets to segmentation datasets using 9 | the 3 models described below: 10 | 11 | - Panoptic: https://github.com/facebookresearch/detectron2 12 | - Human Parts: https://github.com/PeikeLi/Self-Correction-Human-Parsing 13 | - Human Face: https://github.com/1adrianb/face-alignment 14 | 15 | These 3 models will be used to construct the dataset for the segmentation maps. 16 | 17 | VQ-SEG was trained to have 158 categories 18 | - 133 panoptic 19 | - 20 human parts 20 | - 5 human face (eye-brows, eyes, nose, outer-mouth, inner-mouth) 21 | 22 | ## Data Pipeline: 23 | 1. Take in an image or dataset (HxWx3) 24 | 2. seg_panoptic = detectron2(x) 25 | 3. seg_human = human_parsing(x) 26 | 4. seg_face = human_face(x) 27 | 5. Concatenate along channel axis 28 | 6. Add one channel for edges between objects 29 | 7. return segmentation map (HxWx159) 30 | 31 | -------------------------------------------------------------------------------- /Data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_preprocessor import COCO2014Dataset 2 | from .dataset_preprocessor_web import PreprocessedWebDataset 3 | from warnings import warn 4 | try: 5 | from .preprocessor import BasePreprocessor 6 | from .preprocessor_web import WebPreprocessor 7 | except ModuleNotFoundError: 8 | #warn("Some dependencies missing for data preprocessing.") 9 | pass 10 | -------------------------------------------------------------------------------- /Data/dataset_preprocessor.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import os 4 | import cv2 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset 8 | from torch.utils.data import ConcatDataset as ConcatDataset_ 9 | from tqdm import tqdm 10 | import torch.multiprocessing as mp 11 | import warnings 12 | from torchvision import transforms 13 | from hydra.utils import instantiate 14 | import albumentations as A 15 | from albumentations.pytorch import ToTensorV2 16 | from urllib.request import urlretrieve 17 | 18 | 19 | class PreprocessedDataset(Dataset): 20 | def __init__( 21 | self, 22 | root=None, 23 | image_dirs=None, 24 | preprocessed_folder=None, 25 | ): 26 | self.image_dirs = image_dirs 27 | 28 | self.preprocessed_folder = preprocessed_folder 29 | self.preprocessed_path = os.path.join(preprocessed_folder, "segmentations", "%s_%s.npz", ) 30 | 31 | self.root = root 32 | self.transforms = A.Compose([ 33 | A.SmallestMaxSize(256), 34 | A.RandomCrop(256, 256, always_apply=True), 35 | # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 36 | ToTensorV2(transpose_mask=True) 37 | ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=["class_labels"], min_area=100, min_visibility=0.2), 38 | additional_targets={"bboxes0": "bboxes"}) 39 | 40 | if not os.path.exists(preprocessed_folder): 41 | os.makedirs(preprocessed_folder) 42 | else: 43 | assert os.path.isdir(preprocessed_folder) 44 | 45 | self.img_names_path = os.path.join(preprocessed_folder, f"img_names_{self.name}.npz") 46 | if not os.path.exists(self.img_names_path): 47 | self.parse_image_names() 48 | 49 | img_list = np.load(self.img_names_path) 50 | self.img_names = img_list["img_names"] 51 | if "img_urls" in img_list: 52 | self.img_urls = img_list["img_urls"] 53 | 54 | def load_segmentation(self, idx): 55 | img_name = os.path.splitext(self.img_names[idx])[0] 56 | 57 | data_panoptic = np.load(self.preprocessed_path % (img_name, "panoptic")) 58 | data_human = np.load(self.preprocessed_path % (img_name, "human")) 59 | data_face = np.load(self.preprocessed_path % (img_name, "face")) 60 | 61 | # Panoptic 62 | seg_panoptic = F.one_hot( 63 | torch.from_numpy(data_panoptic["seg_panoptic"] + 1).to(torch.long), num_classes=134 64 | )[..., 1:] 65 | edges_panoptic = torch.from_numpy(data_panoptic["edges"]).unsqueeze(-1) 66 | box_thing = data_panoptic["box_things"] 67 | 68 | # Human parts 69 | seg_human = F.one_hot( 70 | torch.from_numpy(data_human["seg_human"] + 1).to(torch.long), num_classes=21 71 | )[..., 1:] 72 | edges_human = torch.from_numpy(data_human["edges"]).unsqueeze(-1) 73 | 74 | # Edges 75 | seg_edges = (edges_panoptic + edges_human).float() 76 | 77 | # Face 78 | seg_face = F.one_hot( 79 | torch.from_numpy(data_face["seg_face"]).to(torch.long), num_classes=6 80 | )[..., 1:] 81 | box_face = data_face["box_face"] 82 | 83 | # Concatenate masks 84 | seg_map = torch.cat( 85 | [seg_panoptic, seg_human, seg_face, seg_edges], dim=-1 86 | ) 87 | 88 | return np.array(seg_map), box_thing, box_face 89 | 90 | def __getitem__(self, idx): 91 | segmentation, box_thing, box_face = self.load_segmentation(idx) 92 | image, _ = self.get_image(idx) 93 | data = self.transforms(image=image, mask=segmentation, bboxes=box_thing, bboxes0=box_face, 94 | class_labels=np.zeros(box_thing.shape[0])) 95 | return data["image"], data["mask"], data["bboxes"], data["bboxes0"], self.img_names[idx] 96 | 97 | def get_image(self, idx): 98 | raise NotImplementedError 99 | 100 | def parse_image_names(self): 101 | raise NotImplementedError 102 | 103 | def __len__(self): 104 | return len(self.img_names) 105 | 106 | 107 | class BaseCOCODataset(PreprocessedDataset): 108 | def get_image(self, idx): 109 | img_name = self.img_names[idx] 110 | path = os.path.join(self.root, img_name) 111 | image = cv2.imread(path) 112 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 113 | return image, img_name 114 | 115 | def parse_image_names(self): 116 | img_names = [] 117 | for directory in self.image_dirs: 118 | for filename in os.listdir(os.path.join(self.root, directory)): 119 | if os.path.splitext(filename)[1] in [".jpg", ".png"]: 120 | img_names.append(os.path.join(directory, filename)) 121 | np.savez(self.img_names_path, img_names=img_names) 122 | 123 | 124 | class COCO2014Dataset(BaseCOCODataset): 125 | name = "coco2014" 126 | image_dirs = "train2014" 127 | 128 | def __init__(self, root, preprocessed_folder, **kwargs): 129 | super().__init__( 130 | root=root, 131 | image_dirs=["train2014"], 132 | preprocessed_folder=preprocessed_folder, 133 | **kwargs, 134 | ) 135 | 136 | 137 | class COCO2017Dataset(BaseCOCODataset): 138 | name = "coco2017" 139 | image_dirs = "train2017" 140 | 141 | def __init__(self, root, preprocessed_folder, **kwargs): 142 | super().__init__( 143 | root=root, 144 | image_dirs=["train2017"], 145 | preprocessed_folder=preprocessed_folder, 146 | **kwargs, 147 | ) 148 | 149 | 150 | class Conceptual12mDataset(PreprocessedDataset): 151 | name = "cc12m" 152 | 153 | def __init__(self, root, preprocessed_folder, **kwargs): 154 | super().__init__( 155 | root=root, 156 | **kwargs, 157 | ) 158 | 159 | def parse_image_names(self, listfile): 160 | img_names = [] 161 | img_urls = [] 162 | with open(listfile, "r") as urllist: 163 | for i, line in enumerate(urllist): 164 | url, caption = line.split("\t") 165 | caption = caption.strip() 166 | img_names.append(caption + ".jpg") 167 | np.savez(self.img_names_path, img_names=img_names, img_urls=img_urls) 168 | 169 | def get_image(self, idx): 170 | img_name = self.img_names[idx] 171 | path = os.path.join(self.root, img_name) 172 | if not os.path.exists(path): 173 | self.download_image(self.url[idx], img_name) 174 | image = cv2.imread(path) 175 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 176 | return image, img_name 177 | 178 | def download_image(self, url, image_name): 179 | try: 180 | image_path = os.path.join(self.root, image_name) 181 | urlretrieve(url, image_path) 182 | return True 183 | except HTTPError: 184 | print("Failed to download the image: ", image_name) 185 | return False 186 | 187 | 188 | class ConcatDataset(ConcatDataset_): 189 | def get_true_idx(self, idx): 190 | if idx < 0: 191 | if -idx > len(self): 192 | raise ValueError("absolute value of index should not exceed dataset length") 193 | idx = len(self) + idx 194 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 195 | if dataset_idx == 0: 196 | sample_idx = idx 197 | else: 198 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 199 | return dataset_idx, sample_idx 200 | 201 | def get_image(self, idx): 202 | dataset_idx, sample_idx = self.get_true_idx(idx) 203 | return self.datasets[dataset_idx].get_image(sample_idx) 204 | 205 | 206 | if __name__ == "__main__": 207 | coco = COCO2014Dataset( 208 | "./mydb", "./mydb/preprocessed" 209 | ) 210 | from torchvision.utils import draw_bounding_boxes 211 | import matplotlib.pyplot as plt 212 | 213 | img, _, ft, fb, _ = coco[0] 214 | plt.imshow(draw_bounding_boxes(img, torch.tensor(ft + fb)).permute(1, 2, 0)) 215 | plt.show() 216 | print() 217 | -------------------------------------------------------------------------------- /Data/dataset_preprocessor_web.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import os 4 | import cv2 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset 8 | from torch.utils.data import ConcatDataset as ConcatDataset_ 9 | from tqdm import tqdm 10 | import torch.multiprocessing as mp 11 | import warnings 12 | from torchvision import transforms 13 | from hydra.utils import instantiate 14 | import albumentations as A 15 | from albumentations.pytorch import ToTensorV2 16 | from .utils import check_bboxes 17 | from urllib.request import urlretrieve 18 | from webdataset import WebDataset 19 | from webdataset.shardlists import split_by_node 20 | from webdataset.handlers import warn_and_continue 21 | from itertools import islice 22 | 23 | def my_split_by_node(src, group=None): 24 | rank, world_size, = int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]) 25 | if world_size > 1: 26 | for s in islice(islice(src, (rank*2)//world_size, None, 2), rank%(world_size//2), None, world_size//2): 27 | yield s 28 | else: 29 | for s in src: 30 | yield s 31 | 32 | 33 | class PreprocessData: 34 | def __init__(self, ready_queue): 35 | self.transforms = A.Compose([ 36 | A.SmallestMaxSize(512), 37 | A.CenterCrop(512, 512, always_apply=True), 38 | ToTensorV2(transpose_mask=True) 39 | ]) 40 | self.lasttar = "no" 41 | self.ready_queue = ready_queue 42 | 43 | def __call__(self, data): 44 | result = self.transforms(image=data["jpg"]) 45 | data["jpg"] = result["image"] 46 | data["tarname"] = os.path.basename(data["__url__"]) 47 | if self.lasttar!=data["tarname"]: 48 | rank = os.environ["RANK"] 49 | proc_type = os.environ["TYPE"] 50 | worker_info = torch.utils.data.get_worker_info() 51 | if worker_info is not None: 52 | worker = worker_info.id 53 | else: 54 | worker = None 55 | 56 | if self.lasttar != "no": 57 | self.ready_queue.put("%s/%s/processed/%s" % (rank+"-"+proc_type, worker, self.lasttar)) 58 | print(self.lasttar, "processed!") 59 | self.ready_queue.put("%s/%s/started/%s" % (rank+"-"+proc_type, worker, data["tarname"])) 60 | self.lasttar = data["tarname"] 61 | return data 62 | 63 | class UnprocessedWebDataset(WebDataset): 64 | def __init__(self, root, *args, is_dir=False, ready_queue=None, **kwargs): 65 | if is_dir: 66 | shards = [os.path.join(root, filename) for filename in os.listdir(root) if os.path.splitext(filename)[1]==".tar"] 67 | shards.sort() 68 | self.basedir = root 69 | else: 70 | shards = root 71 | self.basedir = os.path.dirname(root) 72 | super().__init__(shards, *args, nodesplitter=my_split_by_node, handler=warn_and_continue, **kwargs) 73 | self.decode("rgb") 74 | self.map(PreprocessData(ready_queue)) 75 | self.to_tuple("__key__", "tarname", "jpg") 76 | 77 | 78 | class ProcessData: 79 | def __init__(self,): 80 | self.pretransforms = A.Compose([ 81 | A.SmallestMaxSize(512), 82 | A.CenterCrop(512, 512, always_apply=True), 83 | ]) 84 | self.transforms = A.Compose([ 85 | #A.SmallestMaxSize(512), 86 | #A.CenterCrop(512, 512, always_apply=True), 87 | #A.RandomCrop(256, 256, always_apply=True), 88 | # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 89 | ToTensorV2(transpose_mask=True) 90 | ], bbox_params=A.BboxParams(format='pascal_voc', min_area=100, min_visibility=0.2), 91 | additional_targets={"bboxes0": "bboxes"}) 92 | 93 | def __call__(self, data): 94 | npz_data = data["npz"] 95 | # Panoptic 96 | seg_panoptic = F.one_hot( 97 | torch.from_numpy(npz_data["seg_panoptic"] + 1).to(torch.long), num_classes=134 98 | )[..., 1:] 99 | edges_panoptic = torch.from_numpy(npz_data["edge_panoptic"]).unsqueeze(-1) 100 | box_thing = npz_data["box_things"].tolist() 101 | for box in box_thing: 102 | box.append(0) 103 | 104 | # Human parts 105 | seg_human = F.one_hot( 106 | torch.from_numpy(npz_data["seg_human"] + 1).to(torch.long), num_classes=21 107 | )[..., 1:] 108 | edges_human = torch.from_numpy(npz_data["edge_human"]).unsqueeze(-1) 109 | 110 | # Edges 111 | seg_edges = (edges_panoptic + edges_human).float() 112 | 113 | # Face 114 | seg_face = F.one_hot( 115 | torch.from_numpy(npz_data["seg_face"]).to(torch.long), num_classes=6 116 | )[..., 1:] 117 | box_face = npz_data["box_face"].tolist() 118 | for box in box_face: 119 | box.append(0) 120 | 121 | # Concatenate masks 122 | seg_map = torch.cat( 123 | [seg_panoptic, seg_human, seg_face, seg_edges], dim=-1 124 | ).numpy() 125 | 126 | data["jpg"] = self.pretransforms(image=data["jpg"])["image"] 127 | box_thing = check_bboxes(box_thing) 128 | box_face = check_bboxes(box_face) 129 | transformed_data = self.transforms(image=data["jpg"], bboxes=box_thing, bboxes0=box_face,) 130 | data["jpg"] = transformed_data["image"] 131 | data["mask"] = seg_map 132 | data["box_things"] = transformed_data["bboxes"] 133 | data["box_face"] = transformed_data["bboxes0"] 134 | return data 135 | 136 | 137 | class PreprocessedWebDataset(WebDataset): 138 | def __init__(self, url, *args, **kwargs): 139 | super().__init__(url, *args, nodesplitter=split_by_node, handler=warn_and_continue, **kwargs) 140 | self.decode("rgb") 141 | #self.decode("npz") 142 | self.map(ProcessData(), handler=warn_and_continue) 143 | self.to_tuple("jpg", "mask", "box_things", "box_face", "txt", handler=warn_and_continue) 144 | 145 | class COCOWebDataset(PreprocessedWebDataset): 146 | def __init__(self, *args, **kwargs): 147 | super().__init__("pipe:aws s3 cp s3://s-mas/coco_processed/{00000..00010}.tar -", *args, **kwargs) 148 | 149 | class CC3MWebDataset(PreprocessedWebDataset): 150 | def __init__(self, *args, **kwargs): 151 | super().__init__("pipe:aws s3 cp s3://s-mas/cc3m_processed/{00000..00311}.tar -", *args, **kwargs) 152 | 153 | class S3ProcessedDataset(PreprocessedWebDataset): 154 | datasets = { 155 | "coco": "pipe:aws s3 cp s3://s-mas/coco_processed/{00000..00059}.tar -", 156 | "cc3m": "pipe:aws s3 cp s3://s-mas/cc3m_processed/{00000..00331}.tar -", 157 | "cc12m": "pipe:aws s3 cp s3://s-mas/cc12m_processed/{00000..01242}.tar -", 158 | "laion": "pipe:aws s3 cp s3://s-mas/laion_en_processed/{00000..01500}.tar -" 159 | } 160 | def __init__(self, names, *args, **kwargs): 161 | urls = [] 162 | for name in names: 163 | assert name in self.datasets, f"There is no processed dataset {name}" 164 | urls.append(self.datasets[name]) 165 | urls = "::".join(urls) 166 | super().__init__(urls, *args, **kwargs) 167 | 168 | if __name__ == "__main__": 169 | coco = COCO2014Dataset( 170 | "./mydb", "./mydb/preprocessed" 171 | ) 172 | from torchvision.utils import draw_bounding_boxes 173 | import matplotlib.pyplot as plt 174 | 175 | img, _, ft, fb, _ = coco[0] 176 | plt.imshow(draw_bounding_boxes(img, torch.tensor(ft + fb)).permute(1, 2, 0)) 177 | plt.show() 178 | print() 179 | 180 | -------------------------------------------------------------------------------- /Data/preprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.multiprocessing as mp 4 | from .preprocessors import Detectron2 5 | from .preprocessors import HumanParts 6 | from .preprocessors import HumanFace 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | 11 | class BasePreprocessor: 12 | proc_types = {"panoptic": Detectron2, "human": HumanParts, "face": HumanFace} 13 | 14 | def __init__( 15 | self, 16 | preprocessed_folder, 17 | proc_per_gpu=None, 18 | proc_per_cpu=None, 19 | devices=None, 20 | machine_idx=0, 21 | machines_total=1, 22 | ): 23 | self.idx = machine_idx 24 | self.machines_total = machines_total 25 | self.devices = list(devices) 26 | self.proc_per_gpu = proc_per_gpu 27 | self.proc_per_cpu = proc_per_cpu 28 | self.preprocessed_folder = preprocessed_folder 29 | self.preprocessed_path = os.path.join( 30 | preprocessed_folder, 31 | "segmentations", 32 | f"%s_%s.npz", 33 | ) 34 | self.log_path = os.path.join( 35 | preprocessed_folder, 36 | f"%s_%s.log", 37 | ) 38 | 39 | def __call__(self, dataset): 40 | self.dataset = dataset 41 | assert torch.cuda.is_available(), "GPU required for preprocessing" 42 | procs = [] 43 | mp.set_start_method("spawn") 44 | for proc_type in self.proc_per_gpu: 45 | devices = self.devices * self.proc_per_gpu[proc_type] 46 | n_cpus = self.proc_per_cpu[proc_type] 47 | proc_per_machine = len(devices) + n_cpus 48 | proc_total = proc_per_machine * self.machines_total 49 | # GPUs 50 | for proc_id, dev_id in enumerate(devices): 51 | p = mp.Process( 52 | target=self.preprocess_single_process, 53 | args=( 54 | proc_type, 55 | self.idx * proc_per_machine + proc_id, 56 | dev_id, 57 | proc_total, 58 | ), 59 | ) 60 | p.start() 61 | procs.append(p) 62 | # CPUs 63 | for proc_id in range(n_cpus): 64 | p = mp.Process( 65 | target=self.preprocess_single_process, 66 | args=( 67 | proc_type, 68 | self.idx * proc_per_machine + len(devices) + proc_id, 69 | "cpu", 70 | proc_total, 71 | ), 72 | ) 73 | p.start() 74 | procs.append(p) 75 | for proc in procs: 76 | proc.join() 77 | 78 | def preprocess_single_process(self, proc_type, proc_id, dev_id, proc_total): 79 | correct_names = [] 80 | if dev_id != "cpu": 81 | torch.cuda.set_device( # https://github.com/pytorch/pytorch/issues/21819#issuecomment-553310128 82 | dev_id 83 | ) 84 | device = f"cuda:{dev_id}" 85 | else: 86 | device = "cpu" 87 | processor = self.proc_types[proc_type](device=device) 88 | log_path = self.log_path % (proc_id, proc_type) 89 | self.check_path(log_path) 90 | with open(log_path, "w") as logfile: 91 | for idx in tqdm(range(len(self.dataset)), file=logfile): 92 | if idx % proc_total != proc_id: 93 | continue 94 | image, image_name = self.dataset.get_image(idx) 95 | image_name = os.path.splitext(image_name)[0] 96 | data = processor(image) 97 | save_path = self.preprocessed_path % (image_name, proc_type) 98 | self.check_path(save_path) 99 | np.savez(save_path, **data) 100 | 101 | def check_path(self, path): 102 | dirname = os.path.dirname(path) 103 | if not os.path.exists(dirname): 104 | try: 105 | os.makedirs(dirname) 106 | except FileExistsError: 107 | pass 108 | -------------------------------------------------------------------------------- /Data/preprocessor_web.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.multiprocessing as mp 4 | from torch.utils.data import DataLoader 5 | from .preprocessors import Detectron2 6 | from .preprocessors import HumanParts 7 | from .preprocessors import HumanFace 8 | from tqdm import tqdm 9 | import numpy as np 10 | import hydra 11 | from webdataset import WebDataset, TarWriter 12 | from webdataset.handlers import warn_and_continue 13 | from time import time, sleep 14 | import json 15 | import shutil 16 | import fsspec 17 | import queue 18 | 19 | 20 | class WebPreprocessor: 21 | proc_types = {"panoptic": Detectron2, "human": HumanParts, "face": HumanFace} 22 | 23 | def __init__( 24 | self, 25 | preprocessed_folder, 26 | output_folder, 27 | proc_per_gpu=None, 28 | proc_per_cpu=None, 29 | devices=None, 30 | machine_idx=0, 31 | machines_total=1, 32 | batch_size=5, 33 | num_workers=2 34 | ): 35 | self.idx = machine_idx 36 | self.machines_total = machines_total 37 | self.devices = list(devices) 38 | self.output_folder = output_folder 39 | self.proc_per_gpu = proc_per_gpu 40 | self.proc_per_cpu = proc_per_cpu 41 | self.batch_size = batch_size 42 | self.num_workers = num_workers 43 | self.preprocessed_folder = preprocessed_folder 44 | self.preprocessed_path = os.path.join( 45 | preprocessed_folder, 46 | "untars", 47 | f"%s/%s_%s.npz", 48 | ) 49 | self.repacked_path = os.path.join( 50 | preprocessed_folder, 51 | "tars", 52 | ) 53 | self.log_path = os.path.join( 54 | preprocessed_folder, 55 | f"%s\\%s_%s.log", 56 | ) 57 | 58 | def __call__(self, dataset): 59 | self.dataset = dataset 60 | assert torch.cuda.is_available(), "GPU required for preprocessing" 61 | procs = [] 62 | mp.set_start_method("spawn") 63 | ready_queue = mp.Queue() 64 | proc_type_locks = {proc_type: mp.Value("b", 0) for proc_type in self.proc_types} 65 | proc_per_machine = 0 66 | for value in self.proc_per_gpu: 67 | proc_per_machine += len(value) 68 | for proc_type in self.proc_per_gpu: 69 | devices = len(self.proc_per_gpu[proc_type]) 70 | n_cpus = self.proc_per_cpu[proc_type] 71 | proc_total_type = devices * self.machines_total 72 | # GPUs 73 | for proc_id, dev_id in enumerate(self.proc_per_gpu[proc_type]): 74 | p = mp.Process( 75 | target=self.preprocess_single_process, 76 | args=( 77 | proc_type, 78 | self.idx * devices + proc_id, 79 | dev_id, 80 | proc_total_type, 81 | ready_queue, 82 | proc_type_locks[proc_type] 83 | ), 84 | ) 85 | p.start() 86 | procs.append(p) 87 | # CPUs 88 | for proc_id in range(n_cpus): 89 | p = mp.Process( 90 | target=self.preprocess_single_process, 91 | args=( 92 | proc_type, 93 | self.idx * proc_per_machine + len(devices) + proc_id, 94 | "cpu", 95 | proc_total, 96 | ready_queue, 97 | ), 98 | ) 99 | p.start() 100 | procs.append(p) 101 | 102 | self.repacker_process(ready_queue, proc_per_machine, proc_type_locks) 103 | for proc in procs: 104 | proc.join() 105 | 106 | def preprocess_single_process(self, proc_type, proc_id, dev_id, proc_total, ready_queue, proc_type_lock): 107 | os.environ["RANK"] = str(proc_id) 108 | os.environ["TYPE"] = proc_type 109 | os.environ["WORLD_SIZE"] = str(proc_total) 110 | ready_queue.put("%s/%s/Init/%s" %(proc_id, proc_total, proc_type)) 111 | dataset = hydra.utils.instantiate(self.dataset, ready_queue=ready_queue) 112 | dataloader = DataLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) 113 | correct_names = [] 114 | torch.set_num_threads(15) 115 | if dev_id != "cpu": 116 | torch.cuda.set_device( # https://github.com/pytorch/pytorch/issues/21819#issuecomment-553310128 117 | dev_id 118 | ) 119 | device = f"cuda:{dev_id}" 120 | else: 121 | device = "cpu" 122 | processor = self.proc_types[proc_type](device=device) 123 | log_path = self.log_path % (proc_id, proc_total, proc_type) 124 | self.check_path(log_path) 125 | with open(log_path, "w") as logfile: 126 | x = 0 127 | t = time() 128 | times = [] 129 | for batch in tqdm(dataloader, file=logfile): 130 | if proc_type_lock.value: 131 | print("Waiting slower preprocessors ", proc_type) 132 | while proc_type_lock.value: 133 | sleep(0.1) 134 | print("Waiting released", proc_type) 135 | times.append(str(time()-t)) 136 | t = time() 137 | imgnames, tarnames, images = batch 138 | times.append(str(time()-t)) 139 | t = time() 140 | batched_data = processor(images*255.) 141 | times.append(str(time()-t)) 142 | t = time() 143 | #print(proc_id, proc_total, imgnames, tarnames) 144 | #print(batched_data["box_things"]) 145 | for i in range(len(imgnames)): 146 | tarname, imgname = tarnames[i], imgnames[i] 147 | save_path = self.preprocessed_path % (tarname, imgname, proc_type) 148 | self.check_path(save_path) 149 | data = {key: batched_data[key][i] for key in batched_data} 150 | np.savez(save_path, **data) 151 | times.append(str(time()-t)) 152 | t = time() 153 | #ready_queue.put("%s/%s/info/%s" % (proc_id, proc_type, ",".join(times))) 154 | times = [] 155 | ready_queue.put("%s/%s/done/und" % (str(proc_id) +"-"+ proc_type, proc_type)) 156 | 157 | def repacker_process(self, ready_queue, proc_per_machine, proc_type_locks): 158 | proc_done = 0 159 | procs = [] 160 | max_repackings = 20 161 | repackings_queue = queue.Queue() 162 | repacking_done = mp.Queue() 163 | info = {"repackings": 0} 164 | progress = {"panoptic": 0, "human": 0, "face": 0} 165 | while proc_done < proc_per_machine: 166 | command = ready_queue.get() 167 | print("Got", command) 168 | proc_id, worker, state, tarname = command.split("/") 169 | if state == "done": 170 | proc_done += 1 171 | for worker in info[proc_id]: 172 | tarname = info[proc_id][worker] 173 | info[tarname] = 0 if tarname not in info else info[tarname] 174 | info[tarname] += 1 175 | if info[tarname] == 3: 176 | repackings_queue.put(tarname) 177 | #self.repack_single_tar(tarname) 178 | elif state == "started": 179 | if proc_id not in info: 180 | info[proc_id] = {} 181 | info[proc_id][worker] = tarname 182 | elif state == "processed": 183 | info[tarname] = 0 if tarname not in info else info[tarname] 184 | info[tarname] += 1 185 | proc_type = proc_id.split("-")[1] 186 | progress[proc_type] += 1 187 | if info[tarname] == 3: 188 | repackings_queue.put(tarname) 189 | #self.repack_single_tar(tarname) 190 | if progress[proc_type] - np.min(list(progress.values())) > 30: 191 | proc_type_locks[proc_type].value = 1 192 | if progress[proc_type] == np.min(list(progress.values())): 193 | for t in proc_type_locks: 194 | proc_type_locks[t].value = 0 195 | 196 | 197 | elif state== "info": 198 | continue 199 | 200 | while repacking_done.qsize() > 0: 201 | msg = repacking_done.get() 202 | with open("info.log", "a") as f: 203 | f.write(msg+"\n") 204 | info["repackings"] -= 1 205 | 206 | # Repack original and processed data to new tar 207 | while info["repackings"] < max_repackings and repackings_queue.qsize() > 0: 208 | tarname = repackings_queue.get() 209 | print("Started repacking", tarname) 210 | p = mp.Process( 211 | target=self.repack_single_tar, 212 | args=( 213 | tarname, 214 | repacking_done, 215 | ), 216 | ) 217 | p.start() 218 | procs.append(p) 219 | info["repackings"] += 1 220 | 221 | 222 | with open("info.state", "w") as f: 223 | line = json.dumps(info, indent=3, sort_keys=True) 224 | f.write(str(line)) 225 | with open("info.log", "a") as f: 226 | f.write(command+"\n") 227 | 228 | 229 | 230 | for p in procs: 231 | p.join() 232 | print("Processed all data!") 233 | 234 | def repack_single_tar(self, tarname, repacking_done): 235 | if os.path.isdir(self.dataset.root): 236 | root = self.datatset.root 237 | else: 238 | root = os.path.dirname(self.dataset.root) 239 | old_data = WebDataset(os.path.join(root, tarname), handler=warn_and_continue) 240 | output_folder = "s3://s-mas/" + self.output_folder 241 | fs, output_path = fsspec.core.url_to_fs(output_folder) 242 | tar_fd = fs.open(f"{output_path}/{tarname.split(' ')[0]}", "wb") 243 | new_data = TarWriter(tar_fd) 244 | for sample in old_data: 245 | new_sample = {} 246 | imgname = new_sample["__key__"] = sample["__key__"] 247 | new_sample["jpg"] = sample["jpg"] 248 | new_sample["txt"] = sample["txt"] 249 | 250 | data_face = np.load(self.preprocessed_path % (tarname, imgname, "face"), allow_pickle=True) 251 | data_human = np.load(self.preprocessed_path % (tarname, imgname, "human")) 252 | data_panoptic = np.load(self.preprocessed_path % (tarname, imgname, "panoptic")) 253 | 254 | data_merged = {} 255 | data_merged["seg_panoptic"] = data_panoptic["seg_panoptic"] 256 | data_merged["edge_panoptic"] = data_panoptic["edges"] 257 | data_merged["box_things"] = data_panoptic["box_things"] 258 | data_merged["seg_human"] = data_human["seg_human"] 259 | data_merged["edge_human"] = data_human["edges"] 260 | data_merged["seg_face"] = data_face["seg_face"] 261 | data_merged["box_face"] = data_face["box_face"] 262 | new_sample["npz"] = data_merged 263 | 264 | new_data.write(new_sample) 265 | print("Finished repacking", tarname) 266 | shutil.rmtree(os.path.join(self.preprocessed_folder, "untars", tarname)) 267 | new_data.close() 268 | repacking_done.put("Finished repacking "+ tarname) 269 | 270 | 271 | 272 | def check_path(self, path): 273 | dirname = os.path.dirname(path) 274 | if not os.path.exists(dirname): 275 | try: 276 | os.makedirs(dirname) 277 | except FileExistsError: 278 | pass 279 | -------------------------------------------------------------------------------- /Data/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | from .detectron2_preprocessor import PanopticPreprocesor as Detectron2 2 | from .human_parts_preprocessor import HumanPartsPreprocessor as HumanParts 3 | from .face_alignment_preprocessor import FaceAlignmentPreprocessor as HumanFace 4 | from .edge_extractor import get_edges 5 | -------------------------------------------------------------------------------- /Data/preprocessors/detectron2_preprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import torch.nn.functional as F 5 | # from torchvision.ops import masks_to_boxes 6 | import numpy as np 7 | import detectron2 8 | from detectron2.projects.panoptic_deeplab import add_panoptic_deeplab_config 9 | from detectron2.engine import DefaultPredictor 10 | from detectron2.checkpoint import DetectionCheckpointer 11 | from detectron2.config import get_cfg 12 | from detectron2.modeling import build_model 13 | from .edge_extractor import get_edges 14 | 15 | 16 | def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: 17 | if masks.numel() == 0: 18 | return torch.zeros((0, 4), device=masks.device, dtype=torch.float) 19 | n = masks.shape[0] 20 | bounding_boxes = torch.zeros((n, 4), device=masks.device, dtype=torch.float) 21 | for index, mask in enumerate(masks): 22 | y, x = torch.where(mask != 0) 23 | 24 | bounding_boxes[index, 0] = torch.min(x) 25 | bounding_boxes[index, 1] = torch.min(y) 26 | bounding_boxes[index, 2] = torch.max(x) 27 | bounding_boxes[index, 3] = torch.max(y) 28 | 29 | return bounding_boxes 30 | 31 | 32 | class Predictor: 33 | def __init__(self, cfg): 34 | self.cfg = cfg.clone() 35 | self.model = build_model(self.cfg) 36 | self.model.eval() 37 | checkpointer = DetectionCheckpointer(self.model) 38 | checkpointer.load(cfg.MODEL.WEIGHTS) 39 | self.input_format = cfg.INPUT.FORMAT 40 | 41 | def __call__(self, imgs: np.array): 42 | # imgs should be numpy b x c x h x w 43 | with torch.no_grad(): 44 | #imgs = torch.as_tensor(imgs.astype("float32")) 45 | if self.input_format == "RGB": 46 | # whether the model expects BGR inputs or RGB 47 | imgs = imgs.flip([1]) 48 | 49 | height, width = imgs.shape[2:] 50 | 51 | inputs = [{"image": image, "height": height, "width": width} for image in imgs] 52 | predictions = self.model(inputs) 53 | return predictions 54 | 55 | 56 | class PanopticPreprocesor: 57 | proc_type = "panoptic" 58 | def __init__( 59 | self, 60 | config="/home/ubuntu/anaconda3/envs/schp/lib/python3.8/site-packages/detectron2/projects/panoptic_deeplab/configs/COCO-PanopticSegmentation/panoptic_deeplab_R_52_os16_mg124_poly_200k_bs64_crop_640_640_coco_dsconv.yaml", 61 | num_classes=133, 62 | model_weights="/home/ubuntu/anaconda3/envs/schp/lib/python3.8/site-packages/detectron2/projects/panoptic_deeplab/configs/COCO-PanopticSegmentation/model_final_5e6da2.pkl", 63 | device=None, 64 | ): 65 | self.num_classes = num_classes 66 | 67 | cfg = get_cfg() 68 | print(cfg.MODEL.DEVICE) 69 | opts = ["MODEL.WEIGHTS", model_weights] 70 | add_panoptic_deeplab_config(cfg) 71 | cfg.merge_from_file(config) 72 | cfg.merge_from_list(opts) 73 | if device is not None: 74 | cfg.MODEL.DEVICE = device 75 | cfg.freeze() 76 | print("Building model") 77 | self.predictor = Predictor(cfg) 78 | 79 | def bounding_boxes(self, panoptics): 80 | all_boxes = [] 81 | for panoptic in panoptics: 82 | #panoptic = torch.Tensor(panoptic) 83 | obj_ids = torch.unique(panoptic) 84 | thing_ids = obj_ids[obj_ids/1000 < 80] # according to panopticapi first 80 classes are "things" 85 | binary_masks = panoptic == thing_ids[:, None, None] 86 | boxes = masks_to_boxes(binary_masks) 87 | all_boxes.append(boxes.cpu().numpy()) 88 | return all_boxes 89 | 90 | def __call__(self, imgs: np.array): 91 | # imgs should be numpy b x c x h x w 92 | data = {} 93 | # Returns tensor of shape [H, W] with values equal to 1000*class_id + instance_idx 94 | # panoptic = self.predictor(imgs)["panoptic_seg"][0].cpu() 95 | panoptic = self.predictor(imgs) 96 | panoptic = list(map(lambda pan: pan["panoptic_seg"][0], panoptic)) 97 | bounding_boxes = self.bounding_boxes(panoptic) 98 | panoptic = list(map(lambda pan: pan.cpu().numpy(), panoptic)) 99 | panoptic = np.array(panoptic) 100 | edges = get_edges(panoptic) 101 | data["seg_panoptic"] = np.array(panoptic // 1000, dtype=np.uint8) 102 | data["box_things"] = bounding_boxes 103 | data["edges"] = edges.astype(bool) 104 | return data 105 | 106 | 107 | if __name__ == "__main__": 108 | # img = cv2.imread("humans.jpg") 109 | img = np.random.randint(0, 255, (5, 3, 300, 300), dtype=np.uint8) 110 | detectron2 = PanopticPreprocesor() 111 | panoptic = detectron2(img) 112 | print(panoptic) 113 | # torch.save(panoptic, "test.pth") 114 | -------------------------------------------------------------------------------- /Data/preprocessors/edge_extractor.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | THICKNESS = 1 5 | 6 | 7 | def get_edges(masks): 8 | #!face contours are not used 9 | all_edges = np.zeros(masks.shape) 10 | for i, mask in enumerate(masks): 11 | edges = np.zeros(masks[0].shape) 12 | contours, _ = cv2.findContours( 13 | mask, cv2.RETR_FLOODFILL, cv2.CHAIN_APPROX_SIMPLE 14 | ) 15 | edges = cv2.drawContours(edges, contours, -1, 1, THICKNESS) 16 | all_edges[i] = edges 17 | return all_edges 18 | -------------------------------------------------------------------------------- /Data/preprocessors/face_alignment_preprocessor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | import face_alignment 5 | import ssl 6 | import matplotlib.pyplot as plt 7 | ssl._create_default_https_context = ssl._create_unverified_context 8 | from time import time 9 | 10 | BEARD = 0 11 | BROW = 1 12 | NOSE = 2 13 | EYE = 3 14 | MOUTH = 4 15 | 16 | 17 | class FaceAlignmentPreprocessor: 18 | proc_type = "face" 19 | last_beard = 17 20 | last_brow = 27 21 | last_nose = 36 22 | last_eye = 48 23 | last_mouth = 68 24 | 25 | def __init__(self, n_classes=5, face_confidence=0.95, device="cuda"): 26 | self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, face_detector_kwargs={"filter_threshold":face_confidence}, device=device) 27 | self.face_confidence = face_confidence 28 | self.n_classes = n_classes 29 | self.class_idxs = { 30 | BEARD: (0, self.last_beard), 31 | BROW: (self.last_beard, self.last_brow), 32 | NOSE: (self.last_brow, self.last_nose), 33 | EYE: (self.last_nose, self.last_eye), 34 | MOUTH: (self.last_eye, self.last_mouth), 35 | } 36 | #self.class_idxs = { 37 | # BEARD: torch.arange(0, self.last_beard), 38 | # BROW: torch.arange(self.last_beard, self.last_brow), 39 | # NOSE: torch.arange(self.last_brow, self.last_nose), 40 | # EYE: torch.arange(self.last_nose, self.last_eye), 41 | # MOUTH: torch.arange(self.last_eye, self.last_mouth), 42 | #} 43 | 44 | 45 | def process_image(self, img): 46 | img = img[:, :, ::-1] # face_alignment work with BGR colorspace 47 | points = self.fa.get_landmarks(img) 48 | seg_mask = torch.zeros(*img.shape[:-1]) 49 | if points is None: 50 | return seg_mask 51 | for face in points: 52 | face = face.astype(int) 53 | for class_id in range(self.n_classes): 54 | for point_id in self.class_idxs[class_id]: 55 | try: 56 | seg_mask[face[point_id, 1], face[point_id, 0]] = class_id + 1 57 | except IndexError: 58 | # Probably only part of the face on the image 59 | pass 60 | return seg_mask # F.one_hot(seg_mask.to(torch.long), num_classes=6)[..., 1:].permute(2, 0, 1) 61 | 62 | def interpolate_face(self, face): 63 | interpolation = [] 64 | for class_id in range(self.n_classes): 65 | part_interpolation = [] 66 | part = face[self.class_idxs[class_id]] 67 | for idx, (i, j) in enumerate(zip(part, part[1:])): 68 | if self.class_idxs[class_id][idx] in (21, 41): # to avoid that both eyes (or both brows) are connected 69 | continue 70 | # print(self.class_idxs[class_id][idx]) 71 | part_interpolation.extend( 72 | list(np.round(np.linspace(i, j, 100)).astype(np.int32)) + 73 | list(np.round(np.linspace(i, j, 100)).astype(np.int32) + [0, 1]) + 74 | list(np.round(np.linspace(i, j, 100)).astype(np.int32) + [0, -1]) + 75 | list(np.round(np.linspace(i, j, 100)).astype(np.int32) + [1, 0]) + 76 | list(np.round(np.linspace(i, j, 100)).astype(np.int32) + [-1, 0]) 77 | ) 78 | interpolation.append(part_interpolation) 79 | return interpolation 80 | def process_image_interpolated(self, imgs: np.array): 81 | # imgs should be numpy b x c x h x w 82 | imgs = imgs.flip([1]) # face_alignment works with BGR colorspace 83 | faces = self.fa.face_detector.detect_from_batch(imgs) 84 | # faces = list(filter(lambda face: face[-1] > self.face_confidence, faces)) 85 | faces = list(map(lambda img: list(filter(lambda face: face[-1] > self.face_confidence, img)), faces)) 86 | batched_points = self.fa.get_landmarks_from_batch(imgs, detected_faces=faces) 87 | seg_mask = np.zeros((imgs.shape[0], *imgs.shape[2:])) 88 | if batched_points is None: 89 | return seg_mask 90 | for i, points in enumerate(batched_points): 91 | for face in points: 92 | face = self.interpolate_face(face.astype(int)) 93 | for class_id in range(self.n_classes): 94 | for point in face[class_id]: 95 | try: 96 | seg_mask[i, point[1], point[0]] = class_id + 1 97 | except IndexError as e: 98 | # Probably only part of the face on the image 99 | pass 100 | boxes = [[face[:-1] for face in faces_in_image] for faces_in_image in faces] 101 | return seg_mask, boxes # F.one_hot(seg_mask.to(torch.long), num_classes=6)[..., 1:].permute(2, 0, 1) 102 | 103 | def draw_interpolated_face(self,mask, face): 104 | for class_id in range(self.n_classes): 105 | start, stop = self.class_idxs[class_id] 106 | if class_id not in (EYE, BROW): 107 | cv2.drawContours(mask, [face[start: stop]], 0 , (class_id+1), 1) 108 | else: 109 | step = (stop-start)//2 110 | cv2.drawContours(mask, [face[start:start+step]], 0, (class_id+1), 1) 111 | cv2.drawContours(mask, [face[start+step: stop]], 0, (class_id+1), 1) 112 | return mask 113 | 114 | def process_image_interpolated_fast(self, imgs: np.array): 115 | # imgs should be numpy b x c x h x w 116 | t = time() 117 | times = [] 118 | imgs = imgs.flip([1]) # face_alignment works with BGR colorspace 119 | faces = self.fa.face_detector.detect_from_batch(imgs) 120 | times.append([time()-t]) 121 | t = time() 122 | batched_points = self.fa.get_landmarks_from_batch(imgs, detected_faces=faces) 123 | seg_mask = np.zeros((imgs.shape[0], *imgs.shape[2:])) 124 | times.append([time()-t]) 125 | t = time() 126 | for image_idx, points in enumerate(batched_points): 127 | for face in points: 128 | seg_mask[image_idx] = self.draw_interpolated_face(seg_mask[image_idx], face.astype(int)) 129 | boxes = [[face[:-1] for face in faces_in_image] for faces_in_image in faces] 130 | times.append([time()-t]) 131 | t = time() 132 | #print("HERE!!!", times) 133 | return seg_mask, boxes # F.one_hot(seg_mask.to(torch.long), num_classes=6)[..., 1:].permute(2, 0, 1) 134 | 135 | def process_image_interpolated_old(self, img): 136 | img = img[:, :, ::-1] # face_alignment work with BGR colorspace 137 | faces = self.fa.face_detector.detect_from_image(img.copy()) 138 | faces = list(filter(lambda face: face[-1] > self.face_confidence, faces)) 139 | points = self.fa.get_landmarks(img, detected_faces=faces) 140 | seg_mask = np.zeros(img.shape[:-1]) 141 | if points is None: 142 | return seg_mask 143 | for face in points: 144 | face = self.interpolate_face(face.astype(int)) 145 | for class_id in range(self.n_classes): 146 | for point in face[class_id]: 147 | try: 148 | seg_mask[point[1], point[0]] = class_id + 1 149 | except IndexError as e: 150 | # Probably only part of the face on the image 151 | pass 152 | boxes = [face[:-1] for face in faces] 153 | return seg_mask, boxes # F.one_hot(seg_mask.to(torch.long), num_classes=6)[..., 1:].permute(2, 0, 1) 154 | 155 | def plot_face(self, seg_mask: torch.Tensor): 156 | plt.imshow(seg_mask.clamp(0, 1).detach().cpu().numpy(), cmap="gray") 157 | plt.show() 158 | 159 | def __call__(self, img): 160 | data = {} 161 | mask, boxes = self.process_image_interpolated_fast(img) 162 | # mask, boxes = self.process_image_interpolated_old(img) 163 | data["seg_face"] = mask.astype(np.uint8) 164 | data["box_face"] = boxes 165 | return data 166 | 167 | 168 | if __name__ == "__main__": 169 | face_alignment_preprocessor = FaceAlignmentPreprocessor() 170 | img = cv2.imread("humans.jpg") # cv2 has other order of channels. 171 | print(img.shape) 172 | img = np.repeat(img.transpose(2, 0, 1)[None, ...], 5, axis=0) 173 | data = face_alignment_preprocessor(img) 174 | #face_alignment_preprocessor.plot_face(alignment) 175 | print(data["seg_face"].shape) 176 | print(len(data["box_face"])) 177 | print(data["box_face"]) 178 | # torch.save(alignment, "alignment.pth") 179 | -------------------------------------------------------------------------------- /Data/preprocessors/human_parts_preprocessor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | import torchvision.transforms as transforms 7 | from torchgeometry import warp_affine 8 | 9 | HUMAN_PARSER_DIR = "/home/ubuntu/MakeAScene/Self-Correction-Human-Parsing" 10 | sys.path.append(HUMAN_PARSER_DIR) 11 | import networks 12 | from simple_extractor import get_palette, dataset_settings 13 | from collections import OrderedDict 14 | from utils.transforms import get_affine_transform 15 | from .edge_extractor import get_edges 16 | import torchvision 17 | 18 | 19 | def transform_logits(logits, center, scale, width, height, input_size): 20 | trans = torch.Tensor(get_affine_transform(center, scale, 0, input_size, inv=1)).expand(logits.shape[0], -1, -1) 21 | target_logits = warp_affine(logits, trans, (int(width), int(height))) 22 | return target_logits 23 | 24 | 25 | class HumanPartsPreprocessor: 26 | proc_type = "human" 27 | def __init__( 28 | self, 29 | weights=HUMAN_PARSER_DIR + "/checkpoints/final.pth", 30 | device="cuda", 31 | ): 32 | self.device = device 33 | 34 | dataset_info = dataset_settings["lip"] 35 | self.num_classes = dataset_info["num_classes"] 36 | self.input_size = dataset_info["input_size"] 37 | self.model = networks.init_model( 38 | "resnet101", num_classes=self.num_classes, pretrained=None 39 | ) 40 | self.aspect_ratio = self.input_size[1] * 1.0 / self.input_size[0] 41 | 42 | state_dict = torch.load(weights)["state_dict"] 43 | new_state_dict = OrderedDict() 44 | for k, v in state_dict.items(): 45 | name = k[7:] # remove `module.` 46 | new_state_dict[name] = v 47 | self.model.load_state_dict(new_state_dict) 48 | self.model.eval() 49 | self.model.to(device) 50 | 51 | self.transform = transforms.Compose( 52 | [ 53 | # transforms.ToTensor(), 54 | transforms.Normalize( 55 | mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229] 56 | ), 57 | ] 58 | ) 59 | self.upsample = torch.nn.Upsample( 60 | size=self.input_size, mode="bilinear", align_corners=True 61 | ) 62 | 63 | def _box2cs(self, box): 64 | x, y, w, h = box[:4] 65 | return self._xywh2cs(x, y, w, h) 66 | 67 | def _xywh2cs(self, x, y, w, h): 68 | center = np.zeros((2), dtype=np.float32) 69 | center[0] = x + w * 0.5 70 | center[1] = y + h * 0.5 71 | if w > self.aspect_ratio * h: 72 | h = w * 1.0 / self.aspect_ratio 73 | elif w < self.aspect_ratio * h: 74 | w = h * self.aspect_ratio 75 | scale = np.array([w, h], dtype=np.float32) 76 | return center, scale 77 | 78 | def segment_image(self, imgs: np.array): 79 | # imgs should be numpy b x c x h x w 80 | #imgs = torch.Tensor(imgs) 81 | b, _, h, w = imgs.shape # we can assume b x 3 x 512 x 512 82 | # check if h, w in correct order 83 | # Get person center and scale 84 | #person_center, scale = self._box2cs([0, 0, w - 1, h - 1]) 85 | #c = person_center 86 | #s = scale 87 | #r = 0 88 | #trans = torch.Tensor(get_affine_transform(person_center, s, r, self.input_size)).expand(b, -1, -1) 89 | #imgs = warp_affine(imgs, trans, (int(self.input_size[1]), int(self.input_size[0]))) 90 | imgs = torchvision.transforms.functional.resize(imgs, [self.input_size[1], self.input_size[0]]) 91 | 92 | imgs = self.transform(imgs/255.).to(self.device) 93 | 94 | with torch.no_grad(): 95 | output = self.model(imgs) 96 | upsample_output = self.upsample(output[0][-1]) # reshapes from 1, 20, 119, 119 to 1, 20, 473, 473 97 | 98 | #logits_result = transform_logits(upsample_output.cpu(), c, s, w, h, input_size=self.input_size) 99 | logits_result = torchvision.transforms.functional.resize(upsample_output, [h, w]) 100 | mask = logits_result.argmax(dim=1) 101 | return mask.cpu().numpy() 102 | 103 | def __call__(self, imgs): 104 | data = {} 105 | mask = self.segment_image(imgs) 106 | edges = get_edges(mask) 107 | data["seg_human"] = mask.astype(np.uint8) 108 | data["edges"] = edges.astype(bool) 109 | return data 110 | 111 | 112 | if __name__ == "__main__": 113 | human_processor = HumanPartsPreprocessor() 114 | img = cv2.imread("humans.jpg") 115 | img = torch.randint_like(torch.Tensor(img), 0, 255, dtype=torch.float).permute(2, 0, 1).expand(5, -1, -1, -1).numpy() 116 | masks = human_processor(img) 117 | print(masks) 118 | -------------------------------------------------------------------------------- /Data/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def check_bbox(bbox): 3 | """Check if bbox minimums are lesser then maximums""" 4 | x_min, y_min, x_max, y_max = bbox[:4] 5 | bbox[0] = 0 if x_min < 0 else x_min 6 | bbox[1] = 0 if y_min < 0 else y_min 7 | bbox[2] = 511 if x_max >511 else x_max 8 | bbox[3] = 511 if y_max >511 else y_max 9 | if x_max <= x_min: 10 | return None 11 | if y_max <= y_min: 12 | return None 13 | if y_max - y_min < 16 or x_max - x_min < 16: 14 | # print(f"removing bbox: {bbox}") 15 | return None 16 | 17 | return bbox 18 | 19 | 20 | def check_bboxes(bboxes): 21 | """Check if bboxes boundaries are in range 0, 1 and minimums are lesser then maximums""" 22 | _bboxes = [] 23 | for bbox in bboxes: 24 | check = check_bbox(bbox) 25 | if check is not None: 26 | _bboxes.append(bbox) 27 | # else: 28 | # print(f"Removing bbox: {bbox}") 29 | return _bboxes 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Casual GAN Papers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Make-A-Scene - PyTorch 2 | Pytorch implementation (unofficial) of Make-A-Scene: Scene-Based Text-to-Image Generation with Human Priors 3 | 4 |

5 | results 6 | Figure 1. from paper 7 |

8 | 9 | ## Note: this is work in progress. 10 | We are at training stage! The process can be followed in the Discord-Channel on the LAION Discord https://discord.gg/DghvZDKu. 11 | The data preprocessing has been finished as well as training VQSEG. We are currently training VQIMG. Training checkpoints will be released soon with demos. 12 | The transformer implementation is in progess and will hopefully be started to train as soon as VQIMG finishes. 13 | 14 | ## Demo 15 | VQIMG: https://colab.research.google.com/drive/1SPyQ-epTsAOAu8BEohUokN4-b5RM_TnE?usp=sharing 16 | 17 | ## Paper Description 18 | Make-A-Scene modifies the VQGAN framework. It makes heavy use of using semantic segmentation maps for extra conditioning. This enables more influence on the generation process. Morever, it also conditions on text. The main improvements are the following: 19 | 1. Segmentation condition: separate VQVAE is trained (VQ-SEG) + loss modified to a weighted binary cross entropy. (3.4) 20 | 2. VQGAN training (VQ-IMG) is extended by Face-Loss & Object-Loss (3.3 & 3.5) 21 | 3. Classifier Guidance for the autoregressive transformer (3.7) 22 | 23 | ## Training Pipeline 24 |

25 | results 26 | Figure 6. from paper 27 |

28 | 29 | ## What needs to be done? 30 | Refer to the different folders to see details. 31 | - [X] [VQ-SEG](https://github.com/CasualGANPapers/Make-A-Scene/tree/main/VQ-SEG) 32 | - [X] [VQ-IMG](https://github.com/CasualGANPapers/Make-A-Scene/tree/main/VQ-IMG) 33 | - [ ] [Transformer]() 34 | - [X] [Data Aggregation](https://github.com/CasualGANPapers/Make-A-Scene/tree/main/Data) 35 | 36 | ## Citation 37 | ```bibtex 38 | @misc{https://doi.org/10.48550/arxiv.2203.13131, 39 | doi = {10.48550/ARXIV.2203.13131}, 40 | url = {https://arxiv.org/abs/2203.13131}, 41 | author = {Gafni, Oran and Polyak, Adam and Ashual, Oron and Sheynin, Shelly and Parikh, Devi and Taigman, Yaniv}, 42 | title = {Make-A-Scene: Scene-Based Text-to-Image Generation with Human Priors}, 43 | publisher = {arXiv}, 44 | year = {2022}, 45 | copyright = {arXiv.org perpetual, non-exclusive license} 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /conf/img_config.yaml: -------------------------------------------------------------------------------- 1 | mode: pretrain_image 2 | devices: 3 | - 0 4 | - 1 5 | - 2 6 | - 3 7 | - 4 8 | - 5 9 | - 6 10 | - 7 11 | 12 | total_steps: 800000 13 | accumulate_grad: 8 14 | resume: False 15 | checkpoint: ./outputs/2022-04-03/19-00-36/checkpoint.pt 16 | log_period: 50 17 | batch_size: 2 # 192 for 256 model and 128 for 512 model 18 | 19 | model: 20 | _target_: models.VQBASE 21 | embed_dim: 256 22 | n_embed: 8192 23 | init_steps: 3000 24 | reservoir_size: 12500 # 2e5 / 8 25 | ddconfig: 26 | z_channels: 256 27 | in_channels: 3 28 | out_channels: 3 29 | channels: [128, 128, 128, 256, 512, 512] # [1, 1, 2, 4, 4] 30 | num_res_blocks: 2 31 | resolution: 512 32 | attn_resolutions: 33 | - 32 34 | dropout: 0.0 35 | 36 | optimizer: 37 | vq: 38 | lr: 5e-6 39 | betas: 40 | - 0.5 41 | - 0.9 42 | disc: 43 | lr: 4.5e-6 44 | betas: 45 | - 0.5 46 | - 0.9 47 | 48 | dataset: 49 | _target_: Data.dataset_preprocessor_web.S3ProcessedDataset 50 | resampled: True 51 | names: 52 | - cc3m 53 | - cc12m 54 | # path: file:D:/PycharmProjects/Make-A-Scene/server/Make-A-Scene/dataset/coco/{00000..00004}.tar 55 | # path: file:D:/PycharmProjects/Make-A-Scene/server/Make-A-Scene/dataset/coco/great_dataset.tar 56 | 57 | loss: 58 | #_target_: losses.VQVAEWithBCELoss 59 | _target_: losses.loss_img.VQLPIPSWithDiscriminator 60 | disc_start: 250001 61 | disc_weight: 0.8 62 | codebook_weight: 1.0 63 | 64 | dataloader: 65 | batch_size: ${batch_size} 66 | num_workers: 8 67 | pin_memory: True 68 | 69 | hydra: 70 | job: 71 | chdir: True 72 | run: 73 | dir: ./outputs/${mode}/${now:%Y-%m-%d}/${now:%H-%M-%S} 74 | -------------------------------------------------------------------------------- /conf/preprocess_data.yaml: -------------------------------------------------------------------------------- 1 | mode: preprocess_dataset 2 | root: "/home/silent/hdd/nets/db/mydb" 3 | preprocessed: "/home/silent/hdd/nets/db/mydb/preprocessed" 4 | 5 | preprocessor: 6 | _target_: Data.BasePreprocessor 7 | preprocessed_folder: ${preprocessed} 8 | devices: 9 | - 0 10 | proc_per_gpu: 11 | panoptic: 0 12 | human: 0 13 | face: 0 14 | proc_per_cpu: 15 | panoptic: 1 16 | human: 1 17 | face: 1 18 | 19 | dataset: 20 | _target_: Data.dataset_preprocessor.ConcatDataset 21 | datasets: 22 | - ${coco2014} 23 | - ${coco2017} 24 | 25 | coco2014: 26 | _target_: Data.dataset_preprocessor.COCO2014Dataset 27 | root: ${root} 28 | preprocessed_folder: ${preprocessed} 29 | 30 | coco2017: 31 | _target_: Data.dataset_preprocessor.COCO2017Dataset 32 | root: ${root} 33 | preprocessed_folder: ${preprocessed} 34 | 35 | 36 | hydra: 37 | run: 38 | dir: ./outputs/${mode}/${now:%Y-%m-%d}/${now:%H-%M-%S} 39 | -------------------------------------------------------------------------------- /conf/preprocess_data_web.yaml: -------------------------------------------------------------------------------- 1 | mode: preprocess_dataset 2 | #root: "/home/ubuntu/MakeAScene/data/coco/mscoco" 3 | root: "pipe:aws s3 cp s3://s-mas/laion-high-resolution/{00000..01500}.tar -" 4 | preprocessed: "/home/ubuntu/data/laion_en_tmp/" 5 | output_folder: "laion_en_processed" 6 | 7 | preprocessor: 8 | _target_: Data.WebPreprocessor 9 | preprocessed_folder: ${preprocessed} 10 | output_folder: ${output_folder} 11 | batch_size: 32 12 | num_workers: 2 13 | machines_total: 2 14 | machine_idx: 0 15 | devices: 16 | - 9 17 | proc_per_gpu: 18 | panoptic: 19 | - 4 20 | - 5 21 | - 6 22 | - 7 23 | human: 24 | - 2 25 | - 3 26 | face: 27 | - 0 28 | - 1 29 | proc_per_cpu: 30 | panoptic: 0 31 | human: 0 32 | face: 0 33 | 34 | dataset: 35 | _target_: Data.dataset_preprocessor_web.UnprocessedWebDataset 36 | root: ${root} 37 | 38 | 39 | hydra: 40 | run: 41 | dir: ./outputs/${mode}/${now:%Y-%m-%d}/${now:%H-%M-%S} 42 | -------------------------------------------------------------------------------- /conf/seg_config.yaml: -------------------------------------------------------------------------------- 1 | mode: pretrain_segmentation 2 | devices: 3 | - 0 4 | 5 | total_steps: 6000000 6 | accumulate_grad: 3 7 | resume: False 8 | checkpoint: ./outputs/2022-04-03/19-00-36/checkpoint.pt 9 | log_period: 50 10 | batch_size: 2 11 | 12 | model: 13 | _target_: models.VQBASE 14 | embed_dim: 256 15 | n_embed: 1024 16 | ddconfig: 17 | double_z: false 18 | z_channels: 256 19 | resolution: 256 20 | in_channels: 159 21 | out_ch: 159 22 | ch: 128 23 | ch_mult: 24 | - 1 25 | - 1 26 | - 2 27 | - 2 28 | - 4 29 | num_res_blocks: 2 30 | attn_resolutions: 31 | - 16 32 | dropout: 0.0 33 | 34 | optimizer: 35 | lr: 4.5e-6 36 | betas: 37 | - 0.5 38 | - 0.9 39 | 40 | dataset: 41 | _target_: Data.dataset_preprocessor.COCO2014Dataset 42 | root: "D:\\PycharmProjects\\Make-A-Scene\\Data\\coco\\tmpdb_2\\" 43 | preprocessed_folder: "D:\\PycharmProjects\\Make-A-Scene\\Data\\coco\\tmpdb_2\\preprocessed_folder" 44 | force_preprocessing: False 45 | 46 | loss: 47 | #_target_: losses.VQVAEWithBCELoss 48 | _target_: losses.BCELossWithQuant 49 | 50 | dataloader: 51 | batch_size: 2 52 | num_workers: 8 53 | shuffle: True 54 | pin_memory: True 55 | 56 | hydra: 57 | run: 58 | dir: ./outputs/${mode}/${now:%Y-%m-%d}/${now:%H-%M-%S} 59 | -------------------------------------------------------------------------------- /conf/show.yaml: -------------------------------------------------------------------------------- 1 | mode: show_segmentation 2 | devices: 3 | - 0 4 | 5 | checkpoint: ./outputs/pretrain_segmentation/2022-04-04/16-40-10/checkpoint.pt 6 | 7 | model: 8 | _target_: models.VQBASE 9 | embed_dim: 256 10 | n_embed: 1024 11 | image_key: "segmentation" 12 | #n_labels: 182 13 | ddconfig: 14 | double_z: false 15 | z_channels: 256 16 | resolution: 256 17 | in_channels: 159 18 | out_ch: 159 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | num_res_blocks: 2 27 | attn_resolutions: 28 | - 16 29 | dropout: 0.0 30 | 31 | 32 | dataset: 33 | _target_: Data.dataset_preprocessor.COCO2014Dataset 34 | root: "/path_to_coco" 35 | preprocessed_folder: "/path_to_preprocessed_folder" 36 | force_preprocessing: False 37 | 38 | hydra: 39 | run: 40 | dir: ./outputs/${mode}/${now:%Y-%m-%d}/${now:%H-%M-%S} 41 | -------------------------------------------------------------------------------- /log_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.tensorboard import SummaryWriter 5 | import torchvision.utils as vutils 6 | import os 7 | 8 | 9 | class Logger: 10 | def __init__(self, proc_id, log_dir=".", device="cuda"): 11 | self.proc_id = proc_id 12 | if proc_id != 0: 13 | return 14 | self.writer = SummaryWriter(log_dir) 15 | self.step = 0 16 | os.makedirs("./results") 17 | 18 | def log(self, step=None, img=None, img_rec=None, **kwargs): 19 | if self.proc_id != 0: 20 | return 21 | self.step = step if step is not None else self.step + 1 22 | for key in kwargs: 23 | self.writer.add_scalar(key, kwargs[key].detach().cpu().item(), self.step) 24 | if img is not None and img_rec is not None and self.step%500==0: 25 | img = img.detach().cpu() 26 | img_rec = img_rec.detach().cpu() 27 | pairs = torch.cat([img,img_rec]).detach().cpu() 28 | img_grid = vutils.make_grid(pairs) 29 | self.writer.add_image('samples', img_grid.detach().cpu(), global_step=step) 30 | 31 | 32 | class Visualizer: 33 | dims = { 34 | "panoptic": [0, 133], 35 | "human": [133, 153], 36 | "face": [153, 158], 37 | "edge": [158, 159] 38 | } 39 | 40 | def __init__(self, log_dir=".", device="cuda"): 41 | self.weights = {} 42 | for key in self.dims: 43 | size = self.dims[key][1] - self.dims[key][0] 44 | weight = torch.randn([3, size, 1, 1]).to(device) 45 | self.weights[key] = weight 46 | os.makedirs("./results") 47 | 48 | def log_images(self, seg, seg_rec): 49 | seg = self.colorize(seg) 50 | seg_rec = self.colorize(seg_rec, logits=True) 51 | both = torch.cat((seg, seg_rec)) 52 | grid = vutils.make_grid(both, nrow=2) 53 | vutils.save_image(both, f"./results/{self.step}.jpg", nrow=4) 54 | 55 | def colorize(self, seg, logits=False): 56 | results = [] 57 | for key in self.dims: 58 | seg_key = seg[:, self.dims[key][0]: self.dims[key][1]] 59 | if logits: 60 | n_classes = seg_key.shape[1] 61 | if "face" in key or "edge" in key: 62 | mask = seg_key.sigmoid() > 0.2 63 | seg_key = torch.argmax(seg_key, dim=1, keepdim=False) 64 | seg_key = F.one_hot(seg_key, num_classes=n_classes) 65 | seg_key = seg_key.permute(0, 3, 1, 2).float() 66 | if "face" in key or "edge" in key: 67 | seg_key *= mask 68 | 69 | weight = self.weights[key] 70 | with torch.no_grad(): 71 | x = F.conv2d(seg_key, weight) 72 | x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. 73 | results.append(x) 74 | return results 75 | seg = seg[:, :133] 76 | if logits: 77 | # seg = (seg.sigmoid()>0.35).to(torch.float) 78 | seg = torch.argmax(seg, dim=1, keepdim=False) 79 | seg = F.one_hot(seg, num_classes=133) 80 | seg = seg.squeeze(1).permute(0, 3, 1, 2).float() 81 | 82 | return x 83 | 84 | def __call__(self, step, image=None, seg=None, seg_rec=None): 85 | results = [image] 86 | if seg is not None: 87 | results.extend(self.colorize(seg)) 88 | results.append(torch.zeros_like(image)) 89 | if seg_rec is not None: 90 | results.extend(self.colorize(seg_rec, logits=True)) 91 | results = torch.cat(results) 92 | vutils.save_image(results, f"./results/result_{step}.jpg", nrow=5) 93 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss_seg import BCELossWithQuant, VQVAEWithBCELoss 2 | from .loss_img import VQLPIPSWithDiscriminator 3 | -------------------------------------------------------------------------------- /losses/discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | PatchGAN Discriminator (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538) 3 | """ 4 | 5 | import torch.nn as nn 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class Discriminator(nn.Module): 18 | def __init__(self, in_channels=3, num_filters_last=64, n_layers=3): 19 | super(Discriminator, self).__init__() 20 | 21 | layers = [nn.Conv2d(in_channels, num_filters_last, 4, 2, 1), nn.LeakyReLU(0.2)] 22 | num_filters_mult = 1 23 | 24 | for i in range(1, n_layers + 1): 25 | num_filters_mult_last = num_filters_mult 26 | num_filters_mult = min(2 ** i, 8) 27 | layers += [ 28 | nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4, 29 | 2 if i < n_layers else 1, 1, bias=False), 30 | nn.BatchNorm2d(num_filters_last * num_filters_mult), 31 | nn.LeakyReLU(0.2, True) 32 | ] 33 | 34 | layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1)) 35 | self.model = nn.Sequential(*layers) 36 | 37 | def forward(self, x): 38 | return self.model(x) -------------------------------------------------------------------------------- /losses/face_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code taken from and modified https://github.com/cydonia999/VGGFace2-pytorch 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import transforms 8 | from torchvision.transforms.functional import crop 9 | 10 | __all__ = ['FaceLoss'] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class Bottleneck(nn.Module): 20 | expansion = 4 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(Bottleneck, self).__init__() 24 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 29 | self.bn3 = nn.BatchNorm2d(planes * 4) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.downsample = downsample 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv3(out) 46 | out = self.bn3(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class FaceLoss(nn.Module): 58 | def __init__(self): 59 | super(FaceLoss, self).__init__() 60 | layers = [3, 4, 6, 3] 61 | self.inplanes = 64 62 | self.alphas = [0.1, 0.25 * 0.01, 0.25 * 0.1, 0.25 * 0.2, 0.25 * 0.02] 63 | 64 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 65 | self.bn1 = nn.BatchNorm2d(64) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) 68 | 69 | self.layer1 = self._make_layer(Bottleneck, 64, layers[0]) 70 | self.layer2 = self._make_layer(Bottleneck, 128, layers[1], stride=2) 71 | self.layer3 = self._make_layer(Bottleneck, 256, layers[2], stride=2) 72 | self.layer4 = self._make_layer(Bottleneck, 512, layers[3], stride=2) 73 | 74 | self.channels = [64, 256, 512, 1024, 2048] 75 | 76 | self.load_state_dict(torch.load("/home/ubuntu/Make-A-Scene/losses/face_loss_weights.pt", map_location="cpu"), strict=False) 77 | 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | self.face_transforms = nn.Sequential( 82 | transforms.Resize(256), 83 | transforms.CenterCrop(254) 84 | ) 85 | self.eval() 86 | 87 | def _make_layer(self, block, planes, blocks, stride=1): 88 | downsample = None 89 | if stride != 1 or self.inplanes != planes * block.expansion: 90 | downsample = nn.Sequential( 91 | nn.Conv2d(self.inplanes, planes * block.expansion, 92 | kernel_size=1, stride=stride, bias=False), 93 | nn.BatchNorm2d(planes * block.expansion), 94 | ) 95 | 96 | layers = [] 97 | layers.append(block(self.inplanes, planes, stride, downsample)) 98 | self.inplanes = planes * block.expansion 99 | for i in range(1, blocks): 100 | layers.append(block(self.inplanes, planes)) 101 | 102 | return nn.Sequential(*layers) 103 | 104 | def _forward(self, x): # 224x224 105 | features = [] 106 | x = self.conv1(x) # 112x112 107 | features.append(x) 108 | x = self.bn1(x) 109 | x = self.relu(x) 110 | x = self.maxpool(x) # 56x56 ignore 111 | 112 | x = self.layer1(x) # 56x56 113 | features.append(x) 114 | x = self.layer2(x) # 28x28 115 | features.append(x) 116 | x = self.layer3(x) # 14x14 ignore (maybe not) 117 | features.append(x) 118 | x = self.layer4(x) # 7x7 119 | features.append(x) 120 | 121 | return features 122 | 123 | def forward(self, img, rec, bbox): 124 | """ 125 | Takes in original image and reconstructed image and feeds it through face network and takes the difference 126 | between the different resolutions and scales by alpha_{i}. 127 | Normalizing the features and applying spatial resolution was taken from LPIPS and wasn't mentioned in the paper. 128 | """ 129 | faces = self.prepare_faces(img, rec, bbox) 130 | if faces is None: 131 | return img.new_tensor(0) 132 | faces = faces[:6] # otherwise it may cause oom error. 133 | features = self._forward(faces) 134 | features = [f.chunk(2) for f in features] 135 | # diffs = [a * torch.abs(p[0] - p[1]).sum() for a, p in zip(self.alphas, features)] 136 | diffs = [a * torch.abs(p[0] - p[1]).sum(dim=0).mean() for a, p in zip(self.alphas, features)] 137 | # diffs = [a*torch.abs(self.norm_tensor(tf) - self.norm_tensor(rf)) for a, tf, rf in zip(self.alphas, true_features, rec_features)] 138 | 139 | # diffs = [a * torch.mean(torch.abs(tf - rf)) for a, tf, rf in zip(self.alphas, features)] 140 | return sum(diffs) 141 | # return sum(diffs) / len(diffs) 142 | 143 | def prepare_faces(self, imgs, recs, bboxes): 144 | faces_gt = [] 145 | faces_gen = [] 146 | for img, rec, bboxes in zip(imgs, recs, bboxes): 147 | for bbox in bboxes: 148 | top = bbox[1] 149 | left = bbox[0] 150 | height = bbox[3] - bbox[1] 151 | width = bbox[2] - bbox[0] 152 | crop_img = crop(img, top, left, height, width) 153 | faces_gt.append(self.face_transforms(crop_img)) 154 | crop_rec = crop(rec, top, left, height, width) 155 | faces_gen.append(self.face_transforms(crop_rec)) 156 | if len(faces_gt) == 0: 157 | return None 158 | faces_gt = torch.stack(faces_gt, dim=0) 159 | faces_gen = torch.stack(faces_gen, dim=0) 160 | return torch.cat([faces_gt, faces_gen], dim=0) 161 | 162 | 163 | 164 | if __name__ == '__main__': 165 | model = FaceLoss() 166 | # x = torch.randn(1, 3, 256, 256) 167 | # x_rec = torch.randn(1, 3, 256, 256) 168 | x = torch.randn(2, 3, 101, 101) 169 | x_rec = torch.randn(2, 3, 101, 101) 170 | print(model.forward(x, x_rec)) 171 | -------------------------------------------------------------------------------- /losses/loss_img.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .lpips import LPIPS 5 | from .lpips_with_object import LPIPSWithObject 6 | from .discriminator import Discriminator, weights_init 7 | from .face_loss import FaceLoss 8 | from torchvision.transforms.functional import crop 9 | 10 | 11 | def adopt_weight(weight, global_step, threshold=0, value=0.0): 12 | if global_step < threshold: 13 | weight = value 14 | return weight 15 | 16 | 17 | def hinge_d_loss(logits_real, logits_fake): 18 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 19 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 20 | d_loss = 0.5 * (loss_real + loss_fake) 21 | return d_loss 22 | 23 | 24 | def vanilla_d_loss(logits_real, logits_fake): 25 | d_loss = 0.5 * ( 26 | torch.mean(torch.nn.functional.softplus(-logits_real)) 27 | + torch.mean(torch.nn.functional.softplus(logits_fake)) 28 | ) 29 | return d_loss 30 | 31 | 32 | class VQLPIPSWithDiscriminator(nn.Module): 33 | def __init__( 34 | self, 35 | disc_start, 36 | codebook_weight=1.0, 37 | pixelloss_weight=1.0, 38 | disc_factor=1.0, 39 | disc_weight=1.0, 40 | perceptual_weight=1.0, 41 | ): 42 | super().__init__() 43 | self.codebook_weight = codebook_weight 44 | self.pixel_weight = pixelloss_weight 45 | self.perceptual_loss = LPIPSWithObject().eval() 46 | self.perceptual_weight = perceptual_weight 47 | 48 | self.face_loss = FaceLoss() 49 | #self.object_loss = self.perceptual_loss 50 | 51 | self.discriminator = Discriminator().apply(weights_init) 52 | self.discriminator_iter_start = disc_start 53 | self.disc_factor = disc_factor 54 | self.discriminator_weight = disc_weight 55 | 56 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): 57 | nll_grads = torch.autograd.grad(nll_loss, last_layer.weight, retain_graph=True)[ 58 | 0 59 | ] 60 | g_grads = torch.autograd.grad(g_loss, last_layer.weight, retain_graph=True)[0] 61 | 62 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 63 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 64 | d_weight = d_weight * self.discriminator_weight 65 | return d_weight 66 | 67 | def forward( 68 | self, 69 | optimizer_idx, 70 | global_step, 71 | images, 72 | reconstructions, 73 | codebook_loss=None, 74 | bbox_obj=None, 75 | bbox_face=None, 76 | last_layer=None, 77 | ): 78 | if optimizer_idx == 0: # vqvae loss 79 | rec_loss = torch.abs(images.contiguous() - reconstructions.contiguous()) 80 | p_loss = self.perceptual_loss( 81 | images.contiguous(), reconstructions.contiguous(), bbox_obj 82 | ) 83 | rec_loss = rec_loss + self.perceptual_weight * p_loss 84 | 85 | nll_loss = rec_loss 86 | nll_loss = torch.mean(nll_loss) 87 | 88 | face_loss = self.face_loss(images, reconstructions, bbox_face) 89 | 90 | object_loss = images.new_tensor(0) 91 | #for img, rec, bboxes in zip(images, reconstructions, bbox_obj): 92 | # img_object_loss = img.new_tensor(0) 93 | # for bbox in bboxes: 94 | # # xmin, ymin, xmax, ymax 95 | # top = bbox[1] 96 | # left = bbox[0] 97 | # height = bbox[3] - bbox[1] 98 | # width = bbox[2] - bbox[0] 99 | # crop_img = crop(img, top, left, height, width).unsqueeze( 100 | # 0 101 | # ) # bbox needs to be [x, y, height, width] 102 | # crop_rec = crop(rec, top, left, height, width).unsqueeze(0) 103 | # img_object_loss += self.object_loss( 104 | # crop_img.contiguous(), crop_rec.contiguous() 105 | # ).mean() # TODO: check if crops are actually correct 106 | # object_loss += img_object_loss/(len(bboxes)+1) 107 | 108 | logits_fake = self.discriminator( 109 | reconstructions.contiguous() 110 | ) # cont not necessary 111 | g_loss = -torch.mean(logits_fake) 112 | 113 | d_weight = self.calculate_adaptive_weight( 114 | nll_loss, g_loss, last_layer=last_layer 115 | ) 116 | 117 | disc_factor = adopt_weight( 118 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 119 | ) 120 | 121 | loss = ( 122 | nll_loss 123 | + d_weight * disc_factor * g_loss 124 | + self.codebook_weight * codebook_loss.mean() 125 | + face_loss 126 | + object_loss 127 | ) 128 | # 0.001 * face_loss 129 | return loss, (nll_loss, object_loss, face_loss) 130 | # return loss, (nll_loss, images.new_tensor(0), images.new_tensor(0)) 131 | 132 | if optimizer_idx == 1: # gan loss 133 | disc_factor = adopt_weight( 134 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 135 | ) 136 | 137 | logits_real = self.discriminator(images.contiguous().detach()) 138 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 139 | 140 | d_loss = disc_factor * hinge_d_loss(logits_real, logits_fake) 141 | return d_loss 142 | -------------------------------------------------------------------------------- /losses/loss_seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BCELossWithQuant(nn.Module): 7 | def __init__(self, image_channels=159, codebook_weight=1.0): 8 | super().__init__() 9 | self.codebook_weight = codebook_weight 10 | self.register_buffer( 11 | "weight", 12 | torch.ones(image_channels).index_fill(0, torch.arange(153, 158), 20), 13 | ) 14 | 15 | def forward(self, qloss, target, prediction): 16 | bce_loss = F.binary_cross_entropy_with_logits( 17 | prediction.permute(0, 2, 3, 1), 18 | target.permute(0, 2, 3, 1), 19 | pos_weight=self.weight, 20 | ) 21 | loss = bce_loss + self.codebook_weight * qloss 22 | return loss 23 | 24 | 25 | class VQVAEWithBCELoss(nn.Module): 26 | def __init__(self, image_channels=159, codebook_weight=1.0): 27 | super().__init__() 28 | self.codebook_weight = codebook_weight 29 | self.register_buffer( 30 | "weight", 31 | torch.ones(image_channels).index_fill(0, torch.arange(153, 158), 20), 32 | ) 33 | 34 | def forward(self, qloss, target, prediction): 35 | bce_mse_loss = F.mse_loss(prediction.sigmoid(), target) + F.binary_cross_entropy_with_logits( 36 | prediction.permute(0, 2, 3, 1), 37 | target.permute(0, 2, 3, 1), 38 | pos_weight=self.weight, 39 | ) 40 | loss = bce_mse_loss + self.codebook_weight * qloss 41 | return loss 42 | -------------------------------------------------------------------------------- /losses/lpips.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.models import vgg16 5 | from collections import namedtuple 6 | import requests 7 | from tqdm import tqdm 8 | 9 | 10 | URL_MAP = { 11 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 12 | } 13 | 14 | CKPT_MAP = { 15 | "vgg_lpips": "/home/ubuntu/Make-A-Scene/weights/vgg.pth" 16 | } 17 | 18 | 19 | def download(url, local_path, chunk_size=1024): 20 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 21 | with requests.get(url, stream=True) as r: 22 | total_size = int(r.headers.get("content-length", 0)) 23 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 24 | with open(local_path, "wb") as f: 25 | for data in r.iter_content(chunk_size=chunk_size): 26 | if data: 27 | f.write(data) 28 | pbar.update(chunk_size) 29 | 30 | 31 | def get_ckpt_path(name, root): 32 | assert name in URL_MAP 33 | #path = os.path.join(root, CKPT_MAP[name]) 34 | path = os.path.join(CKPT_MAP[name]) 35 | if not os.path.exists(path): 36 | print(f"Downloading {name} model from {URL_MAP[name]} to {path}") 37 | download(URL_MAP[name], path) 38 | return path 39 | 40 | 41 | class LPIPS(nn.Module): 42 | def __init__(self): 43 | super(LPIPS, self).__init__() 44 | self.scaling_layer = ScalingLayer() 45 | self.channels = [64, 128, 256, 512, 512] 46 | self.vgg = VGG16() 47 | self.lin0 = NetLinLayer(self.channels[0]) 48 | self.lin1 = NetLinLayer(self.channels[1]) 49 | self.lin2 = NetLinLayer(self.channels[2]) 50 | self.lin3 = NetLinLayer(self.channels[3]) 51 | self.lin4 = NetLinLayer(self.channels[4]) 52 | self.load_from_pretrained() 53 | self.lins = [ 54 | self.lin0, 55 | self.lin1, 56 | self.lin2, 57 | self.lin3, 58 | self.lin4 59 | ] 60 | 61 | for param in self.parameters(): 62 | param.requires_grad = False 63 | 64 | def load_from_pretrained(self, name="vgg_lpips"): 65 | ckpt = get_ckpt_path(name, "vgg_lpips") 66 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 67 | 68 | def forward(self, real_x, fake_x): 69 | features_real = self.vgg(self.scaling_layer(real_x)) 70 | features_fake = self.vgg(self.scaling_layer(fake_x)) 71 | diffs = {} 72 | 73 | for i in range(len(self.channels)): 74 | diffs[i] = (norm_tensor(features_real[i]) - norm_tensor(features_fake[i])) ** 2 75 | 76 | return sum([spatial_average(self.lins[i].model(diffs[i])) for i in range(len(self.channels))]) 77 | 78 | 79 | class ScalingLayer(nn.Module): 80 | def __init__(self): 81 | super(ScalingLayer, self).__init__() 82 | self.register_buffer("shift", torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 83 | self.register_buffer("scale", torch.Tensor([.458, .448, .450])[None, :, None, None]) 84 | 85 | def forward(self, x): 86 | return (x - self.shift) / self.scale 87 | 88 | 89 | class NetLinLayer(nn.Module): 90 | def __init__(self, in_channels, out_channels=1): 91 | super(NetLinLayer, self).__init__() 92 | self.model = nn.Sequential( 93 | nn.Dropout(), 94 | nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False) 95 | ) 96 | 97 | 98 | class VGG16(nn.Module): 99 | def __init__(self): 100 | super(VGG16, self).__init__() 101 | vgg_pretrained_features = vgg16(pretrained=True).features 102 | slices = [vgg_pretrained_features[i] for i in range(30)] 103 | self.slice1 = nn.Sequential(*slices[0:4]) 104 | self.slice2 = nn.Sequential(*slices[4:9]) 105 | self.slice3 = nn.Sequential(*slices[9:16]) 106 | self.slice4 = nn.Sequential(*slices[16:23]) 107 | self.slice5 = nn.Sequential(*slices[23:30]) 108 | 109 | for param in self.parameters(): 110 | param.requires_grad = False 111 | 112 | def forward(self, x): 113 | h = self.slice1(x) 114 | h_relu1 = h 115 | h = self.slice2(h) 116 | h_relu2 = h 117 | h = self.slice3(h) 118 | h_relu3 = h 119 | h = self.slice4(h) 120 | h_relu4 = h 121 | h = self.slice5(h) 122 | h_relu5 = h 123 | vgg_outputs = namedtuple("VGGOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 124 | return vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 125 | 126 | 127 | def norm_tensor(x): 128 | """ 129 | Normalize images by their length to make them unit vector? 130 | :param x: batch of images 131 | :return: normalized batch of images 132 | """ 133 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 134 | return x / (norm_factor + 1e-10) 135 | 136 | 137 | def spatial_average(x): 138 | """ 139 | imgs have: batch_size x channels x width x height --> average over width and height channel 140 | :param x: batch of images 141 | :return: averaged images along width and height 142 | """ 143 | return x.mean([2, 3], keepdim=True) 144 | 145 | -------------------------------------------------------------------------------- /losses/lpips_with_object.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Function 5 | from torchvision.models import vgg16 6 | from collections import namedtuple 7 | import requests 8 | from tqdm import tqdm 9 | from .lpips import LPIPS 10 | 11 | 12 | class WeightThingGrad(Function): 13 | def forward(ctx, input, bboxes): 14 | weights = torch.ones_like(input) 15 | for weight, img_bboxes in zip(weights, bboxes): 16 | for x_min, y_min, x_max, y_max in img_bboxes: 17 | weight[:, x_min:x_max, y_min:y_max] 18 | 19 | ctx.save_for_backward(weights) 20 | return input 21 | 22 | def backward(ctx, grad_output): 23 | weights = ctx.saved_tensors[0] 24 | return grad_output*weights, None 25 | 26 | weight_thing_grad = WeightThingGrad.apply 27 | 28 | class LPIPSWithObject(LPIPS): 29 | def forward(self, real_x, fake_x, object_boxes): 30 | fake_x = weight_thing_grad(fake_x, object_boxes) 31 | return super().forward(real_x, fake_x) 32 | 33 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vqvae import VQBASE 2 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | # Taken from https://github.com/CompVis/taming-transformers 2 | # pytorch_diffusion + derived encoder decoder 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from fast_pytorch_kmeans import KMeans 9 | from torch import einsum 10 | import torch.distributed as dist 11 | from einops import rearrange 12 | 13 | 14 | def get_timestep_embedding(timesteps, embedding_dim): 15 | """ 16 | This matches the implementation in Denoising Diffusion Probabilistic Models: 17 | From Fairseq. 18 | Build sinusoidal embeddings. 19 | This matches the implementation in tensor2tensor, but differs slightly 20 | from the description in Section 3.5 of "Attention Is All You Need". 21 | """ 22 | assert len(timesteps.shape) == 1 23 | 24 | half_dim = embedding_dim // 2 25 | emb = math.log(10000) / (half_dim - 1) 26 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 27 | emb = emb.to(device=timesteps.device) 28 | emb = timesteps.float()[:, None] * emb[None, :] 29 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 30 | if embedding_dim % 2 == 1: # zero pad 31 | emb = torch.nn.functional.pad(emb, (0,1,0,0)) 32 | return emb 33 | 34 | 35 | def nonlinearity(x): 36 | # swish 37 | return x*torch.sigmoid(x) 38 | 39 | 40 | def Normalize(in_channels): 41 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 42 | 43 | 44 | class Upsample(nn.Module): 45 | def __init__(self, in_channels, with_conv): 46 | super().__init__() 47 | self.with_conv = with_conv 48 | if self.with_conv: 49 | self.conv = torch.nn.Conv2d(in_channels, 50 | in_channels, 51 | kernel_size=3, 52 | stride=1, 53 | padding=1) 54 | 55 | def forward(self, x): 56 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 57 | if self.with_conv: 58 | x = self.conv(x) 59 | return x 60 | 61 | 62 | class Downsample(nn.Module): 63 | def __init__(self, in_channels, with_conv): 64 | super().__init__() 65 | self.with_conv = with_conv 66 | if self.with_conv: 67 | # no asymmetric padding in torch conv, must do it ourselves 68 | self.conv = torch.nn.Conv2d(in_channels, 69 | in_channels, 70 | kernel_size=3, 71 | stride=2, 72 | padding=0) 73 | 74 | def forward(self, x): 75 | if self.with_conv: 76 | pad = (0,1,0,1) 77 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 78 | x = self.conv(x) 79 | else: 80 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 81 | return x 82 | 83 | 84 | class ResnetBlock(nn.Module): 85 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout): 86 | super().__init__() 87 | self.in_channels = in_channels 88 | out_channels = in_channels if out_channels is None else out_channels 89 | self.out_channels = out_channels 90 | self.use_conv_shortcut = conv_shortcut 91 | 92 | self.norm1 = Normalize(in_channels) 93 | self.conv1 = torch.nn.Conv2d(in_channels, 94 | out_channels, 95 | kernel_size=3, 96 | stride=1, 97 | padding=1) 98 | self.norm2 = Normalize(out_channels) 99 | self.dropout = torch.nn.Dropout(dropout) 100 | self.conv2 = torch.nn.Conv2d(out_channels, 101 | out_channels, 102 | kernel_size=3, 103 | stride=1, 104 | padding=1) 105 | if self.in_channels != self.out_channels: 106 | if self.use_conv_shortcut: 107 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 108 | out_channels, 109 | kernel_size=3, 110 | stride=1, 111 | padding=1) 112 | else: 113 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 114 | out_channels, 115 | kernel_size=1, 116 | stride=1, 117 | padding=0) 118 | 119 | def forward(self, x): 120 | h = x 121 | h = self.norm1(h) 122 | h = nonlinearity(h) 123 | h = self.conv1(h) 124 | 125 | h = self.norm2(h) 126 | h = nonlinearity(h) 127 | h = self.dropout(h) 128 | h = self.conv2(h) 129 | 130 | if self.in_channels != self.out_channels: 131 | if self.use_conv_shortcut: 132 | x = self.conv_shortcut(x) 133 | else: 134 | x = self.nin_shortcut(x) 135 | 136 | return x+h 137 | 138 | 139 | class AttnBlock(nn.Module): 140 | def __init__(self, in_channels): 141 | super().__init__() 142 | self.in_channels = in_channels 143 | 144 | self.norm = Normalize(in_channels) 145 | self.q = torch.nn.Conv2d(in_channels, 146 | in_channels, 147 | kernel_size=1, 148 | stride=1, 149 | padding=0) 150 | self.k = torch.nn.Conv2d(in_channels, 151 | in_channels, 152 | kernel_size=1, 153 | stride=1, 154 | padding=0) 155 | self.v = torch.nn.Conv2d(in_channels, 156 | in_channels, 157 | kernel_size=1, 158 | stride=1, 159 | padding=0) 160 | self.proj_out = torch.nn.Conv2d(in_channels, 161 | in_channels, 162 | kernel_size=1, 163 | stride=1, 164 | padding=0) 165 | 166 | 167 | def forward(self, x): 168 | h_ = x 169 | h_ = self.norm(h_) 170 | q = self.q(h_) 171 | k = self.k(h_) 172 | v = self.v(h_) 173 | 174 | # compute attention 175 | b,c,h,w = q.shape 176 | q = q.reshape(b,c,h*w) 177 | q = q.permute(0,2,1) # b,hw,c 178 | k = k.reshape(b,c,h*w) # b,c,hw 179 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 180 | w_ = w_ * (int(c)**(-0.5)) 181 | w_ = torch.nn.functional.softmax(w_, dim=2) 182 | 183 | # attend to values 184 | v = v.reshape(b,c,h*w) 185 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) 186 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 187 | h_ = h_.reshape(b,c,h,w) 188 | 189 | h_ = self.proj_out(h_) 190 | 191 | return x+h_ 192 | 193 | 194 | class Swish(nn.Module): 195 | def forward(self, x): 196 | return x * torch.sigmoid(x) 197 | 198 | 199 | class Encoder(nn.Module): 200 | """ 201 | Encoder of VQ-GAN to map input batch of images to latent space. 202 | Dimension Transformations: 203 | 3x256x256 --Conv2d--> 32x256x256 204 | for loop: 205 | --ResBlock--> 64x256x256 --DownBlock--> 64x128x128 206 | --ResBlock--> 128x128x128 --DownBlock--> 128x64x64 207 | --ResBlock--> 256x64x64 --DownBlock--> 256x32x32 208 | --ResBlock--> 512x32x32 209 | --ResBlock--> 512x32x32 210 | --NonLocalBlock--> 512x32x32 211 | --ResBlock--> 512x32x32 212 | --GroupNorm--> 213 | --Swish--> 214 | --Conv2d-> 256x32x32 215 | """ 216 | 217 | def __init__(self, in_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs): 218 | super(Encoder, self).__init__() 219 | layers = [nn.Conv2d(in_channels, channels[0], 3, 1, 1)] 220 | for i in range(len(channels) - 1): 221 | in_channels = channels[i] 222 | out_channels = channels[i + 1] 223 | for j in range(num_res_blocks): 224 | layers.append(ResnetBlock(in_channels=in_channels, out_channels=out_channels, dropout=0.0)) 225 | in_channels = out_channels 226 | if resolution in attn_resolutions: 227 | layers.append(AttnBlock(in_channels)) 228 | if i < len(channels) - 2: 229 | layers.append(Downsample(channels[i + 1], with_conv=True)) 230 | resolution //= 2 231 | layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0)) 232 | layers.append(AttnBlock(channels[-1])) 233 | layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0)) 234 | layers.append(Normalize(channels[-1])) 235 | layers.append(Swish()) 236 | layers.append(nn.Conv2d(channels[-1], z_channels, 3, 1, 1)) 237 | self.model = nn.Sequential(*layers) 238 | 239 | def forward(self, x): 240 | return self.model(x) 241 | 242 | 243 | # class Encoder(nn.Module): 244 | # def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 245 | # attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 246 | # resolution, z_channels, double_z=True, **ignore_kwargs): 247 | # super().__init__() 248 | # self.ch = ch 249 | # self.temb_ch = 0 250 | # self.num_resolutions = len(ch_mult) 251 | # self.num_res_blocks = num_res_blocks 252 | # self.resolution = resolution 253 | # self.in_channels = in_channels 254 | # 255 | # # downsampling 256 | # self.conv_in = torch.nn.Conv2d(in_channels, 257 | # self.ch, 258 | # kernel_size=3, 259 | # stride=1, 260 | # padding=1) 261 | # 262 | # curr_res = resolution 263 | # in_ch_mult = (1,)+tuple(ch_mult) 264 | # self.down = nn.ModuleList() 265 | # for i_level in range(self.num_resolutions): 266 | # block = nn.ModuleList() 267 | # attn = nn.ModuleList() 268 | # block_in = ch*in_ch_mult[i_level] 269 | # block_out = ch*ch_mult[i_level] 270 | # for i_block in range(self.num_res_blocks): 271 | # block.append(ResnetBlock(in_channels=block_in, 272 | # out_channels=block_out, 273 | # temb_channels=self.temb_ch, 274 | # dropout=dropout)) 275 | # block_in = block_out 276 | # if curr_res in attn_resolutions: 277 | # attn.append(AttnBlock(block_in)) 278 | # down = nn.Module() 279 | # down.block = block 280 | # down.attn = attn 281 | # if i_level != self.num_resolutions-1: 282 | # down.downsample = Downsample(block_in, resamp_with_conv) 283 | # curr_res = curr_res // 2 284 | # self.down.append(down) 285 | # 286 | # # middle 287 | # self.mid = nn.Module() 288 | # self.mid.block_1 = ResnetBlock(in_channels=block_in, 289 | # out_channels=block_in, 290 | # temb_channels=self.temb_ch, 291 | # dropout=dropout) 292 | # self.mid.attn_1 = AttnBlock(block_in) 293 | # self.mid.block_2 = ResnetBlock(in_channels=block_in, 294 | # out_channels=block_in, 295 | # temb_channels=self.temb_ch, 296 | # dropout=dropout) 297 | # 298 | # # end 299 | # self.norm_out = Normalize(block_in) 300 | # self.conv_out = torch.nn.Conv2d(block_in, 301 | # 2*z_channels if double_z else z_channels, 302 | # kernel_size=3, 303 | # stride=1, 304 | # padding=1) 305 | # 306 | # 307 | # def forward(self, x): 308 | # #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) 309 | # 310 | # # timestep embedding 311 | # temb = None 312 | # 313 | # # downsampling 314 | # hs = [self.conv_in(x)] 315 | # for i_level in range(self.num_resolutions): 316 | # for i_block in range(self.num_res_blocks): 317 | # h = self.down[i_level].block[i_block](hs[-1], temb) 318 | # if len(self.down[i_level].attn) > 0: 319 | # h = self.down[i_level].attn[i_block](h) 320 | # hs.append(h) 321 | # if i_level != self.num_resolutions-1: 322 | # hs.append(self.down[i_level].downsample(hs[-1])) 323 | # 324 | # # middle 325 | # h = hs[-1] 326 | # h = self.mid.block_1(h, temb) 327 | # h = self.mid.attn_1(h) 328 | # h = self.mid.block_2(h, temb) 329 | # 330 | # # end 331 | # h = self.norm_out(h) 332 | # h = nonlinearity(h) 333 | # h = self.conv_out(h) 334 | # return h 335 | 336 | 337 | class Decoder(nn.Module): 338 | def __init__(self, out_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs): 339 | super(Decoder, self).__init__() 340 | ch_mult = channels[1:] 341 | num_resolutions = len(ch_mult) 342 | block_in = ch_mult[num_resolutions - 1] 343 | curr_res = resolution// 2 ** (num_resolutions - 1) 344 | 345 | layers = [nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1), 346 | ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0), 347 | AttnBlock(block_in), 348 | ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0) 349 | ] 350 | 351 | for i in reversed(range(num_resolutions)): 352 | block_out = ch_mult[i] 353 | for i_block in range(num_res_blocks+1): 354 | layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=0.0)) 355 | block_in = block_out 356 | if curr_res in attn_resolutions: 357 | layers.append(AttnBlock(block_in)) 358 | if i > 0: 359 | layers.append(Upsample(block_in, with_conv=True)) 360 | curr_res = curr_res * 2 361 | 362 | layers.append(Normalize(block_in)) 363 | layers.append(Swish()) 364 | layers.append(nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)) 365 | 366 | self.model = nn.Sequential(*layers) 367 | 368 | def forward(self, x): 369 | return self.model(x) 370 | 371 | 372 | # class Decoder(nn.Module): 373 | # def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 374 | # attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 375 | # resolution, z_channels, **ignorekwargs): 376 | # super().__init__() 377 | # self.temb_ch = 0 378 | # self.num_resolutions = len(ch_mult) 379 | # self.num_res_blocks = num_res_blocks 380 | # self.resolution = resolution 381 | # self.in_channels = in_channels 382 | # 383 | # block_in = ch*ch_mult[self.num_resolutions-1] 384 | # curr_res = resolution // 2**(self.num_resolutions-1) 385 | # self.z_shape = (1,z_channels,curr_res,curr_res) 386 | # 387 | # # z to block_in 388 | # self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 389 | # 390 | # # middle 391 | # self.mid = nn.Module() 392 | # self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) 393 | # self.mid.attn_1 = AttnBlock(block_in) 394 | # self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) 395 | # 396 | # # upsampling 397 | # self.up = nn.ModuleList() 398 | # for i_level in reversed(range(self.num_resolutions)): 399 | # block = nn.ModuleList() 400 | # attn = nn.ModuleList() 401 | # block_out = ch*ch_mult[i_level] 402 | # for i_block in range(self.num_res_blocks+1): 403 | # block.append(ResnetBlock(in_channels=block_in, 404 | # out_channels=block_out, 405 | # temb_channels=self.temb_ch, 406 | # dropout=dropout)) 407 | # block_in = block_out 408 | # if curr_res in attn_resolutions: 409 | # attn.append(AttnBlock(block_in)) 410 | # up = nn.Module() 411 | # up.block = block 412 | # up.attn = attn 413 | # if i_level != 0: 414 | # up.upsample = Upsample(block_in, resamp_with_conv) 415 | # curr_res = curr_res * 2 416 | # self.up.insert(0, up) # prepend to get consistent order 417 | # 418 | # # end 419 | # self.norm_out = Normalize(block_in) 420 | # self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 421 | # 422 | # def forward(self, z): 423 | # self.last_z_shape = z.shape 424 | # 425 | # # timestep embedding 426 | # temb = None 427 | # 428 | # # z to block_in 429 | # h = self.conv_in(z) 430 | # 431 | # # middle 432 | # h = self.mid.block_1(h, temb) 433 | # h = self.mid.attn_1(h) 434 | # h = self.mid.block_2(h, temb) 435 | # 436 | # # upsampling 437 | # for i_level in reversed(range(self.num_resolutions)): 438 | # for i_block in range(self.num_res_blocks+1): 439 | # h = self.up[i_level].block[i_block](h, temb) 440 | # if len(self.up[i_level].attn) > 0: 441 | # h = self.up[i_level].attn[i_block](h) 442 | # if i_level != 0: 443 | # h = self.up[i_level].upsample(h) 444 | # 445 | # h = self.norm_out(h) 446 | # h = nonlinearity(h) 447 | # h = self.conv_out(h) 448 | # return h 449 | 450 | 451 | class Codebook(nn.Module): 452 | """ 453 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 454 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 455 | """ 456 | def __init__(self, codebook_size, codebook_dim, beta, init_steps=2000, reservoir_size=2e5): 457 | super().__init__() 458 | self.codebook_size = codebook_size 459 | self.codebook_dim = codebook_dim 460 | self.beta = beta 461 | 462 | self.embedding = nn.Embedding(self.codebook_size, self.codebook_dim) 463 | self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) 464 | 465 | self.q_start_collect, self.q_init, self.q_re_end, self.q_re_step = init_steps, init_steps * 3, init_steps * 30, init_steps // 2 466 | self.q_counter = 0 467 | self.reservoir_size = int(reservoir_size) 468 | self.reservoir = None 469 | 470 | def forward(self, z): 471 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 472 | batch_size = z.size(0) 473 | z_flattened = z.view(-1, self.codebook_dim) 474 | if self.training: 475 | self.q_counter += 1 476 | # x_flat = x.permute(0, 2, 3, 1).reshape(-1, z.shape(1)) 477 | if self.q_counter > self.q_start_collect: 478 | z_new = z_flattened.clone().detach().view(batch_size, -1, self.codebook_dim) 479 | z_new = z_new[:, torch.randperm(z_new.size(1))][:, :10].reshape(-1, self.codebook_dim) 480 | self.reservoir = z_new if self.reservoir is None else torch.cat([self.reservoir, z_new], dim=0) 481 | self.reservoir = self.reservoir[torch.randperm(self.reservoir.size(0))[:self.reservoir_size]].detach() 482 | if self.q_counter < self.q_init: 483 | z_q = rearrange(z, 'b h w c -> b c h w').contiguous() 484 | return z_q, z_q.new_tensor(0), None # z_q, loss, min_encoding_indices 485 | else: 486 | # if self.q_counter < self.q_init + self.q_re_end: 487 | if self.q_init <= self.q_counter < self.q_re_end: 488 | if (self.q_counter - self.q_init) % self.q_re_step == 0 or self.q_counter == self.q_init + self.q_re_end - 1: 489 | kmeans = KMeans(n_clusters=self.codebook_size) 490 | world_size = dist.get_world_size() 491 | print("Updating codebook from reservoir.") 492 | if world_size > 1: 493 | global_reservoir = [torch.zeros_like(self.reservoir) for _ in range(world_size)] 494 | dist.all_gather(global_reservoir, self.reservoir.clone()) 495 | global_reservoir = torch.cat(global_reservoir, dim=0) 496 | else: 497 | global_reservoir = self.reservoir 498 | kmeans.fit_predict(global_reservoir) # reservoir is 20k encoded latents 499 | self.embedding.weight.data = kmeans.centroids.detach() 500 | 501 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 502 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 503 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 504 | 505 | min_encoding_indices = torch.argmin(d, dim=1) 506 | z_q = self.embedding(min_encoding_indices).view(z.shape) 507 | 508 | # compute loss for embedding 509 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) 510 | 511 | # preserve gradients 512 | z_q = z + (z_q - z).detach() 513 | 514 | # reshape back to match original input shape 515 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 516 | 517 | return z_q, loss, min_encoding_indices 518 | 519 | def get_codebook_entry(self, indices, shape): 520 | # get quantized latent vectors 521 | z_q = self.embedding(indices) 522 | 523 | if shape is not None: 524 | z_q = z_q.view(shape) 525 | # reshape back to match original input shape 526 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 527 | 528 | return z_q 529 | 530 | 531 | if __name__ == '__main__': 532 | enc = Encoder() 533 | dec = Decoder() 534 | print(sum([p.numel() for p in enc.parameters()])) 535 | print(sum([p.numel() for p in dec.parameters()])) 536 | x = torch.randn(1, 3, 512, 512) 537 | res = enc(x) 538 | print(res.shape) 539 | res = dec(res) 540 | print(res.shape) 541 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | taken from: https://github.com/ai-forever/ru-dalle/blob/master/rudalle/dalle/transformer.py slightly modified 3 | """ 4 | 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | 10 | 11 | @torch.jit.script 12 | def gelu(x): 13 | """OpenAI's gelu implementation.""" 14 | return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) 15 | 16 | 17 | class SelfAttention(nn.Module): 18 | def __init__(self, 19 | hidden_dim, 20 | num_attn_heads, 21 | attn_dropout_prob, 22 | out_dropout_prob, 23 | cogview_pb_relax=True, 24 | rudalle_relax=False 25 | ): 26 | super(SelfAttention, self).__init__() 27 | 28 | self.hidden_dim = hidden_dim 29 | self.num_attn_heads = num_attn_heads 30 | self.d = math.sqrt(self.hidden_dim // self.num_attn_heads) 31 | self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim) 32 | self.attn_drop = nn.Dropout(attn_dropout_prob) 33 | 34 | self.out_proj = nn.Linear(hidden_dim, hidden_dim) 35 | self.out_drop = nn.Dropout(out_dropout_prob) 36 | 37 | self.cogview_pb_relax = cogview_pb_relax 38 | self.rudalle_relax = rudalle_relax 39 | 40 | def split_heads(self, x): 41 | new_shape = x.size()[:-1] + (self.num_attn_heads, self.hidden_dim // self.num_attn_heads) 42 | return x.view(*new_shape).permute(0, 2, 1, 3) 43 | 44 | def calculate_attention(self, q, k, mask): 45 | k_t = k.transpose(-1, -2) 46 | mask_value = 10000. 47 | if self.cogview_pb_relax: 48 | if self.rudalle_relax: 49 | sigma = k_t.std() 50 | attn_scores = torch.matmul(q / self.d, k_t / sigma) 51 | attn_scores_max = attn_scores.detach().max(dim=-1)[0] 52 | attn_scores_min = (attn_scores.detach() + 65504).min(dim=-1)[0] 53 | shift = torch.min(attn_scores_min, attn_scores_max).unsqueeze(-1).expand_as(attn_scores) / 2 54 | attn_scores = (attn_scores - shift) / sigma 55 | mask_value = 65504 56 | else: 57 | attn_scores = torch.matmul(q / self.d, k_t) 58 | else: 59 | attn_scores = torch.matmul(q, k_t) / self.d 60 | 61 | mask = mask[:, :, -attn_scores.shape[-2]:] 62 | attn_scores = mask * attn_scores - (1. - mask) * mask_value 63 | if self.cogview_pb_relax and not self.rudalle_relax: 64 | alpha = 32 65 | attn_scores_scaled = attn_scores / alpha 66 | attn_scores_scaled_max, _ = attn_scores_scaled.detach().view( 67 | [attn_scores.shape[0], attn_scores.shape[1], -1]).max(dim=-1) 68 | attn_scores_scaled_max = attn_scores_scaled_max[..., None, None].expand( 69 | [-1, -1, attn_scores.size(2), attn_scores.size(3)]) 70 | attn_scores = (attn_scores_scaled - attn_scores_scaled_max) * alpha 71 | return attn_scores 72 | 73 | def forward(self, x, mask, use_cache=False, cache=None): 74 | if use_cache and cache is not None: 75 | qkv = self.qkv(x[:, cache[0].shape[-2]:, :]) 76 | else: 77 | qkv = self.qkv(x) 78 | 79 | q, k, v = torch.split(qkv, qkv.shape[-1] // 3, dim=-1) # probably use different dim 80 | q, k, v = self.split_heads(q), self.split_heads(k), self.split_heads(v) 81 | 82 | if use_cache and cache is not None: 83 | past_k, past_v, past_output = cache 84 | k = torch.cat([past_k, k], dim=-2) 85 | v = torch.cat([past_v, v], dim=-2) 86 | attn_scores = self.calculate_attention(q, k, mask) 87 | else: 88 | attn_scores = self.calculate_attention(q, k, mask) 89 | 90 | attn_probs = nn.Softmax(dim=-1)(attn_scores) # [b, np, s, s] 91 | attn_probs = self.attn_drop(attn_probs) 92 | 93 | if self.rudalle_relax: 94 | scale = v.detach().max().item() 95 | context = torch.matmul(attn_probs, v / scale) 96 | else: 97 | context = torch.matmul(attn_probs, v) 98 | 99 | context = context.permute(0, 2, 1, 3).contiguous() 100 | context = context.view(context.shape[0], context.shape[1], 101 | context.shape[2] * context.shape[3]) 102 | 103 | if self.rudalle_relax: 104 | scale = context.detach().max().item() 105 | context /= scale 106 | 107 | out = self.out_proj(context) 108 | if use_cache and cache is not None: 109 | out = torch.concat([past_output, out], dim=-2) 110 | 111 | if use_cache: 112 | cache = k, v, out 113 | 114 | out = self.out_drop(out) 115 | return out, cache 116 | 117 | 118 | class MLP(nn.Module): 119 | def __init__(self, 120 | hidden_dim, 121 | dropout_prob, 122 | rudalle_relax=False 123 | ): 124 | super(MLP, self).__init__() 125 | self.lin1 = nn.Linear(hidden_dim, 4 * hidden_dim) 126 | self.lin2 = nn.Linear(4 * hidden_dim, hidden_dim) 127 | self.dropout = nn.Dropout(dropout_prob) 128 | self.rudalle_relax = rudalle_relax 129 | 130 | def forward(self, x): 131 | x = self.lin1(x) 132 | x = gelu(x) 133 | if self.rudalle_relax: 134 | scale = x.detach().max().item() / 4 135 | x = self.lin2(x / scale) 136 | x = (x / x.detach().max(dim=-1)[0].unsqueeze(-1)) * scale 137 | else: 138 | x = self.lin2(x) 139 | return self.dropout(x) 140 | 141 | 142 | class TransformerLayer(nn.Module): 143 | def __init__(self, 144 | hidden_dim, 145 | num_attn_heads, 146 | attn_dropout_prop, 147 | out_dropout_prob, 148 | cogview_pb_relax=True, 149 | cogview_sandwich_layernorm=True, 150 | cogview_layernorm_prescale=False, 151 | rudalle_relax=False 152 | ): 153 | super().__init__() 154 | self.cogview_pb_relax = cogview_pb_relax 155 | self.cogview_sandwich_layernorm = cogview_sandwich_layernorm 156 | self.cogview_layernorm_prescale = cogview_layernorm_prescale 157 | self.rudalle_relax = rudalle_relax 158 | 159 | self.ln_in = nn.LayerNorm(hidden_dim, eps=1e-5) 160 | self.ln_out = nn.LayerNorm(hidden_dim, eps=1e-5) 161 | if cogview_sandwich_layernorm: 162 | self.first_ln_sandwich = nn.LayerNorm(hidden_dim, eps=1e-5) 163 | self.second_ln_sandwich = nn.LayerNorm(hidden_dim, eps=1e-5) 164 | 165 | self.attn = SelfAttention(hidden_dim=hidden_dim, 166 | num_attn_heads=num_attn_heads, 167 | attn_dropout_prob=attn_dropout_prop, 168 | out_dropout_prob=out_dropout_prob, 169 | cogview_pb_relax=cogview_pb_relax, 170 | rudalle_relax=rudalle_relax) 171 | 172 | self.mlp = MLP(hidden_dim=hidden_dim, 173 | dropout_prob=out_dropout_prob, 174 | rudalle_relax=rudalle_relax) 175 | 176 | def forward(self, x, mask, cache=None, use_cache=False, mlp_cache=False): 177 | if self.cogview_layernorm_prescale: 178 | ln_in = self.ln_in(x / x.detach().max(dim=-1)[0].unsqueeze(-1)) 179 | else: 180 | ln_in = self.ln_in(x) 181 | attn_out, new_cache = self.attn(ln_in, mask, cache, use_cache) 182 | 183 | if self.cogview_sandwich_layernorm: 184 | if self.cogview_layernorm_prescale: 185 | attn_out = self.first_ln_sandwich(attn_out / attn_out.detach().max(dim=-1)[0].unsqueeze(-1)) 186 | else: 187 | attn_out = self.first_ln_sandwich(attn_out) 188 | 189 | x = x + attn_out 190 | cached = 0 if cache is None else cache[0].shape[2] 191 | 192 | if self.cogview_layernorm_prescale: 193 | ln_out = self.ln_out(x / x.detach().max(dim=-1)[0].unsqueeze(-1)) 194 | else: 195 | ln_out = self.ln_out(x) 196 | 197 | if use_cache and cached: 198 | mlp_out = torch.cat( 199 | (cache[-1] if mlp_cache else ln_out[..., :cached, :], self.mlp(ln_out[..., :cached, :])), dim=-2) 200 | if mlp_cache: 201 | new_cache = new_cache + (mlp_out,) 202 | else: 203 | mlp_out = self.mlp(ln_out) 204 | 205 | if self.cogview_sandwich_layernorm: 206 | mlp_out = self.second_ln_sandwich(mlp_out) 207 | 208 | x = x + mlp_out 209 | 210 | return x, new_cache 211 | 212 | 213 | class Transformer(nn.Module): 214 | def __init__(self, 215 | num_layers, 216 | hidden_dim, 217 | num_attn_heads, 218 | image_tokens_per_dim, 219 | seg_tokens_per_dim, 220 | text_length, 221 | attn_dropout_prop=0, 222 | out_dropout_prob=0, 223 | cogview_pb_relax=True, 224 | cogview_sandwich_layernorm=True, 225 | cogview_layernorm_prescale=False, 226 | rudalle_relax=False 227 | ): 228 | super(Transformer, self).__init__() 229 | self.num_layers = num_layers 230 | self.cogview_pb_relax = cogview_pb_relax 231 | self.rudalle_relax = rudalle_relax 232 | 233 | self.layers = nn.ModuleList([ 234 | TransformerLayer( 235 | hidden_dim, 236 | num_attn_heads, 237 | attn_dropout_prop, 238 | out_dropout_prob, 239 | cogview_pb_relax, 240 | cogview_sandwich_layernorm, 241 | cogview_layernorm_prescale, 242 | rudalle_relax 243 | ) for _ in range(num_layers) 244 | ]) 245 | 246 | self.register_buffer("mask", self._create_mask(text_length, seg_tokens_per_dim, image_tokens_per_dim)) 247 | self.final_ln = nn.LayerNorm(hidden_dim, eps=1e-5) 248 | 249 | def _create_mask(self, text_length, seg_tokens_per_dim, image_tokens_per_dim): 250 | size = text_length + seg_tokens_per_dim ** 2 + image_tokens_per_dim ** 2 251 | return torch.tril(torch.ones(size, size, dtype=torch.float32)) 252 | 253 | def get_block_size(self): 254 | return self.block_size 255 | 256 | def forward(self, x, attn_mask, cache=None, use_cache=None): 257 | if cache is None: 258 | cache = {} 259 | 260 | for i, layer in enumerate(self.layers): 261 | mask = attn_mask 262 | layer_mask = self.mask[:mask.size(2), :mask.size(3)] 263 | mask = torch.mul(attn_mask, layer_mask) 264 | x, layer_cache = layer(x, mask, cache.get(i), mlp_cache=i == len(self.layers) - 1, use_cache=use_cache) 265 | cache[i] = layer_cache 266 | 267 | if self.rudalle_relax: 268 | ln_out = self.final_ln(x / x.detach().max(dim=-1)[0].unsqueeze(-1)) 269 | else: 270 | ln_out = self.final_ln(x) 271 | 272 | return ln_out, cache 273 | 274 | 275 | class MakeAScene(nn.Module): 276 | def __init__(self, 277 | num_layers, 278 | hidden_dim, 279 | num_attn_heads, 280 | image_vocab_size, 281 | seg_vocab_size, 282 | text_vocab_size, 283 | image_tokens_per_dim, 284 | seg_tokens_per_dim, 285 | text_length 286 | ): 287 | super(MakeAScene, self).__init__() 288 | self.image_tokens_per_dim = image_tokens_per_dim 289 | self.seg_tokens_per_dim = seg_tokens_per_dim 290 | self.image_length = image_tokens_per_dim ** 2 291 | self.seg_length = seg_tokens_per_dim ** 2 292 | self.text_length = text_length 293 | self.total_length = self.text_length + self.seg_length + self.image_length 294 | self.text_vocab_size = text_vocab_size 295 | 296 | self.transformer = Transformer(num_layers, hidden_dim, num_attn_heads, 297 | image_tokens_per_dim, seg_tokens_per_dim, 298 | text_length) 299 | 300 | self.image_token_embedding = nn.Embedding(image_vocab_size, hidden_dim) 301 | self.seg_token_embedding = nn.Embedding(seg_vocab_size, hidden_dim) 302 | self.text_token_embedding = nn.Embedding(text_vocab_size, hidden_dim) 303 | 304 | self.text_pos_embeddings = torch.nn.Embedding(text_length, hidden_dim) 305 | self.seg_row_embeddings = torch.nn.Embedding(seg_tokens_per_dim, hidden_dim) 306 | self.seg_col_embeddings = torch.nn.Embedding(seg_tokens_per_dim, hidden_dim) 307 | self.image_row_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_dim) 308 | self.image_col_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_dim) 309 | self._init_weights(self.text_pos_embeddings) 310 | self._init_weights(self.seg_row_embeddings) 311 | self._init_weights(self.seg_col_embeddings) 312 | self._init_weights(self.image_row_embeddings) 313 | self._init_weights(self.image_col_embeddings) 314 | 315 | self.to_logits = torch.nn.Sequential( 316 | torch.nn.LayerNorm(hidden_dim), # TODO: check if this is redundant 317 | torch.nn.Linear(hidden_dim, image_vocab_size), 318 | ) 319 | 320 | def _init_weights(self, module): 321 | if isinstance(module, (nn.Linear, nn.Embedding)): 322 | module.weight.data.normal_(mean=0.0, std=0.02) 323 | if isinstance(module, nn.Linear) and module.bias is not None: 324 | module.bias.data.zero_() 325 | elif isinstance(module, nn.LayerNorm): 326 | module.bias.data.zero_() 327 | module.weight.data.fill_(1.0) 328 | 329 | def get_seg_pos_embeddings(self, seg_input_ids): 330 | input_shape = seg_input_ids.size() 331 | row_ids = torch.arange(input_shape[-1], 332 | dtype=torch.long, device=self.device) // self.seg_tokens_per_dim 333 | row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1]) 334 | col_ids = torch.arange(input_shape[-1], 335 | dtype=torch.long, device=self.device) % self.seg_tokens_per_dim 336 | col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1]) 337 | return self.seg_row_embeddings(row_ids) + self.seg_col_embeddings(col_ids) 338 | 339 | def get_image_pos_embeddings(self, image_input_ids, past_length=0): 340 | input_shape = image_input_ids.size() 341 | row_ids = torch.arange(past_length, input_shape[-1] + past_length, 342 | dtype=torch.long, device=self.device) // self.image_tokens_per_dim 343 | row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1]) 344 | col_ids = torch.arange(past_length, input_shape[-1] + past_length, 345 | dtype=torch.long, device=self.device) % self.image_tokens_per_dim 346 | col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1]) 347 | return self.image_row_embeddings(row_ids) + self.image_col_embeddings(col_ids) 348 | 349 | def forward(self, text_tokens, seg_tokens, img_tokens): 350 | text_range = torch.arange(self.text_length) 351 | text_range += (self.text_vocab_size - self.text_length) 352 | text_range = text_range.to(self.device) 353 | text_tokens = torch.where(text_tokens == 0, text_range, text_tokens) 354 | text_pos = self.text_pos_embeddings(torch.arange(text_tokens.shape[1], device=self.device)) 355 | text_embeddings = self.text_token_embedding(text_tokens) + text_pos 356 | 357 | seg_pos = self.get_seg_pos_embeddings(seg_tokens) 358 | seg_embeddings = self.seg_token_embedding(seg_tokens) + seg_pos 359 | 360 | embeddings = torch.cat((text_embeddings, seg_embeddings), dim=1) 361 | if img_tokens is not None: 362 | img_pos = self.get_image_pos_embeddings(img_tokens) 363 | image_embeddings = self.image_token_embedding(img_tokens) + img_pos 364 | embeddings = torch.cat((embeddings, image_embeddings), dim=1) 365 | 366 | attention_mask = torch.tril( 367 | torch.ones((embeddings.shape[0], 1, self.total_length, self.total_length), device=self.device) 368 | ) 369 | attention_mask[:, :, :-self.image_length, :-self.image_length] = 1 370 | attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]] 371 | 372 | transformer_output, present_cache = self.transformer( 373 | embeddings, attention_mask, 374 | cache=None, use_cache=False 375 | ) 376 | 377 | logits = self.to_logits(transformer_output) 378 | return logits[:, -self.image_length-1:-1, :] 379 | 380 | 381 | if __name__ == '__main__': 382 | from omegaconf import OmegaConf 383 | batch_size = 2 384 | 385 | num_layers = 2 386 | hidden_dim = 64 387 | num_attn_heads = 8 388 | 389 | text_length = 128 390 | seg_per_dim = 16 391 | image_per_dim = 32 392 | 393 | text_vocab_size = 128 394 | seg_vocab_size = 128 395 | image_vocab_size = 128 396 | 397 | model = MakeAScene(num_layers, hidden_dim, num_attn_heads, image_vocab_size, seg_vocab_size, text_vocab_size+text_length, image_per_dim, seg_per_dim, text_length).to('cuda') 398 | model.device = torch.device('cuda') 399 | pred_logits = model(torch.randint(high=text_vocab_size+text_length, size=(batch_size, text_length)).cuda(), 400 | torch.randint(high=seg_vocab_size, size=(batch_size, seg_per_dim**2)).cuda(), 401 | torch.randint(high=image_vocab_size, size=(batch_size, image_per_dim**2)).cuda()) 402 | 403 | assert pred_logits.shape == torch.Size([batch_size, image_per_dim ** 2, image_vocab_size]) 404 | -------------------------------------------------------------------------------- /models/vqvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from .modules import Encoder, Decoder 5 | from .modules import Codebook 6 | 7 | 8 | class VQBASE(nn.Module): 9 | def __init__(self, ddconfig, n_embed, embed_dim, init_steps, reservoir_size): 10 | super(VQBASE, self).__init__() 11 | self.encoder = Encoder(**ddconfig) 12 | self.decoder = Decoder(**ddconfig) 13 | self.quantize = Codebook(n_embed, embed_dim, beta=0.25, init_steps=init_steps, reservoir_size=reservoir_size) # TODO: change length_one_epoch 14 | self.quant_conv = nn.Sequential( 15 | nn.Conv2d(ddconfig["z_channels"], embed_dim, 1), 16 | nn.SyncBatchNorm(embed_dim) 17 | ) 18 | self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 19 | 20 | def encode(self, x): 21 | h = self.encoder(x) 22 | h = self.quant_conv(h) 23 | quant, emb_loss, info = self.quantize(h) 24 | return quant, emb_loss 25 | 26 | def decode(self, quant): 27 | quant = self.post_quant_conv(quant) 28 | dec = self.decoder(quant) 29 | return dec 30 | 31 | def decode_code(self, code_b): 32 | quant_b = self.quantize.embed_code(code_b) 33 | dec = self.decode(quant_b) 34 | return dec 35 | 36 | def forward(self, input): 37 | quant, diff = self.encode(input) 38 | dec = self.decode(quant) 39 | return dec, diff 40 | 41 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.distributed as dist 5 | from torch.utils.data import DataLoader 6 | from torch.nn.parallel import DistributedDataParallel 7 | import torchvision 8 | import torch.multiprocessing as mp 9 | from tqdm import tqdm 10 | import hydra 11 | from log_utils import Logger, Visualizer 12 | from utils import collate_fn, change_requires_grad 13 | import os 14 | from time import time 15 | import traceback 16 | 17 | 18 | def train(proc_id, cfg): 19 | print(proc_id) 20 | parallel = len(cfg.devices) > 1 21 | if parallel: 22 | torch.cuda.set_device(proc_id) 23 | torch.backends.cudnn.benchmark = True 24 | dist.init_process_group(backend="nccl", init_method="env://", world_size=len(cfg.devices), rank=proc_id) 25 | device = torch.device(proc_id) 26 | dataset = hydra.utils.instantiate(cfg.dataset, _recursive_=False) 27 | print(cfg.dataset) 28 | dataloader = DataLoader(dataset, **cfg.dataloader, collate_fn=collate_fn) 29 | model = hydra.utils.instantiate(cfg.model).to(device) 30 | loss_fn = hydra.utils.instantiate(cfg.loss).to(device) 31 | if parallel: 32 | model = DistributedDataParallel(model, device_ids=[device], output_device=device) #just try, find_unused_parameters=True) 33 | if "discriminator" in loss_fn._modules.keys(): 34 | loss_fn = DistributedDataParallel(loss_fn, device_ids=[device], output_device=device,) 35 | logger = Logger(proc_id, device=device) 36 | 37 | if cfg.mode == "pretrain_segmentation": 38 | optim = torch.optim.Adam(model.parameters(), **cfg.optimizer) 39 | 40 | for step in range(cfg.total_steps): 41 | data = next(dataloader_iter) 42 | _, seg = data 43 | seg = seg.to(device) 44 | seg_rec, q_loss = model(seg) 45 | loss = loss_fn(q_loss, seg, seg_rec) 46 | 47 | if step % cfg.log_period == 0: 48 | logger.log(loss, q_loss, seg, seg_rec, step) 49 | torch.save(model.state_dict(), "checkpoint.pt") 50 | 51 | loss.backward() 52 | if step % cfg.accumulate_grad == 0: 53 | optim.step() 54 | optim.zero_grad() 55 | 56 | if step == cfg.total_steps: 57 | torch.save(model.state_dict(), "final.pt") 58 | return 59 | 60 | elif cfg.mode == "pretrain_image": 61 | vq_optim = torch.optim.Adam(model.parameters(), **cfg.optimizer.vq) 62 | for param_group in vq_optim.param_groups: 63 | param_group["lr"]/=cfg.accumulate_grad 64 | disc_optim = torch.optim.Adam(loss_fn.module.discriminator.parameters(), **cfg.optimizer.disc) 65 | for param_group in disc_optim.param_groups: 66 | param_group["lr"]/=cfg.accumulate_grad 67 | 68 | start = 0 69 | if cfg.resume: 70 | checkpoint = torch.load(cfg.checkpoint, map_location=device) 71 | model.module.load_state_dict(checkpoint["model"]) 72 | loss_fn.module.discriminator.load_state_dict(checkpoint["discriminator"]) 73 | vq_optim.load_state_dict(checkpoint["optim"]) 74 | disc_optim.load_state_dict(checkpoint["disc_optim"]) 75 | start = checkpoint["step"] 76 | model.module.quantize.q_counter = start 77 | 78 | pbar = tqdm(enumerate(dataloader, start=start), total=cfg.total_steps, initial=start) if proc_id == 0 else enumerate(dataloader, start=start) 79 | try: 80 | for step, data in pbar: 81 | img, _, bbox_objects, bbox_faces, _ = data 82 | img = img.to(device) 83 | bbox_objects = bbox_objects.to(device).to(torch.float32) 84 | img_rec, q_loss = model(img) 85 | 86 | change_requires_grad(model, False) 87 | d_loss , (d_loss_ema,) = loss_fn(optimizer_idx=1, global_step=step, images=img, reconstructions=img_rec) 88 | d_loss.backward() 89 | change_requires_grad(model, True) 90 | 91 | change_requires_grad(loss_fn.module.discriminator, False) 92 | loss, (nll_loss, face_loss, g_loss) = loss_fn(optimizer_idx=0, global_step=step, images=img, 93 | reconstructions=img_rec, 94 | codebook_loss=q_loss, bbox_obj=bbox_objects, 95 | bbox_face=bbox_faces, 96 | last_layer=model.module.decoder.model[-1]) 97 | loss.backward() 98 | change_requires_grad(loss_fn.module.discriminator, True) 99 | if step % cfg.accumulate_grad == 0: 100 | disc_optim.step() 101 | disc_optim.zero_grad() 102 | vq_optim.step() 103 | vq_optim.zero_grad() 104 | 105 | if step % cfg.log_period == 0 and proc_id == 0: 106 | logger.log(loss=loss, q_loss=q_loss, img=img, img_rec=img_rec, d_loss=d_loss, nll_loss=nll_loss, 107 | face_loss=face_loss, g_loss=g_loss, d_loss_ema=d_loss_ema, step=step) 108 | if step % cfg.save_period == 0 and proc_id == 0: 109 | state = { 110 | "model": model.module.state_dict(), 111 | "discriminator": loss_fn.module.discriminator.state_dict(), 112 | "optim": vq_optim.state_dict(), 113 | "disc_optim": disc_optim.state_dict(), 114 | "step": step 115 | } 116 | torch.save(state, f"checkpoint_{step//5e4}.pt") 117 | 118 | if step == cfg.total_steps: 119 | state = { 120 | "model": model.module.state_dict(), 121 | "discriminator": loss_fn.module.discriminator.state_dict(), 122 | "optim": vq_optim.state_dict(), 123 | "disc_optim": disc_optim.state_dict(), 124 | "step": step 125 | } 126 | torch.save(state, "final.pt") 127 | return 128 | except Exception as e: 129 | print('Caught exception in worker thread (x = %d):' % proc_id) 130 | # This prints the type, value, and stack trace of the 131 | # current exception being handled. 132 | with open("error.log", "a") as f: 133 | traceback.print_exc(file=f) 134 | raise e 135 | 136 | elif cfg.mode == "train_transformer": 137 | optim = torch.optim.Adam(model.parameters(), **cfg.optimizer.stage2) 138 | 139 | pbar = tqdm(enumerate(dataloader), total=cfg.total_steps) if proc_id == 0 else enumerate(dataloader) 140 | try: 141 | for step, data in pbar: 142 | img_token, seg_token, _, _, text_token = data 143 | img_token = img_token.to(device) 144 | seg_token = seg_token.to(device) 145 | text_token = text_token.to(device) 146 | 147 | if step >= cfg.start_uncond and random() < cfg.uncond_p: 148 | text_token *= 0 149 | 150 | pred_logit = model(text_token, seg_token, img_token) 151 | 152 | loss = F.cross_entropy(pred_logit.view(-1, pred_logit.shape[-1]), img_token.view(-1)) 153 | loss.backward() 154 | if step % cfg.accumulate_grad == 0: 155 | optim.step() 156 | optim.zero_grad() 157 | 158 | ### LOGGING PART 159 | if step % cfg.log_period == 0: 160 | logger.log(loss=loss, step=step) 161 | torch.save(model.module.state_dict(), "checkpoint.pt") 162 | 163 | if step == cfg.total_steps: 164 | torch.save(model.module.state_dict(), "final.pt") 165 | return 166 | except Exception as e: 167 | print('Caught exception in worker thread (x = %d):' % proc_id) 168 | 169 | # This prints the type, value, and stack trace of the 170 | # current exception being handled. 171 | with open("error.log", "a") as f: 172 | traceback.print_exc(file=f) 173 | raise e 174 | 175 | def visualize(cfg): 176 | device = torch.device(cfg.devices[0]) 177 | model = torch.nn.DataParallel(hydra.utils.instantiate(cfg.model)).to(device) 178 | checkpoint = hydra.utils.to_absolute_path(cfg.checkpoint) 179 | state_dict = torch.load(checkpoint, map_location=device) 180 | model.load_state_dict(state_dict) 181 | model = model.module 182 | model.eval() 183 | img = torch.rand(1, 159, 256, 256).to(device) 184 | dataset = hydra.utils.instantiate(cfg.dataset) 185 | visualizer = Visualizer(device=device) 186 | print("Processing...") 187 | for i, data in enumerate(dataset): 188 | if i == 40: 189 | break 190 | print("Processing image ", i) 191 | img, seg = data 192 | img = img.to(device).unsqueeze(0) 193 | seg = seg.to(device).unsqueeze(0) 194 | seg_rec, _ = model(seg) 195 | visualizer(i, image=img, seg=seg, seg_rec=seg_rec, ) 196 | 197 | print(model(dataset[0][1].unsqueeze(0).to(device))[0].shape) 198 | 199 | 200 | def preprocess_dataset(cfg): 201 | # dataset = hydra.utils.instantiate(cfg.dataset,) 202 | dataset = cfg.dataset 203 | preprocessor = hydra.utils.instantiate(cfg.preprocessor) 204 | preprocessor(dataset) 205 | 206 | 207 | @hydra.main(config_path="conf", config_name="img_config", version_base="1.2") 208 | def launch(cfg): 209 | if "pretrain" in cfg.mode: 210 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(d) for d in cfg.devices]) 211 | cfg.checkpoint = hydra.utils.to_absolute_path(cfg.checkpoint) 212 | if len(cfg.devices) == 1: 213 | train(0, cfg) 214 | else: 215 | os.environ["MASTER_ADDR"] = "localhost" 216 | os.environ["MASTER_PORT"] = "33751" 217 | p = mp.spawn(train, nprocs=len(cfg.devices), args=(cfg,)) 218 | elif "show" in cfg.mode: 219 | visualize(cfg) 220 | elif "preprocess_dataset" in cfg.mode: 221 | preprocess_dataset(cfg) 222 | 223 | 224 | if __name__ == "__main__": 225 | launch() 226 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def collate_fn(batch, need_seg=False): 5 | images = torch.stack([i[0] for i in batch], dim=0) 6 | if need_seg: 7 | segmentation_maps = torch.stack([torch.Tensor(i[1]).permute(2, 0, 1) for i in batch], dim=0) 8 | else: 9 | segmentation_maps = [] 10 | object_boxes = [list(map(lambda bbox: list(map(int, bbox))[:-1], i[2])) for i in batch] 11 | face_boxes = [list(map(lambda bbox: list(map(int, bbox))[:-1], i[3])) for i in batch] 12 | captions = [i[4] for i in batch] 13 | return [images, segmentation_maps, object_boxes, face_boxes, captions] 14 | 15 | 16 | def collate_fn_(batch): 17 | images, segmentation_maps = None, None 18 | object_boxes, face_boxes, captions = [], [], [] 19 | for i in batch: 20 | images = i[0] if images is None else torch.stack([images, i[0]], dim=0) 21 | segmentation_maps = torch.Tensor(i[1]).permute(2, 0, 1) if segmentation_maps is None else torch.stack([segmentation_maps, torch.Tensor(i[1]).permute(2, 0, 1)], dim=0) 22 | object_boxes.append(i[2][:-1]) 23 | face_boxes.append(i[3][:-1]) 24 | captions.append(i[4]) 25 | return [images, segmentation_maps, object_boxes, face_boxes, captions] 26 | 27 | def change_requires_grad(model, state): 28 | for parameter in model.parameters(): 29 | parameter.requires_grad = state 30 | 31 | --------------------------------------------------------------------------------