├── LICENSE ├── README.md ├── configs ├── celeb.yaml ├── city.yaml └── cocostuff.yaml ├── file ├── img.jpg ├── mask.png └── teaser.png ├── module ├── data │ ├── builder.py │ ├── celeb.py │ ├── cityscapes.py │ ├── cocostuff.py │ ├── hook.py │ ├── load_dataset.py │ ├── prepare_text.py │ ├── transform.py │ └── utils.py └── pipe │ ├── pipe.py │ └── val.py ├── requirements.txt └── scripts ├── demo.py └── semrf.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Chaoyang Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SemFlow: Binding Semantic Segmentation and Image Synthesis via Rectified Flow 2 | 3 | 4 |
5 | 6 | [![arXiv](https://img.shields.io/badge/arXiv-2405.20282-b31b1b.svg)](https://arxiv.org/abs/2405.20282) 7 | [![Project Website](https://img.shields.io/badge/🔗-Project_Website-blue.svg)](https://wang-chaoyang.github.io/project/semflow) 8 | 9 |
10 | 11 |
12 |

13 | NeurIPS 2024 14 |

15 |
16 | 17 |

18 | 19 |

20 | 21 | ## Requirements 22 | 23 | 1. Install `torch==2.1.0`. 24 | 2. Install other pip packages via `pip install -r requirements.txt`. 25 | 3. Our model is based on [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), please download and put it into `dataset/pretrain`. 26 | 27 | ## Demo 28 | 29 | We provide a [demo](scripts/demo.py) for visualizing the bidirectional generation capability of the model. First, download the [checkpoint](https://huggingface.co/chaoyangw/semflow/resolve/main/face.pt). Then, rename it as `diffusion_pytorch_model.bin` and put it into `demo/unet`. Finally, copy `stable-diffusion-v1-5/unet/config.json` to `demo/unet`. 30 | 31 | 32 | ``` 33 | # mask2image 34 | 35 | python scripts/demo.py --pretrain_model dataset/pretrain/stable-diffusion-v1-5 --valmode gen --ckpt demo/unet --im_path file/img.jpg --mask_path file/mask.png 36 | 37 | # image2mask 38 | 39 | python scripts/demo.py --pretrain_model dataset/pretrain/stable-diffusion-v1-5 --valmode seg --ckpt demo/unet --im_path file/img.jpg --mask_path file/mask.png 40 | ``` 41 | 42 | ## Data Preparation 43 | Please download the datasets from [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ), [COCOStuff](https://github.com/nightrome/cocostuff), [Cityscapes](https://www.cityscapes-dataset.com/). Put them under `dataset` and rearrange as follows. 44 | 45 | ``` 46 | dataset 47 | ├── celebAmask 48 | │ ├── CelebA-512-img 49 | │ ├── CelebAMask-HQ-mergemask 50 | │ └── metas 51 | ├── cityscapes 52 | │ ├── gtFine 53 | │ ├── leftImg8bit 54 | │ └── metas 55 | └── cocostuff 56 | ├── annotations 57 | │ ├── train2017 58 | │ └── val2017 59 | ├── metas 60 | │ ├── train.txt 61 | │ └── val.txt 62 | ├── train2017 63 | └── val2017 64 | ``` 65 | 66 | We specify the training and testing data with indices, which are recorded in `metas/train.txt` and `metas/val.txt`. The indices should look like follows, 67 | 68 | ``` 69 | aachen/aachen_000000_000019 # cityscapes 70 | 000000000285 # cocostuff 71 | 0 # celebAmask 72 | ``` 73 | 74 | ## Training Scripts 75 | 76 | The [script](scripts/semrf.py) is launched by accelerate. The saved path is specified by `env.output_dir`. 77 | ``` 78 | # run 79 | accelerate launch --multi_gpu --num_processes [GPUS] scripts/semrf.py [CONFIG] [PY_ARGS] 80 | ``` 81 | You can also use `deepspeed` and `bf16` to reduce the usage of VRAM. 82 | ``` 83 | accelerate launch --multi_gpu --num_processes [GPUS] scripts/semrf.py [CONFIG] env.mixed_precision=bf16 env.deepspeed=true [PY_ARGS] 84 | 85 | ``` 86 | 87 | 88 | ## Citation 89 | 90 | If you find SemFlow useful for your work, please kindly consider citing our paper: 91 | 92 | ```bibtex 93 | @article{wang2024semflow, 94 | author = {Wang, Chaoyang and Li, Xiangtai and Qi, Lu and Ding, Henghui and Tong, Yunhai and Yang, Ming-Hsuan}, 95 | title = {SemFlow: Binding Semantic Segmentation and Image Synthesis via Rectified Flow}, 96 | journal = {arXiv}, 97 | year = {2024} 98 | } 99 | ``` 100 | 101 | ## License 102 | 103 | MIT license -------------------------------------------------------------------------------- /configs/celeb.yaml: -------------------------------------------------------------------------------- 1 | pretrain_model: # 'dataset/pretrain/stable-diffusion-v1-5' 2 | resume_from_checkpoint: latest 3 | eval_only: False 4 | mode: 'semantic' 5 | db: celeb 6 | 7 | pa: 8 | k: 6 9 | s: 50 10 | 11 | pert: 12 | t: u 13 | co: 0.1 14 | 15 | cfg: 16 | cond_prob: 0 17 | guide: 1 18 | image_guide: 1 19 | text: null 20 | continus: true 21 | sampler: ddpm 22 | 23 | env: 24 | output_dir: work_dirs/semflow 25 | logging_dir: 'log' 26 | report_to: 'tensorboard' 27 | gradient_accumulation_steps: 1 28 | seed: 123 29 | mixed_precision: 'bf16' 30 | deepspeed: True 31 | allow_tf32: False 32 | scale_lr: False 33 | use_xformers: False 34 | ema: False 35 | max_train_steps: 80000 36 | max_grad_norm: 1.0 37 | checkpointing_steps: 10000 38 | size: 512 39 | val_iter: 5000 40 | splsave: false 41 | quickval: 42 | lossmask: false 43 | vis: true 44 | 45 | 46 | 47 | train: 48 | batch_size: 32 49 | num_workers: 8 50 | find_unused_parameters: False 51 | gradient_checkpointing: False 52 | 53 | 54 | 55 | eval: 56 | mask_th: 0.5 57 | count_th: 512 58 | overlap_th: 0.5 59 | batch_size: 8 60 | num_workers: 8 61 | 62 | 63 | optim: 64 | name: adamw 65 | lr: 2.0e-5 66 | beta1: 0.9 67 | beta2: 0.999 68 | weight_decay: 0.01 69 | epsilon: 1.0e-08 70 | 71 | 72 | 73 | lr_scheduler: 74 | name: linear 75 | final_lr: 0.000001 76 | warmup_steps: 500 77 | 78 | 79 | transformation: 80 | flip: True 81 | crop: random 82 | 83 | 84 | ds_base: 85 | ignore_label: 0 86 | 87 | -------------------------------------------------------------------------------- /configs/city.yaml: -------------------------------------------------------------------------------- 1 | pretrain_model: # 'dataset/pretrain/stable-diffusion-v1-5' 2 | resume_from_checkpoint: latest 3 | eval_only: False 4 | mode: 'semantic' 5 | db: city 6 | 7 | pa: 8 | k: 6 9 | s: 50 10 | 11 | pert: 12 | t: u 13 | co: 0.1 14 | 15 | cfg: 16 | cond_prob: 0 17 | guide: 1 18 | image_guide: 1 19 | text: null 20 | continus: true 21 | sampler: ddpm 22 | 23 | env: 24 | output_dir: work_dirs/semflow 25 | logging_dir: 'log' 26 | report_to: 'tensorboard' 27 | gradient_accumulation_steps: 1 28 | seed: 123 29 | mixed_precision: 'bf16' 30 | deepspeed: True 31 | allow_tf32: False 32 | scale_lr: False 33 | use_xformers: False 34 | ema: False 35 | max_train_steps: 8000 36 | max_grad_norm: 1.0 37 | checkpointing_steps: 10000 38 | size: 5120 39 | val_iter: 1000 40 | splsave: false 41 | quickval: 42 | lossmask: false 43 | vis: true 44 | 45 | 46 | 47 | train: 48 | batch_size: 32 49 | num_workers: 8 50 | find_unused_parameters: False 51 | gradient_checkpointing: False 52 | 53 | 54 | 55 | eval: 56 | mask_th: 0.5 57 | count_th: 512 58 | overlap_th: 0.5 59 | batch_size: 8 60 | num_workers: 8 61 | 62 | 63 | optim: 64 | name: adamw 65 | lr: 5.0e-5 66 | beta1: 0.9 67 | beta2: 0.999 68 | weight_decay: 0.01 69 | epsilon: 1.0e-08 70 | 71 | 72 | 73 | lr_scheduler: 74 | name: linear 75 | final_lr: 0.000001 76 | warmup_steps: 500 77 | 78 | 79 | transformation: 80 | flip: True 81 | crop: random 82 | 83 | 84 | ds_base: 85 | ignore_label: 0 86 | 87 | -------------------------------------------------------------------------------- /configs/cocostuff.yaml: -------------------------------------------------------------------------------- 1 | pretrain_model: # 'dataset/pretrain/stable-diffusion-v1-5' 2 | resume_from_checkpoint: latest 3 | eval_only: False 4 | mode: 'semantic' 5 | db: cocostuff 6 | 7 | valmode: seg 8 | valstep: 25 9 | valsample: euler 10 | 11 | 12 | pa: 13 | k: 6 14 | s: 50 15 | 16 | 17 | pert: 18 | t: u 19 | co: 0.1 20 | mode: none 21 | 22 | cfg: 23 | cond_prob: 0 24 | guide: 1 25 | image_guide: 1 26 | text: null 27 | continus: true 28 | sampler: ddpm 29 | 30 | env: 31 | output_dir: work_dirs/semflow 32 | logging_dir: 'log' 33 | report_to: 'tensorboard' 34 | gradient_accumulation_steps: 1 35 | seed: 123 36 | mixed_precision: "no" 37 | deepspeed: False 38 | allow_tf32: False 39 | scale_lr: False 40 | use_xformers: False 41 | ema: False 42 | max_train_steps: 320000 43 | max_grad_norm: 1.0 44 | checkpointing_steps: 10000 45 | size: 512 46 | val_iter: 5000 47 | splsave: false 48 | quickval: 49 | lossmask: false 50 | vis: true 51 | 52 | 53 | train: 54 | batch_size: 16 55 | num_workers: 8 56 | find_unused_parameters: False 57 | gradient_checkpointing: False 58 | 59 | 60 | eval: 61 | mask_th: 0.5 62 | count_th: 512 63 | overlap_th: 0.5 64 | batch_size: 8 65 | num_workers: 8 66 | 67 | optim: 68 | name: adamw 69 | lr: 1.0e-5 70 | beta1: 0.9 71 | beta2: 0.999 72 | weight_decay: 0.01 73 | epsilon: 1.0e-08 74 | 75 | 76 | lr_scheduler: 77 | name: constant_with_warmup 78 | final_lr: 0.000001 79 | warmup_steps: 10000 80 | 81 | 82 | transformation: 83 | flip: True 84 | crop: random 85 | 86 | ds_base: 87 | ignore_label: 0 88 | 89 | -------------------------------------------------------------------------------- /file/img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-chaoyang/SemFlow/0e56cb0a148648a3edbc844c87c6f6d3be547456/file/img.jpg -------------------------------------------------------------------------------- /file/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-chaoyang/SemFlow/0e56cb0a148648a3edbc844c87c6f6d3be547456/file/mask.png -------------------------------------------------------------------------------- /file/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-chaoyang/SemFlow/0e56cb0a148648a3edbc844c87c6f6d3be547456/file/teaser.png -------------------------------------------------------------------------------- /module/data/builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | 8 | def build_palette(k=6,s=None): 9 | if s==None: 10 | s = 250 // (k-1) 11 | else: 12 | assert s*(k-1)<255 13 | palette = [] 14 | for m0 in range(k): 15 | for m1 in range(k): 16 | for m2 in range(k): 17 | palette.extend([s*m0,s*m1,s*m2]) 18 | return palette 19 | 20 | 21 | 22 | 23 | if __name__ == '__main__': 24 | pass 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /module/data/celeb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | import torch 5 | import time 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | 10 | from .builder import build_palette 11 | 12 | class Normalize(object): 13 | 14 | def __init__(self, mean=0.5, std=0.5): 15 | self.mean = mean 16 | self.std = std 17 | 18 | def __call__(self, sample): 19 | 20 | for k in sample.keys(): 21 | if k in ['image', 'image_panseg','image_semseg']: 22 | sample[k] = (sample[k]-self.mean)/self.std 23 | return sample 24 | 25 | 26 | class Celeb(data.Dataset): 27 | COCO_CATEGORY_NAMES = ['background','skin', 'nose', 'eye_g', 'l_eye', 28 | 'r_eye', 'l_brow', 'r_brow', 'l_ear', 'r_ear', 29 | 'mouth', 'u_lip', 'l_lip', 'hair', 'hat', 30 | 'ear_r', 'neck_l', 'neck', 'cloth'] 31 | 32 | def __init__( 33 | self, 34 | data_root: str, 35 | split: str = 'val', 36 | transform = None, 37 | args_palette = None, 38 | ): 39 | print('init celebA, semseg') 40 | _img_dir = osp.join(data_root,'CelebA-512-img') 41 | _seg_dir = osp.join(data_root,'CelebAMask-HQ-mergemask') 42 | self.meta_data = {'category_names':self.COCO_CATEGORY_NAMES} 43 | self.transform = transform 44 | self.post_norm = Normalize() 45 | 46 | self.palette = build_palette(args_palette[0],args_palette[1]) 47 | 48 | if split == 'train': 49 | _datalist = osp.join(data_root,'metas','train.txt') 50 | elif split == 'val': 51 | _datalist = osp.join(data_root,'metas','val.txt') 52 | self.images = [] 53 | self.semsegs = [] 54 | with open(_datalist,'r') as f: 55 | for line in f: 56 | self.images.append(osp.join(_img_dir,line.strip()+'.jpg')) 57 | self.semsegs.append(osp.join(_seg_dir,line.strip()+'.png')) 58 | print(f'processing {len(self.images)} images') 59 | 60 | def __len__(self): 61 | return len(self.images) 62 | 63 | def prepare_pm(self,x): 64 | assert len(x.shape)==2 65 | h,w = x.shape 66 | pm = np.ones((h,w,3))*255 67 | clslist = np.unique(x).tolist() 68 | if 255 in clslist: 69 | raise ValueError() 70 | for _c in clslist: 71 | _x,_y = np.where(x==_c) 72 | pm[_x,_y,:] = self.palette[int(_c)*3:(int(_c)+1)*3] 73 | return pm 74 | 75 | def __getitem__(self, index): 76 | sample = {} 77 | _img = Image.open(self.images[index]).convert('RGB') 78 | sample['image'] = _img 79 | sample['gt_semseg'] = Image.open(self.semsegs[index]) 80 | sample['mask'] = np.ones_like(np.array(sample['gt_semseg'])) 81 | sample['mask'] = Image.fromarray(sample['mask']) 82 | sample['image_semseg'] = Image.fromarray(self.prepare_pm(np.array(sample['gt_semseg'])).astype(np.uint8)) 83 | sample['text'] = None 84 | 85 | sample['meta'] = { 86 | 'im_size': (_img.size[1], _img.size[0]), 87 | 'image_file': self.images[index], 88 | "image_id": int(os.path.basename(self.images[index]).split(".")[0]) 89 | } 90 | 91 | if self.transform is not None: 92 | sample = self.transform(sample) 93 | 94 | sample = self.post_norm(sample) 95 | 96 | return sample -------------------------------------------------------------------------------- /module/data/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | import torch 5 | import time 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | # from .map_tbl import palette 10 | from .builder import build_palette 11 | 12 | 13 | class Normalize(object): 14 | def __init__(self, mean=0.5, std=0.5): 15 | self.mean = mean 16 | self.std = std 17 | 18 | def __call__(self, sample): 19 | 20 | for k in sample.keys(): 21 | if k in ['image', 'image_panseg','image_semseg']: 22 | sample[k] = (sample[k]-self.mean)/self.std 23 | return sample 24 | 25 | 26 | class Cityscapes(data.Dataset): 27 | COCO_CATEGORY_NAMES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 28 | 'traffic light', 'traffic sign', 'vegetation', 'terrain', 29 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 30 | 'motorcycle', 'bicycle'] 31 | 32 | def __init__( 33 | self, 34 | data_root: str, 35 | split: str = 'val', 36 | transform = None, 37 | size=1024, 38 | args_palette = None, 39 | ): 40 | 41 | self.palette = build_palette(args_palette[0],args_palette[1]) 42 | 43 | self.size = size 44 | self.data_root = data_root 45 | self.split = split 46 | if split=='train': 47 | self.training = True 48 | elif split=='val': 49 | self.training = False 50 | else: 51 | raise ValueError() 52 | print(f'init cityspace, semseg, size: {size}') 53 | if size==1024: 54 | _img_dir = osp.join(data_root,'leftImg8bit') 55 | _seg_dir = osp.join(data_root,'gtFine') 56 | else: 57 | raise ValueError() 58 | 59 | self.meta_data = {'category_names':self.COCO_CATEGORY_NAMES} 60 | self.transform = transform 61 | self.post_norm = Normalize() 62 | 63 | if split == 'train': 64 | _datalist = osp.join(data_root,'metas','train.txt') 65 | elif split == 'val': 66 | _datalist = osp.join(data_root,'metas','val.txt') 67 | self.images = [] 68 | self.semsegs = [] 69 | with open(_datalist,'r') as f: 70 | for line in f: 71 | self.images.append(osp.join(_img_dir,split,line.strip()+'_leftImg8bit.png')) 72 | self.semsegs.append(osp.join(_seg_dir,split,line.strip()+'_gtFine_labelTrainIds.png')) 73 | print(f'processing {len(self.images)} images') 74 | 75 | 76 | def __len__(self): 77 | return len(self.images) 78 | 79 | def prepare_pm(self,x): 80 | assert len(x.shape)==2 81 | h,w = x.shape 82 | pm = np.ones((h,w,3))*255 83 | clslist = np.unique(x).tolist() 84 | if 255 in clslist: 85 | clslist.remove(255) 86 | for _c in clslist: 87 | _x,_y = np.where(x==_c) 88 | pm[_x,_y,:] = self.palette[int(_c)*3:(int(_c)+1)*3] 89 | return pm 90 | 91 | def __getitem__(self, index): 92 | sample = {} 93 | _img = Image.open(self.images[index]).convert('RGB') 94 | sample['gt_semseg'] = Image.open(self.semsegs[index]) 95 | 96 | sample['image'] = _img 97 | 98 | sample['image_semseg'] = Image.fromarray(self.prepare_pm(np.array(sample['gt_semseg'])).astype(np.uint8)) 99 | sample['text'] = None 100 | 101 | sample['meta'] = { 102 | 'im_size': (_img.size[1], _img.size[0]), 103 | 'image_file': self.images[index], 104 | "image_id": os.path.basename(self.images[index]).split(".")[0] 105 | } 106 | 107 | if self.transform is not None: 108 | sample = self.transform(sample) 109 | 110 | sample['mask'] = torch.ones(sample['image'].shape[1:]).long() 111 | 112 | sample = self.post_norm(sample) 113 | 114 | return sample -------------------------------------------------------------------------------- /module/data/cocostuff.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import torch.utils.data as data 5 | import torchvision.transforms as T 6 | from PIL import Image 7 | from typing import Optional, Any, Tuple 8 | from .builder import build_palette 9 | 10 | 11 | class Normalize(object): 12 | 13 | def __init__(self, mean=0.5, std=0.5): 14 | self.mean = mean 15 | self.std = std 16 | 17 | def __call__(self, sample): 18 | 19 | for k in sample.keys(): 20 | if k in ['image', 'image_panseg','image_semseg']: 21 | sample[k] = (sample[k]-self.mean)/self.std 22 | return sample 23 | 24 | 25 | class COCOStuff(data.Dataset): 26 | 27 | COCO_CATEGORY_NAMES = [ 28 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 29 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 30 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 31 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 32 | 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 33 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 34 | 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 35 | 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 36 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 37 | 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 38 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 39 | 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 40 | 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 41 | 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 42 | 'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet', 43 | 'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile', 44 | 'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain', 45 | 'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble', 46 | 'floor-other', 'floor-stone', 'floor-tile', 'floor-wood', 'flower', 47 | 'fog', 'food-other', 'fruit', 'furniture-other', 'grass', 'gravel', 48 | 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat', 'metal', 49 | 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net', 50 | 'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 51 | 'platform', 'playingfield', 'railing', 'railroad', 'river', 'road', 52 | 'rock', 'roof', 'rug', 'salad', 'sand', 'sea', 'shelf', 53 | 'sky-other', 'skyscraper', 'snow', 'solid-other', 'stairs', 54 | 'stone', 'straw', 'structural-other', 'table', 'tent', 55 | 'textile-other', 'towel', 'tree', 'vegetable', 'wall-brick', 56 | 'wall-concrete', 'wall-other', 'wall-panel', 'wall-stone', 57 | 'wall-tile', 'wall-wood', 'water-other', 'waterdrops', 58 | 'window-blind', 'window-other', 'wood'] 59 | 60 | class_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 61 | 20, 21, 22, 23, 24, 26, 27, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 62 | 40, 41, 42, 43, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 63 | 60, 61, 62, 63, 64, 66, 69, 71, 72, 73, 74, 75, 76, 77, 78, 79, 64 | 80, 81, 83, 84, 85, 86, 87, 88, 89, 91, 92, 93, 94, 95, 96, 97, 98, 99, 65 | 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 66 | 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 67 | 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 68 | 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 69 | 180, 181] 70 | 71 | mappings = {0 : 0 , 72 | 1 : 1 , 73 | 2 : 2 , 74 | 3 : 3 , 75 | 4 : 4 , 76 | 5 : 5 , 77 | 6 : 6 , 78 | 7 : 7 , 79 | 8 : 8 , 80 | 9 : 9 , 81 | 10 : 10 , 82 | 12 : 11 , 83 | 13 : 12 , 84 | 14 : 13 , 85 | 15 : 14 , 86 | 16 : 15 , 87 | 17 : 16 , 88 | 18 : 17 , 89 | 19 : 18 , 90 | 20 : 19 , 91 | 21 : 20 , 92 | 22 : 21 , 93 | 23 : 22 , 94 | 24 : 23 , 95 | 26 : 24 , 96 | 27 : 25 , 97 | 30 : 26 , 98 | 31 : 27 , 99 | 32 : 28 , 100 | 33 : 29 , 101 | 34 : 30 , 102 | 35 : 31 , 103 | 36 : 32 , 104 | 37 : 33 , 105 | 38 : 34 , 106 | 39 : 35 , 107 | 40 : 36 , 108 | 41 : 37 , 109 | 42 : 38 , 110 | 43 : 39 , 111 | 45 : 40 , 112 | 46 : 41 , 113 | 47 : 42 , 114 | 48 : 43 , 115 | 49 : 44 , 116 | 50 : 45 , 117 | 51 : 46 , 118 | 52 : 47 , 119 | 53 : 48 , 120 | 54 : 49 , 121 | 55 : 50 , 122 | 56 : 51 , 123 | 57 : 52 , 124 | 58 : 53 , 125 | 59 : 54 , 126 | 60 : 55 , 127 | 61 : 56 , 128 | 62 : 57 , 129 | 63 : 58 , 130 | 64 : 59 , 131 | 66 : 60 , 132 | 69 : 61 , 133 | 71 : 62 , 134 | 72 : 63 , 135 | 73 : 64 , 136 | 74 : 65 , 137 | 75 : 66 , 138 | 76 : 67 , 139 | 77 : 68 , 140 | 78 : 69 , 141 | 79 : 70 , 142 | 80 : 71 , 143 | 81 : 72 , 144 | 83 : 73 , 145 | 84 : 74 , 146 | 85 : 75 , 147 | 86 : 76 , 148 | 87 : 77 , 149 | 88 : 78 , 150 | 89 : 79 , 151 | 91 : 80 , 152 | 92 : 81 , 153 | 93 : 82 , 154 | 94 : 83 , 155 | 95 : 84 , 156 | 96 : 85 , 157 | 97 : 86 , 158 | 98 : 87 , 159 | 99 : 88 , 160 | 100 : 89 , 161 | 101 : 90 , 162 | 102 : 91 , 163 | 103 : 92 , 164 | 104 : 93 , 165 | 105 : 94 , 166 | 106 : 95 , 167 | 107 : 96 , 168 | 108 : 97 , 169 | 109 : 98 , 170 | 110 : 99 , 171 | 111 : 100 , 172 | 112 : 101 , 173 | 113 : 102 , 174 | 114 : 103 , 175 | 115 : 104 , 176 | 116 : 105 , 177 | 117 : 106 , 178 | 118 : 107 , 179 | 119 : 108 , 180 | 120 : 109 , 181 | 121 : 110 , 182 | 122 : 111 , 183 | 123 : 112 , 184 | 124 : 113 , 185 | 125 : 114 , 186 | 126 : 115 , 187 | 127 : 116 , 188 | 128 : 117 , 189 | 129 : 118 , 190 | 130 : 119 , 191 | 131 : 120 , 192 | 132 : 121 , 193 | 133 : 122 , 194 | 134 : 123 , 195 | 135 : 124 , 196 | 136 : 125 , 197 | 137 : 126 , 198 | 138 : 127 , 199 | 139 : 128 , 200 | 140 : 129 , 201 | 141 : 130 , 202 | 142 : 131 , 203 | 143 : 132 , 204 | 144 : 133 , 205 | 145 : 134 , 206 | 146 : 135 , 207 | 147 : 136 , 208 | 148 : 137 , 209 | 149 : 138 , 210 | 150 : 139 , 211 | 151 : 140 , 212 | 152 : 141 , 213 | 153 : 142 , 214 | 154 : 143 , 215 | 155 : 144 , 216 | 156 : 145 , 217 | 157 : 146 , 218 | 158 : 147 , 219 | 159 : 148 , 220 | 160 : 149 , 221 | 161 : 150 , 222 | 162 : 151 , 223 | 163 : 152 , 224 | 164 : 153 , 225 | 165 : 154 , 226 | 166 : 155 , 227 | 167 : 156 , 228 | 168 : 157 , 229 | 169 : 158 , 230 | 170 : 159 , 231 | 171 : 160 , 232 | 172 : 161 , 233 | 173 : 162 , 234 | 174 : 163 , 235 | 175 : 164 , 236 | 176 : 165 , 237 | 177 : 166 , 238 | 178 : 167 , 239 | 179 : 168 , 240 | 180 : 169 , 241 | 181 : 170 , 242 | 255 : 255} 243 | 244 | 245 | def __init__( 246 | self, 247 | data_root: str, 248 | split: str = 'val', 249 | transform: Optional[Any] = None, 250 | args_palette = None, 251 | ): 252 | print('init cocostuff, semseg') 253 | self.meta_data = {'category_names':self.COCO_CATEGORY_NAMES} 254 | self.transform = transform 255 | self.post_norm = Normalize() 256 | self.palette = build_palette(args_palette[0],args_palette[1]) 257 | 258 | if split == 'train': 259 | _datalist = osp.join(data_root,'metas','train.txt') 260 | elif split == 'val': 261 | _datalist = osp.join(data_root,'metas','val.txt') 262 | _img_dir = osp.join(data_root,f'{split}2017') 263 | _seg_dir = osp.join(data_root,f'annotations/{split}2017') 264 | self.images = [] 265 | self.semsegs = [] 266 | with open(_datalist,'r') as f: 267 | for line in f: 268 | self.images.append(osp.join(_img_dir,line.strip()+'.jpg')) 269 | self.semsegs.append(osp.join(_seg_dir,line.strip()+'.png')) 270 | print(f'split={split}, processing {len(self.images)} images') 271 | 272 | def prepare_pm(self,x): 273 | assert len(x.shape)==2 274 | h,w = x.shape 275 | pm = np.ones((h,w,3))*255 276 | clslist = np.unique(x).tolist() 277 | if 255 in clslist: 278 | clslist.remove(255) 279 | for _c in clslist: 280 | _x,_y = np.where(x==_c) 281 | pm[_x,_y,:] = self.palette[int(_c)*3:(int(_c)+1)*3] 282 | return pm 283 | 284 | def remap_id(self,x): 285 | assert len(x.shape)==2 286 | clslist = np.unique(x) 287 | newx = np.zeros_like(x) 288 | for _c in clslist: 289 | _x,_y = np.where(x==_c) 290 | newx[_x,_y] = self.mappings[_c] 291 | return newx.astype(np.uint8) 292 | 293 | 294 | def __getitem__(self, index): 295 | sample = {} 296 | _img = Image.open(self.images[index]).convert('RGB') 297 | sample['image'] = _img 298 | 299 | # make sure np.uint8 300 | gt_semseg = Image.open(self.semsegs[index]) 301 | gt_semseg = self.remap_id(np.array(gt_semseg)) 302 | sample['gt_semseg'] = Image.fromarray(gt_semseg) 303 | sample['image_semseg'] = Image.fromarray(self.prepare_pm(gt_semseg).astype(np.uint8)) 304 | 305 | sample['text'] = "" 306 | 307 | # mask with ones for valid pixels 308 | sample['mask'] = np.ones(_img.size[::-1]) 309 | sample['mask'] = Image.fromarray(sample['mask']) 310 | 311 | # meta data 312 | sample['meta'] = { 313 | 'im_size': (_img.size[1], _img.size[0]), # h,w 314 | 'image_file': self.images[index], 315 | "image_id": int(os.path.basename(self.images[index]).split(".")[0]), 316 | } 317 | 318 | if self.transform is not None: 319 | sample = self.transform(sample) 320 | 321 | sample = self.post_norm(sample) 322 | 323 | return sample 324 | 325 | 326 | def __len__(self): 327 | return len(self.images) 328 | 329 | 330 | if __name__ == '__main__': 331 | pass -------------------------------------------------------------------------------- /module/data/hook.py: -------------------------------------------------------------------------------- 1 | import os 2 | from accelerate import Accelerator,DistributedType 3 | import shutil 4 | import torch 5 | import warnings 6 | 7 | def resume_state(accelerator:Accelerator,args,num_update_steps_per_epoch,model): 8 | first_epoch = 0 9 | resume_step = 0 10 | global_step = 0 11 | if args.resume_from_checkpoint: 12 | if args.resume_from_checkpoint != "latest": 13 | path = os.path.basename(args.resume_from_checkpoint) 14 | else: 15 | dirs = os.listdir(args.env.output_dir) 16 | dirs = [d for d in dirs if d.startswith("checkpoint")] 17 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 18 | path = dirs[-1] if len(dirs) > 0 else None 19 | 20 | if path is None: 21 | accelerator.print( 22 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 23 | ) 24 | args.resume_from_checkpoint = None 25 | else: 26 | accelerator.print(f"Resuming from checkpoint {path}") 27 | global_step = int(path.split("-")[1]) 28 | resume_global_step = global_step * args.env.gradient_accumulation_steps 29 | first_epoch = global_step // num_update_steps_per_epoch 30 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.env.gradient_accumulation_steps) 31 | accelerator.load_state(os.path.join(args.env.output_dir, path)) 32 | return first_epoch, resume_step, global_step 33 | 34 | 35 | def save_normal(accelerator:Accelerator,args,logger,global_step,model): 36 | accelerator.wait_for_everyone() 37 | args.env.checkpoints_total_limit = 1 38 | if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: 39 | save_path = os.path.join(args.env.output_dir, f"checkpoint-{global_step}") 40 | accelerator.save_state(save_path) 41 | logger.info(f"Saved state to {save_path}") 42 | 43 | accelerator.wait_for_everyone() 44 | if accelerator.is_main_process: 45 | if args.env.checkpoints_total_limit is not None: 46 | checkpoints = os.listdir(args.env.output_dir) 47 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 48 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 49 | 50 | if len(checkpoints) > args.env.checkpoints_total_limit: 51 | num_to_remove = len(checkpoints) - args.env.checkpoints_total_limit 52 | removing_checkpoints = checkpoints[0:num_to_remove] 53 | 54 | logger.info( 55 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 56 | ) 57 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 58 | 59 | for removing_checkpoint in removing_checkpoints: 60 | removing_checkpoint = os.path.join(args.env.output_dir, removing_checkpoint) 61 | shutil.rmtree(removing_checkpoint) 62 | 63 | 64 | -------------------------------------------------------------------------------- /module/data/load_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .utils import get_train_transforms, get_val_transforms, get_dataset 3 | from torch.utils.data import DataLoader 4 | 5 | def collate_fn(batch: dict): 6 | # TODO: make general 7 | semseg = image_panseg = None 8 | images = torch.stack([d['image'] for d in batch]) 9 | if 'panseg' in batch[0]: 10 | semseg = torch.stack([d['panseg'] for d in batch]) 11 | if 'image_panseg' in batch[0]: 12 | image_panseg = torch.stack([d['image_panseg'] for d in batch]) 13 | image_semseg = torch.stack([d['image_semseg'] for d in batch]) 14 | gt_semseg = [d['gt_semseg'] for d in batch] 15 | tokens = mask = inpainting_mask = text = meta = None 16 | if 'tokens' in batch[0]: 17 | tokens = torch.stack([d['tokens'] for d in batch]) 18 | if 'mask' in batch[0]: 19 | mask = torch.stack([d['mask'] for d in batch]) 20 | if 'inpainting_mask' in batch[0]: 21 | inpainting_mask = torch.stack([d['inpainting_mask'] for d in batch]) 22 | if 'text' in batch[0]: 23 | text = [d['text'] for d in batch] 24 | if 'meta' in batch[0]: 25 | meta = [d['meta'] for d in batch] 26 | return { 27 | 'image': images, 28 | 'panseg': semseg, 29 | 'meta': meta, 30 | 'text': text, 31 | 'tokens': tokens, 32 | 'mask': mask, 33 | 'inpainting_mask': inpainting_mask, 34 | 'image_panseg': image_panseg, 35 | 'image_semseg': image_semseg, 36 | 'gt_semseg': gt_semseg 37 | } 38 | 39 | def pr_train_dataloader(p): 40 | transforms = get_train_transforms(p.transformation) 41 | train_dataset = get_dataset( 42 | split='train', 43 | db_name=p.db, 44 | transform=transforms, 45 | cfg_palette=p.pa 46 | ) 47 | 48 | train_dataloader = DataLoader( 49 | train_dataset, 50 | batch_size=p.train.batch_size, 51 | num_workers=p.train.num_workers, 52 | shuffle=True, # 53 | pin_memory=True, 54 | drop_last=True, 55 | collate_fn=collate_fn, 56 | ) 57 | 58 | return train_dataloader 59 | 60 | def pr_val_dataloader(p): 61 | transforms_val = get_val_transforms(p.transformation) 62 | val_dataset = get_dataset( 63 | split='val', 64 | db_name=p.db, 65 | transform=transforms_val, 66 | cfg_palette=p.pa 67 | ) 68 | 69 | val_dataloader = DataLoader( 70 | val_dataset, 71 | batch_size=p.eval.batch_size, 72 | num_workers=p.eval.num_workers, 73 | shuffle=False, 74 | pin_memory=True, 75 | drop_last=False, 76 | collate_fn=collate_fn, 77 | ) 78 | 79 | return val_dataloader 80 | 81 | -------------------------------------------------------------------------------- /module/data/prepare_text.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import CLIPTextModel, CLIPTokenizer 3 | 4 | def sd_null_condition(path): 5 | text = "" 6 | text_encoder = CLIPTextModel.from_pretrained(path, subfolder="text_encoder", revision=None) 7 | tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer", revision=None) 8 | with torch.no_grad(): 9 | empty_inputs = tokenizer(text, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt") 10 | emptyembed = text_encoder(empty_inputs.input_ids)[0] 11 | del text_encoder, tokenizer 12 | return emptyembed 13 | 14 | 15 | if __name__ == '__main__': 16 | pass 17 | -------------------------------------------------------------------------------- /module/data/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms.functional as F 4 | import numpy as np 5 | import random 6 | from PIL import Image 7 | 8 | INT_MODES = { 9 | 'image': 'bicubic', 10 | 'panseg': 'nearest', 11 | 'class_labels': 'nearest', 12 | 'mask': 'nearest', 13 | 'image_panseg': 'bilinear', 14 | 'image_class_labels': 'bilinear', 15 | 'image_semseg': 'bilinear', 16 | 'gt_semseg': 'nearest' 17 | } 18 | 19 | class RandomHorizontalFlip(object): 20 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5.""" 21 | 22 | def __call__(self, sample): 23 | 24 | if random.random() < 0.5: 25 | for elem in sample.keys(): 26 | if elem in ['meta', 'text']: 27 | continue 28 | else: 29 | sample[elem] = F.hflip(sample[elem]) 30 | 31 | return sample 32 | 33 | def __str__(self): 34 | return 'RandomHorizontalFlip(p=0.5)' 35 | 36 | class CropResize(object): 37 | def __init__(self, size, crop_mode=None): 38 | self.size = size 39 | self.crop_mode = crop_mode 40 | assert self.crop_mode in ['centre', 'random', None] 41 | 42 | def crop_and_resize(self, img, h, w, mode='bicubic', crop_size=None): 43 | # crop 44 | if self.crop_mode == 'centre': 45 | img_w, img_h = img.size 46 | min_size = min(img_h, img_w) 47 | if min_size == img_h: 48 | margin = (img_w - min_size) // 2 49 | new_img = img.crop((margin, 0, margin+min_size, min_size)) 50 | else: 51 | margin = (img_h - min_size) // 2 52 | new_img = img.crop((0, margin, min_size, margin+min_size)) 53 | elif self.crop_mode == 'random': 54 | new_img = img.crop(crop_size) 55 | else: 56 | new_img = img 57 | 58 | # accelerate 59 | if new_img.size==(w,h): 60 | return new_img 61 | # resize 62 | if mode == 'bicubic': 63 | new_img = new_img.resize((w, h), resample=getattr(Image, 'Resampling', Image).BICUBIC, reducing_gap=None) 64 | elif mode == 'bilinear': 65 | new_img = new_img.resize((w, h), resample=getattr(Image, 'Resampling', Image).BILINEAR, reducing_gap=None) 66 | elif mode == 'nearest': 67 | new_img = new_img.resize((w, h), resample=getattr(Image, 'Resampling', Image).NEAREST, reducing_gap=None) 68 | else: 69 | raise NotImplementedError 70 | return new_img 71 | 72 | def rand_decide(self,img): 73 | """ 74 | decide crop size in random mode, the crop size in a sample should be the same 75 | """ 76 | img_w, img_h = img.size 77 | min_size = min(img_h, img_w) 78 | if min_size == img_h: 79 | margin = random.randint(0,img_w-min_size) 80 | return (margin, 0, margin+min_size, min_size) 81 | else: 82 | margin = random.randint(0,img_h-min_size) 83 | return (0, margin, min_size, margin+min_size) 84 | 85 | 86 | def __call__(self, sample): 87 | if self.crop_mode == 'random': 88 | crop_size = self.rand_decide(sample['image']) 89 | else: 90 | crop_size = None 91 | for elem in sample.keys(): 92 | if elem in ['image', 'image_panseg', 'panseg', 'mask', 'class_labels', 'image_class_labels', 'image_semseg']: 93 | sample[elem] = self.crop_and_resize(sample[elem], self.size[0], self.size[1], mode=INT_MODES[elem], crop_size=crop_size) 94 | return sample 95 | 96 | def __str__(self) -> str: 97 | return f"CropResize(size={self.size}, crop_mode={self.crop_mode})" 98 | 99 | 100 | class ToTensor(object): 101 | """Convert ndarrays in sample to Tensors.""" 102 | def __init__(self): 103 | self.to_tensor = torchvision.transforms.ToTensor() 104 | 105 | def __call__(self, sample): 106 | 107 | for elem in sample.keys(): 108 | if 'meta' in elem or 'text' in elem: 109 | continue 110 | 111 | elif elem in ['image', 'image_panseg', 'image_class_labels', 'image_semseg']: 112 | sample[elem] = self.to_tensor(sample[elem]) # Regular ToTensor operation 113 | 114 | elif elem in ['panseg', 'mask', 'class_labels', 'gt_semseg']: 115 | sample[elem] = torch.from_numpy(np.array(sample[elem])).long() # Torch Long 116 | 117 | else: 118 | raise NotImplementedError 119 | 120 | return sample 121 | 122 | def __str__(self): 123 | return 'ToTensor' 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /module/data/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torchvision import transforms as T 5 | from typing import Callable, Dict, Tuple, Any, Optional,List 6 | from .transform import RandomHorizontalFlip, CropResize, ToTensor 7 | 8 | 9 | def get_train_transforms(p: Dict[str, Any]) -> Callable: 10 | size = p.size 11 | crop_mode = p.crop 12 | if size<2000: 13 | real_size = (size,size) 14 | else: 15 | size = size//10 16 | real_size = (size,2*size) 17 | crop_mode = None 18 | print('size:',real_size) 19 | transforms = T.Compose([ 20 | RandomHorizontalFlip() if p.flip else nn.Identity(), 21 | CropResize(real_size, crop_mode=crop_mode), 22 | ToTensor(), 23 | ]) 24 | return transforms 25 | 26 | def get_val_transforms(p: Dict) -> Callable: 27 | size = p.size 28 | if size<2000: 29 | real_size = (size,size) 30 | else: 31 | size = size//10 32 | real_size = (size,2*size) 33 | print('size:',real_size) 34 | transforms = T.Compose([ 35 | CropResize(real_size, crop_mode=None), 36 | ToTensor(), 37 | ]) 38 | return transforms 39 | 40 | def get_dataset( 41 | split: Any, 42 | db_name = 'coco', 43 | transform: Optional[Callable] = None, 44 | cfg_palette = None, 45 | ): 46 | 47 | args_palette = (cfg_palette.k,cfg_palette.s) 48 | 49 | if db_name=='celeb': 50 | from .celeb import Celeb 51 | dataset = Celeb( 52 | data_root='dataset/celebAmask', 53 | split=split, 54 | transform=transform, 55 | args_palette=args_palette 56 | ) 57 | 58 | elif db_name=='city': 59 | from .cityscapes import Cityscapes 60 | dataset = Cityscapes( 61 | data_root='dataset/cityscapes', 62 | split=split, 63 | transform=transform, 64 | size=1024, 65 | args_palette=args_palette 66 | ) 67 | 68 | elif db_name=='cocostuff': 69 | from .cocostuff import COCOStuff 70 | dataset = COCOStuff( 71 | data_root='dataset/cocostuff', 72 | split=split, 73 | transform=transform, 74 | args_palette=args_palette 75 | ) 76 | 77 | else: 78 | raise NotImplementedError() 79 | 80 | return dataset 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /module/pipe/pipe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | @torch.no_grad() 4 | def pipeline_rf(timesteps,unet,z0,encoder_hidden_states,blank_feat,guidance_scale,unet_added_conditions=None): 5 | cls_free = guidance_scale>1.0 6 | bsz = z0.shape[0] 7 | if cls_free: 8 | blank_feat = blank_feat.repeat(bsz,1,1) 9 | encoder_hidden_states = torch.cat([encoder_hidden_states, blank_feat], dim=0) 10 | ttlsteps = len(timesteps) 11 | timesteps = timesteps.reshape(ttlsteps,-1).flip([0,1]).squeeze(1)-1 12 | dt = 1.0 / ttlsteps 13 | latents = z0 14 | all_latents = [] 15 | for i, t in enumerate(timesteps): 16 | latent_model_input = torch.cat([latents] * 2) if cls_free else latents 17 | v_pred = unet(latent_model_input, t, encoder_hidden_states, added_cond_kwargs=unet_added_conditions, return_dict=False)[0] 18 | if cls_free: 19 | v_pred_text, v_pred_null = v_pred.chunk(2) 20 | v_pred = v_pred_null + guidance_scale * (v_pred_text - v_pred_null) 21 | latents = latents + dt * v_pred 22 | all_latents.append(latents) 23 | return latents, all_latents 24 | 25 | 26 | @torch.no_grad() 27 | def pipeline_rf_reverse(timesteps,unet,z1,encoder_hidden_states,blank_feat,guidance_scale,unet_added_conditions=None): 28 | cls_free = guidance_scale>1.0 29 | bsz = z1.shape[0] 30 | if cls_free: 31 | blank_feat = blank_feat.repeat(bsz,1,1) 32 | encoder_hidden_states = torch.cat([encoder_hidden_states, blank_feat], dim=0) 33 | ttlsteps = len(timesteps) 34 | timesteps = 1000 - timesteps.max()+timesteps 35 | dt = 1.0 / ttlsteps 36 | latents = z1 37 | all_latents = [] 38 | for i, t in enumerate(timesteps): 39 | latent_model_input = torch.cat([latents] * 2) if cls_free else latents 40 | v_pred = unet(latent_model_input, t, encoder_hidden_states, added_cond_kwargs=unet_added_conditions, return_dict=False)[0] 41 | if cls_free: 42 | v_pred_text, v_pred_null = v_pred.chunk(2) 43 | v_pred = v_pred_null + guidance_scale * (v_pred_text - v_pred_null) 44 | latents = latents - dt * v_pred 45 | all_latents.append(latents) 46 | return latents, all_latents -------------------------------------------------------------------------------- /module/pipe/val.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from accelerate import Accelerator 5 | from typing import Optional 6 | from tqdm import tqdm 7 | from .pipe import pipeline_rf_reverse 8 | import numpy as np 9 | import cv2 10 | import numpy as np 11 | from ..data.builder import build_palette 12 | from einops import rearrange 13 | from einops import rearrange 14 | import os 15 | import os.path as osp 16 | 17 | 18 | def get_unet_added_conditions(args,null_condition): 19 | prompt_embeds = null_condition 20 | unet_added_conditions = None 21 | return prompt_embeds, unet_added_conditions 22 | 23 | 24 | def l2i(latents,vae,weight_dtype,file_names=None): 25 | ## 26 | latents = latents * (1. / vae.scaling_factor) 27 | masks_logits_tmp = vae.decode(latents.to(weight_dtype)).sample 28 | imgs = [] 29 | names = [] 30 | for i in range(latents.shape[0]): 31 | _tmp = masks_logits_tmp[i]*127.5+127.5 32 | _tmp = torch.clamp(_tmp,0,255) 33 | img = _tmp.detach().float().cpu().permute(1,2,0).numpy().astype(np.uint8) 34 | img = img[:,:,::-1] 35 | imgs.append(img) 36 | if file_names: 37 | name = file_names[i].split('/')[-1] 38 | names.append(name) 39 | return imgs, names 40 | 41 | 42 | @torch.no_grad() 43 | def valrf( 44 | accelerator: Accelerator, 45 | args, 46 | vae, 47 | unet, 48 | dataloader, 49 | device, 50 | weight_dtype, 51 | null_condition, 52 | num_inference_steps: int = None, 53 | max_iter: Optional[int] = None, 54 | gstep=0, 55 | ): 56 | 57 | num_inference_steps = args.valstep 58 | guidance_scale = args.cfg.guide 59 | palette = build_palette(args.pa.k,args.pa.s) 60 | meta_data = dataloader.dataset.meta_data 61 | assert args.mode=='semantic' 62 | map_tbl = meta_data['category_names'] 63 | cls_num = len(map_tbl) 64 | 65 | table = torch.tensor(palette[:cls_num*3]) 66 | table = rearrange(table,'(q c) -> c q',c=3) 67 | table = table.to(device=device,dtype=weight_dtype)[None,:,:,None,None] # 3,cls 68 | table = table / 127.5 -1.0 69 | 70 | prompt_embeds, unet_added_conditions = get_unet_added_conditions(args,null_condition) 71 | timesteps = torch.arange(1,1000,1000//num_inference_steps).to(device=device).long() 72 | timesteps = timesteps.reshape(len(timesteps),-1).flip([0,1]).squeeze(1) 73 | 74 | for _, data in enumerate(dataloader): 75 | file_names = [x["image_file"] for x in data['meta']] 76 | rgb_images = data['image'].to(device=device, dtype=weight_dtype) 77 | 78 | rgb_latents = vae.encode(rgb_images).latent_dist.mode()*vae.scaling_factor 79 | bsz = rgb_latents.shape[0] 80 | encoder_hidden_states = prompt_embeds.repeat(bsz,1,1) 81 | 82 | if unet_added_conditions is not None: 83 | _unet_added_conditions = {"time_ids":unet_added_conditions["time_ids"].repeat(bsz,1), 84 | "text_embeds":unet_added_conditions["text_embeds"].repeat(bsz,1)} 85 | else: 86 | _unet_added_conditions = None 87 | 88 | if accelerator.is_main_process: 89 | image_semseg = data['image_semseg'].to(device=device, dtype=weight_dtype) 90 | noise = torch.rand_like(image_semseg)-0.5 91 | 92 | image_semseg += args.pert.co * noise 93 | image_latents = vae.encode(image_semseg).latent_dist.mode()*vae.scaling_factor 94 | rlatents,_ = pipeline_rf_reverse(timesteps,unet,image_latents,encoder_hidden_states,prompt_embeds,guidance_scale,_unet_added_conditions) 95 | imgs,names = l2i(rlatents,vae,weight_dtype,file_names) 96 | for i in range(len(imgs)): 97 | fold=osp.join(args.env.output_dir,'vis') 98 | cv2.imwrite(f'{fold}/{gstep}_{names[i]}',imgs[i]) 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.0 2 | torchvision 3 | accelerate==0.25.0 4 | diffusers==0.25.0 5 | huggingface-hub==0.25.0 6 | deepspeed 7 | numpy~=1.26.0 8 | transformers 9 | imageio 10 | einops 11 | tqdm 12 | omegaconf 13 | opencv-python -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | import torch 7 | from accelerate.utils import set_seed 8 | from diffusers import AutoencoderKL, UNet2DConditionModel 9 | from einops import rearrange 10 | from PIL import Image 11 | import argparse 12 | import imageio 13 | from module.data.builder import build_palette 14 | from module.data.prepare_text import sd_null_condition 15 | from module.pipe.val import l2i 16 | from module.pipe.pipe import pipeline_rf,pipeline_rf_reverse 17 | 18 | 19 | @torch.no_grad() 20 | def valrf(data,args,vae,unet,device,weight_dtype,null_condition): 21 | num_inference_steps = args.valstep 22 | guidance_scale = 1.0 23 | palette = build_palette(6,50) 24 | cls_num = 19 25 | table = torch.tensor(palette[:cls_num*3]) 26 | table = rearrange(table,'(q c) -> c q',c=3) 27 | table = table.to(device=device,dtype=weight_dtype)[None,:,:,None,None] 28 | table = table / 127.5 -1.0 29 | prompt_embeds = null_condition 30 | unet_added_conditions = None 31 | timesteps = torch.arange(1,1000,1000//num_inference_steps).to(device=device).long() 32 | timesteps = timesteps.reshape(len(timesteps),-1).flip([0,1]).squeeze(1) 33 | 34 | rgb_images = data['image'].to(device=device, dtype=weight_dtype) 35 | rgb_latents = vae.encode(rgb_images).latent_dist.mode()*vae.scaling_factor 36 | bsz = rgb_latents.shape[0] 37 | encoder_hidden_states = prompt_embeds.repeat(bsz,1,1) 38 | 39 | assert unet_added_conditions is None 40 | _unet_added_conditions = None 41 | 42 | if args.valmode=='seg': 43 | _, all_latents = pipeline_rf(timesteps,unet,rgb_latents,encoder_hidden_states,prompt_embeds,guidance_scale,_unet_added_conditions) 44 | imgs,_ = l2i(torch.concat(all_latents,dim=0),vae,weight_dtype) 45 | with imageio.get_writer(osp.join(args.sv_dir,'output_seg.gif'), mode='I', fps=4) as writer: 46 | for im in imgs: 47 | writer.append_data(im[:,:,::-1]) 48 | elif args.valmode=='gen': 49 | image_semseg = data['image_semseg'].to(device=device, dtype=weight_dtype) 50 | noise = torch.rand_like(image_semseg)-0.5 51 | image_semseg += noise*0.1 52 | image_latents = vae.encode(image_semseg).latent_dist.mode()*vae.scaling_factor 53 | _, all_latents = pipeline_rf_reverse(timesteps,unet,image_latents,encoder_hidden_states,prompt_embeds,guidance_scale,_unet_added_conditions) 54 | imgs,_ = l2i(torch.concat(all_latents,dim=0),vae,weight_dtype) 55 | with imageio.get_writer(osp.join(args.sv_dir,'output_gen.gif'), mode='I', fps=4) as writer: 56 | for im in imgs: 57 | writer.append_data(im[:,:,::-1]) 58 | else: 59 | raise ValueError() 60 | 61 | 62 | 63 | 64 | def main(args): 65 | 66 | device = args.device 67 | if args.seed is not None: 68 | set_seed(args.seed) 69 | vae = AutoencoderKL.from_pretrained(args.pretrain_model, subfolder='vae', revision=None) 70 | unet = UNet2DConditionModel.from_pretrained(args.ckpt) 71 | 72 | weight_dtype = torch.float32 73 | if args.mixed_precision == "fp16": 74 | weight_dtype = torch.float16 75 | elif args.mixed_precision == "bf16": 76 | weight_dtype = torch.bfloat16 77 | 78 | vae.requires_grad_(False) 79 | unet.to(device=device, dtype=weight_dtype) 80 | vae.to(device=device, dtype=weight_dtype) 81 | 82 | null_condition = sd_null_condition(args.pretrain_model) 83 | null_condition = null_condition.to(device=device, dtype=weight_dtype) 84 | 85 | if args.allow_tf32: 86 | torch.backends.cuda.matmul.allow_tf32 = True 87 | 88 | torch.cuda.empty_cache() 89 | data = prepare_data(args.im_path,args.mask_path) 90 | unet.eval() 91 | valrf(data,args,vae,unet,device,weight_dtype,null_condition) 92 | 93 | def prepare_pm(x): 94 | palette = build_palette(6,50) 95 | assert len(x.shape)==2 96 | h,w = x.shape 97 | pm = np.ones((h,w,3))*255 98 | clslist = np.unique(x).tolist() 99 | if 255 in clslist: 100 | raise ValueError() 101 | for _c in clslist: 102 | _x,_y = np.where(x==_c) 103 | pm[_x,_y,:] = palette[int(_c)*3:(int(_c)+1)*3] 104 | return pm 105 | 106 | def prepare_data(ipth,mpth): 107 | import torchvision.transforms as T 108 | image = Image.open(ipth).convert('RGB').resize((512,512),resample=Image.Resampling.BILINEAR) 109 | seg = Image.open(mpth).resize((512,512),resample=Image.Resampling.NEAREST) 110 | image_semseg = prepare_pm(np.array(seg)).astype(np.uint8) 111 | tf = T.Compose([T.ToTensor(),T.Normalize(0.5,0.5)]) 112 | image = tf(image) 113 | image_semseg = tf(image_semseg) 114 | data = dict(image=image.unsqueeze(0),image_semseg=image_semseg.unsqueeze(0)) 115 | return data 116 | 117 | 118 | if __name__ == "__main__": 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument('--pretrain_model',type=str,default='dataset/pretrain/stable-diffusion-v1-5') 121 | parser.add_argument('--seed',type=int,default=None) 122 | parser.add_argument('--allow_tf32',type=bool,default=True) 123 | parser.add_argument('--mixed_precision',type=str,default=None) 124 | parser.add_argument('--valmode',type=str,default='gen',choices=['seg','gen']) 125 | parser.add_argument('--valstep',type=int,default=25) 126 | parser.add_argument('--device',type=str,default='cuda') 127 | parser.add_argument('--ckpt',type=str,default='demo/unet') 128 | parser.add_argument('--sv_dir',type=str,default='demo/vis') 129 | parser.add_argument('--im_path',type=str,default='file/img.jpg') 130 | parser.add_argument('--mask_path',type=str,default='file/mask.png') 131 | args = parser.parse_args() 132 | os.makedirs(args.sv_dir,exist_ok=True) 133 | main(args) -------------------------------------------------------------------------------- /scripts/semrf.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | 4 | import logging 5 | import math 6 | import os 7 | 8 | from pathlib import Path 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.utils.checkpoint 14 | import transformers 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import ProjectConfiguration, set_seed, DeepSpeedPlugin 18 | 19 | from tqdm import tqdm 20 | from omegaconf import OmegaConf 21 | import diffusers 22 | from diffusers import AutoencoderKL, UNet2DConditionModel 23 | from diffusers.optimization import get_scheduler 24 | from module.data.load_dataset import pr_val_dataloader, pr_train_dataloader 25 | from module.pipe.val import valrf 26 | from module.data.hook import resume_state, save_normal 27 | from module.data.prepare_text import sd_null_condition 28 | 29 | logger = get_logger(__name__) 30 | 31 | 32 | @torch.no_grad() 33 | def pre_rf(data, vae, device, weight_dtype, args): 34 | rgb_images = data['image'].to(dtype=weight_dtype, device=device) 35 | images = data['image_semseg'].to(dtype=weight_dtype, device=device) 36 | noise = torch.rand_like(images)-0.5 37 | images = images + args.pert.co*noise 38 | rgb_latents = vae.encode(rgb_images).latent_dist.sample()* vae.scaling_factor 39 | 40 | latents = vae.encode(images).latent_dist.mode()* vae.scaling_factor 41 | 42 | return latents, rgb_latents 43 | 44 | 45 | def main(args): 46 | 47 | ttlsteps = 1000 48 | args.transformation.size = args.env.size 49 | 50 | print('init SD 1.5') 51 | args.pretrain_model = 'dataset/pretrain/stable-diffusion-v1-5' 52 | args.vae_path = 'dataset/pretrain/stable-diffusion-v1-5/vae' 53 | 54 | train_dataloader = pr_train_dataloader(args) 55 | val_dataloader = pr_val_dataloader(args) 56 | 57 | logging_dir = Path(args.env.output_dir, args.env.logging_dir) 58 | accelerator_project_config = ProjectConfiguration(project_dir=args.env.output_dir, logging_dir=logging_dir) 59 | deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=args.env.gradient_accumulation_steps) 60 | accelerator = Accelerator( 61 | gradient_accumulation_steps=args.env.gradient_accumulation_steps, 62 | mixed_precision=args.env.mixed_precision, 63 | log_with=args.env.report_to, 64 | project_config=accelerator_project_config, 65 | deepspeed_plugin=deepspeed_plugin if args.env.deepspeed else None 66 | ) 67 | 68 | logging.basicConfig( 69 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 70 | datefmt="%m/%d/%Y %H:%M:%S", 71 | level=logging.INFO, 72 | ) 73 | logger.info(accelerator.state, main_process_only=False) 74 | if accelerator.is_local_main_process: 75 | transformers.utils.logging.set_verbosity_warning() 76 | diffusers.utils.logging.set_verbosity_info() 77 | else: 78 | transformers.utils.logging.set_verbosity_error() 79 | diffusers.utils.logging.set_verbosity_error() 80 | 81 | if args.env.seed is not None: 82 | set_seed(args.env.seed) 83 | 84 | 85 | if accelerator.is_main_process: 86 | if args.env.output_dir is not None: 87 | os.makedirs(os.path.join(args.env.output_dir,'vis'), exist_ok=True) 88 | OmegaConf.save(args,os.path.join(args.env.output_dir,'config.yaml')) 89 | 90 | vae = AutoencoderKL.from_pretrained(args.vae_path, revision=None) 91 | unet = UNet2DConditionModel.from_pretrained(args.pretrain_model, subfolder="unet", revision=None) 92 | 93 | 94 | weight_dtype = torch.float32 95 | if accelerator.mixed_precision == "fp16": 96 | weight_dtype = torch.float16 97 | elif accelerator.mixed_precision == "bf16": 98 | weight_dtype = torch.bfloat16 99 | 100 | vae.requires_grad_(False) 101 | unet.to(accelerator.device, dtype=weight_dtype) 102 | vae.to(accelerator.device, dtype=weight_dtype) 103 | 104 | null_condition = sd_null_condition(args.pretrain_model) 105 | null_condition = null_condition.to(accelerator.device, dtype=weight_dtype) 106 | 107 | if args.env.allow_tf32: 108 | torch.backends.cuda.matmul.allow_tf32 = True 109 | 110 | if args.env.scale_lr: 111 | args.optim.lr = ( 112 | args.optim.lr * args.env.gradient_accumulation_steps * args.train.batch_size * accelerator.num_processes 113 | ) 114 | 115 | assert args.optim.name=='adamw' 116 | optimizer_class = torch.optim.AdamW 117 | optimizer = optimizer_class( 118 | unet.parameters(), 119 | lr=args.optim.lr, 120 | betas=(args.optim.beta1, args.optim.beta2), 121 | weight_decay=args.optim.weight_decay, 122 | eps=args.optim.epsilon, 123 | ) 124 | 125 | # # Scheduler and math around the number of training steps. 126 | overrode_max_train_steps = False 127 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.env.gradient_accumulation_steps) 128 | assert args.env.max_train_steps is not None 129 | 130 | lr_ratio = 1 if args.env.deepspeed else accelerator.num_processes 131 | lr_scheduler = get_scheduler( 132 | args.lr_scheduler.name, 133 | optimizer=optimizer, 134 | num_warmup_steps=args.lr_scheduler.warmup_steps * lr_ratio, 135 | num_training_steps=args.env.max_train_steps * lr_ratio, 136 | ) 137 | 138 | unet, optimizer, train_dataloader, lr_scheduler, val_dataloader = accelerator.prepare( 139 | unet, optimizer, train_dataloader, lr_scheduler, val_dataloader) 140 | 141 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.env.gradient_accumulation_steps) 142 | if overrode_max_train_steps: 143 | args.env.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 144 | # Afterwards we recalculate our number of training epochs 145 | args.num_train_epochs = math.ceil(args.env.max_train_steps / num_update_steps_per_epoch) 146 | 147 | ## 148 | if accelerator.is_main_process: 149 | accelerator.init_trackers("model") 150 | 151 | # Train! 152 | total_batch_size = args.train.batch_size * accelerator.num_processes * args.env.gradient_accumulation_steps 153 | 154 | logger.info("***** Running training *****") 155 | logger.info(f" Num Epochs = {args.num_train_epochs}") 156 | logger.info(f" Instantaneous batch size per device = {args.train.batch_size}") 157 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 158 | logger.info(f" Gradient Accumulation steps = {args.env.gradient_accumulation_steps}") 159 | logger.info(f" Total optimization steps = {args.env.max_train_steps}") 160 | global_step = 0 161 | first_epoch = 0 162 | 163 | first_epoch, resume_step, global_step = resume_state(accelerator,args,num_update_steps_per_epoch,unet) 164 | torch.cuda.empty_cache() 165 | progress_bar = tqdm(range(global_step, args.env.max_train_steps), disable=not accelerator.is_local_main_process) 166 | progress_bar.set_description("Steps") 167 | 168 | device = accelerator.device 169 | 170 | for epoch in range(first_epoch, args.num_train_epochs): 171 | train_loss = 0.0 172 | for step, batch in enumerate(train_dataloader): 173 | unet.train() 174 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 175 | if step % args.env.gradient_accumulation_steps == 0: 176 | progress_bar.update(1) 177 | continue 178 | 179 | with accelerator.accumulate(unet): 180 | 181 | latents, z0 = pre_rf(batch,vae,device,weight_dtype,args) 182 | bsz = latents.shape[0] 183 | if args.cfg.continus: 184 | t = torch.rand((bsz,),device=device,dtype=weight_dtype) 185 | timesteps = t*ttlsteps 186 | else: 187 | timesteps = torch.randint(0, ttlsteps,(bsz,), device=device,dtype=torch.long) 188 | t = timesteps.to(weight_dtype)/ttlsteps 189 | 190 | t = t[:,None,None,None] 191 | perturb_latent = t*latents+(1.-t)*z0 192 | 193 | prompt_embeds = null_condition.repeat(bsz,1,1) 194 | model_pred = unet(perturb_latent, timesteps, prompt_embeds).sample 195 | target = latents - z0 196 | 197 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 198 | 199 | avg_loss = accelerator.gather(loss.repeat(args.train.batch_size)).mean() 200 | train_loss += avg_loss.item() / args.env.gradient_accumulation_steps 201 | 202 | accelerator.backward(loss) 203 | if accelerator.sync_gradients: 204 | accelerator.clip_grad_norm_(unet.parameters(), args.env.max_grad_norm) 205 | optimizer.step() 206 | lr_scheduler.step() 207 | optimizer.zero_grad() 208 | 209 | if accelerator.sync_gradients: 210 | progress_bar.update(1) 211 | global_step += 1 212 | accelerator.log({"train_loss": train_loss}, step=global_step) 213 | train_loss = 0.0 214 | 215 | if global_step % args.env.checkpointing_steps == 0: 216 | save_normal(accelerator,args,logger,global_step,unet) 217 | 218 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 219 | progress_bar.set_postfix(**logs) 220 | 221 | if args.env.val_iter > 0 and global_step % args.env.val_iter == 0: 222 | unet.eval() 223 | 224 | valrf( 225 | accelerator, 226 | args, 227 | vae, 228 | unet, 229 | val_dataloader, 230 | device, 231 | weight_dtype, 232 | null_condition, 233 | max_iter=None, 234 | gstep=global_step 235 | ) 236 | 237 | if global_step >= args.env.max_train_steps: 238 | accelerator.wait_for_everyone() 239 | break 240 | 241 | accelerator.end_training() 242 | 243 | 244 | if __name__ == "__main__": 245 | cfg_path = sys.argv[1] 246 | assert os.path.isfile(cfg_path) 247 | args = OmegaConf.load(cfg_path) 248 | cli_config = OmegaConf.from_cli(sys.argv[2:]) 249 | args = OmegaConf.merge(args,cli_config) 250 | main(args) 251 | --------------------------------------------------------------------------------