├── LICENSE ├── README.md ├── checkpoints ├── personalized_models │ ├── README.md │ └── unet_disney │ │ └── config.json └── stable-diffusion-v1-5 │ ├── README.md │ ├── feature_extractor │ └── preprocessor_config.json │ ├── model_index.json │ ├── safety_checker │ └── config.json │ ├── scheduler │ └── scheduler_config.json │ ├── text_encoder │ └── config.json │ ├── tokenizer │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json │ ├── unet │ └── config.json │ └── vae │ └── config.json ├── datasets └── README.md ├── examples ├── RealSRSet │ ├── Lincoln.png │ ├── building.png │ ├── butterfly.png │ ├── butterfly2.png │ ├── chip.png │ ├── comic1.png │ ├── comic2.png │ ├── comic3.png │ ├── computer.png │ ├── dog.png │ ├── dped_crop00061.png │ ├── foreman.png │ ├── frog.png │ ├── oldphoto2.png │ ├── oldphoto3.png │ ├── oldphoto6.png │ ├── painting.png │ ├── pattern.png │ ├── ppt3.png │ └── tiger.png ├── Set14 │ ├── baboon.png │ ├── barbara.png │ ├── bridge.png │ ├── coastguard.png │ ├── comic.png │ ├── face.png │ ├── flowers.png │ ├── foreman.png │ ├── lenna.png │ ├── man.png │ ├── monarch.png │ ├── pepper.png │ ├── ppt3.png │ └── zebra.png ├── Set5 │ ├── baby.png │ ├── bird.png │ ├── butterfly.png │ ├── head.png │ └── woman.png └── dog.png ├── gradio_pasd.py ├── pasd ├── __init__.py ├── annotator │ ├── __init__.py │ ├── ckpts │ │ └── README.md │ ├── retinaface │ │ ├── __init__.py │ │ ├── data │ │ │ ├── FDDB │ │ │ │ └── img_list.txt │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── data_augment.py │ │ │ └── wider_face.py │ │ ├── facemodels │ │ │ ├── __init__.py │ │ │ ├── net.py │ │ │ └── retinaface.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── functions │ │ │ │ └── prior_box.py │ │ │ └── modules │ │ │ │ ├── __init__.py │ │ │ │ └── multibox_loss.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── box_utils.py │ │ │ ├── nms │ │ │ ├── __init__.py │ │ │ └── py_cpu_nms.py │ │ │ └── timer.py │ ├── util.py │ └── yolo │ │ └── __init__.py ├── dataloader │ ├── __init__.py │ ├── basicsr │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-39.pyc │ │ │ │ ├── ffhq_dataset.cpython-39.pyc │ │ │ │ ├── prefetch_dataloader.cpython-39.pyc │ │ │ │ └── transforms.cpython-39.pyc │ │ │ ├── data_sampler.py │ │ │ ├── data_util.py │ │ │ ├── degradations.py │ │ │ ├── prefetch_dataloader.py │ │ │ └── transforms.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── color_util.cpython-39.pyc │ │ │ ├── diffjpeg.cpython-39.pyc │ │ │ ├── dist_util.cpython-39.pyc │ │ │ ├── file_client.cpython-39.pyc │ │ │ ├── img_process_util.cpython-39.pyc │ │ │ ├── img_util.cpython-39.pyc │ │ │ ├── logger.cpython-39.pyc │ │ │ ├── misc.cpython-39.pyc │ │ │ ├── options.cpython-39.pyc │ │ │ └── registry.cpython-39.pyc │ │ │ ├── color_util.py │ │ │ ├── diffjpeg.py │ │ │ ├── dist_util.py │ │ │ ├── download_util.py │ │ │ ├── file_client.py │ │ │ ├── flow_util.py │ │ │ ├── img_process_util.py │ │ │ ├── img_util.py │ │ │ ├── lmdb_util.py │ │ │ ├── logger.py │ │ │ ├── matlab_functions.py │ │ │ ├── misc.py │ │ │ ├── options.py │ │ │ ├── plot_util.py │ │ │ └── registry.py │ ├── localdatasets.py │ ├── params_realesrgan.yml │ ├── realesrgan.py │ └── webdatasets.py ├── models │ ├── __init__.py │ ├── pasd │ │ ├── __init__.py │ │ ├── controlnet.py │ │ ├── unet_2d_blocks.py │ │ └── unet_2d_condition.py │ └── pasd_light │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── controlnet.py │ │ ├── transformer_2d.py │ │ ├── unet_2d_blocks.py │ │ └── unet_2d_condition.py ├── myutils │ ├── __init__.py │ ├── convert_lora_safetensor_to_diffusers.py │ ├── devices.py │ ├── img_util.py │ ├── misc.py │ ├── vaehook.py │ └── wavelet_color_fix.py └── pipelines │ ├── __init__.py │ ├── pipeline_pasd.py │ └── pipeline_pasd_sdxl.py ├── requirements-test.txt ├── requirements.txt ├── runs ├── pasd │ └── README.md ├── pasd_light │ └── README.md ├── pasd_light_rrdb │ └── README.md └── pasd_rrdb │ └── README.md ├── samples ├── 000001x2.gif ├── 000001x2.png ├── 000001x2_comp.png ├── 000001x2_out.png ├── 000004x2.gif ├── 000004x2.png ├── 000004x2_comp.png ├── 000004x2_out.png ├── 000020x2.gif ├── 000020x2.png ├── 000020x2_comp.png ├── 000020x2_out.png ├── 000030x2.png ├── 000030x2_comp.png ├── 000030x2_out.png ├── 000067x2.gif ├── 000067x2.png ├── 000067x2_comp.png ├── 000067x2_out.png ├── 000080x2.gif ├── 000080x2.png ├── 000080x2_comp.png ├── 000080x2_out.png ├── 0c74bd2420d532c2.png ├── 0c74bd2420d532c2_out.png ├── 1125e119c19065f3.png ├── 1125e119c19065f3_out.png ├── 1965223411271f2f.png ├── 1965223411271f2f_out.png ├── 27d38eeb2dbbe7c9.gif ├── 27d38eeb2dbbe7c9.png ├── 27d38eeb2dbbe7c9_comp.png ├── 27d38eeb2dbbe7c9_out.png ├── 2e512b688ef48a43.gif ├── 2e512b688ef48a43.png ├── 2e512b688ef48a43_comp.png ├── 2e512b688ef48a43_out.png ├── 2e753d77bca91095.png ├── 2e753d77bca91095_out.png ├── 38fa6c25e210c3a2.png ├── 38fa6c25e210c3a2_out.png ├── 461a96f62b724eab.png ├── 461a96f62b724eab_out.png ├── 629e4da70703193b.gif ├── 629e4da70703193b.png ├── 629e4da70703193b_comp.png ├── 629e4da70703193b_out.png ├── 8f5cb2715536eef0.png ├── 8f5cb2715536eef0_out.png ├── Lincoln.gif ├── Lincoln.png ├── Lincoln_comp.png ├── Lincoln_out.png ├── RealPhoto60_06.png ├── RealPhoto60_09.png ├── RealPhoto60_22.png ├── RealPhoto60_56.png ├── building.gif ├── building.png ├── building_comp.png ├── building_out.png ├── d4f59e89c1011bc4.png ├── d4f59e89c1011bc4_out.png ├── ed2ec7d15fcbe80e.png ├── ed2ec7d15fcbe80e_comp.png ├── ed2ec7d15fcbe80e_out.png ├── ed84982626af1f44.png ├── ed84982626af1f44_comp.png ├── ed84982626af1f44_out.png ├── f125ee5838471073.gif ├── f125ee5838471073.png ├── f125ee5838471073_comp.png ├── f125ee5838471073_out.png ├── fe1ade76596bffdf.png ├── fe1ade76596bffdf_out.png ├── frog.gif ├── frog.png ├── frog_out.png ├── house.gif ├── house.jpg ├── house_out.png └── pasd_arch.png ├── setup.py ├── test_pasd.py ├── test_pasd_sdxl.py ├── train_pasd.py └── train_pasd.sh /checkpoints/personalized_models/README.md: -------------------------------------------------------------------------------- 1 | Please put personalized models here. You can download them from [civitai](https://civitai.com/models), [huggingface](https://huggingface.co/models) or any other communities. -------------------------------------------------------------------------------- /checkpoints/personalized_models/unet_disney/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DConditionModel", 3 | "_diffusers_version": "0.6.0", 4 | "act_fn": "silu", 5 | "attention_head_dim": 8, 6 | "block_out_channels": [ 7 | 320, 8 | 640, 9 | 1280, 10 | 1280 11 | ], 12 | "center_input_sample": false, 13 | "cross_attention_dim": 768, 14 | "down_block_types": [ 15 | "CrossAttnDownBlock2D", 16 | "CrossAttnDownBlock2D", 17 | "CrossAttnDownBlock2D", 18 | "DownBlock2D" 19 | ], 20 | "downsample_padding": 1, 21 | "flip_sin_to_cos": true, 22 | "freq_shift": 0, 23 | "in_channels": 4, 24 | "layers_per_block": 2, 25 | "mid_block_scale_factor": 1, 26 | "norm_eps": 1e-05, 27 | "norm_num_groups": 32, 28 | "out_channels": 4, 29 | "sample_size": 64, 30 | "up_block_types": [ 31 | "UpBlock2D", 32 | "CrossAttnUpBlock2D", 33 | "CrossAttnUpBlock2D", 34 | "CrossAttnUpBlock2D" 35 | ] 36 | } 37 | -------------------------------------------------------------------------------- /checkpoints/stable-diffusion-v1-5/README.md: -------------------------------------------------------------------------------- 1 | Please put SD1.5 models here. You can also put the unets from personalized models here. 2 | 3 | [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) -------------------------------------------------------------------------------- /checkpoints/stable-diffusion-v1-5/feature_extractor/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": 224, 3 | "do_center_crop": true, 4 | "do_convert_rgb": true, 5 | "do_normalize": true, 6 | "do_resize": true, 7 | "feature_extractor_type": "CLIPFeatureExtractor", 8 | "image_mean": [ 9 | 0.48145466, 10 | 0.4578275, 11 | 0.40821073 12 | ], 13 | "image_std": [ 14 | 0.26862954, 15 | 0.26130258, 16 | 0.27577711 17 | ], 18 | "resample": 3, 19 | "size": 224 20 | } 21 | -------------------------------------------------------------------------------- /checkpoints/stable-diffusion-v1-5/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableDiffusionPipeline", 3 | "_diffusers_version": "0.6.0", 4 | "feature_extractor": [ 5 | "transformers", 6 | "CLIPFeatureExtractor" 7 | ], 8 | "safety_checker": [ 9 | "stable_diffusion", 10 | "StableDiffusionSafetyChecker" 11 | ], 12 | "scheduler": [ 13 | "diffusers", 14 | "PNDMScheduler" 15 | ], 16 | "text_encoder": [ 17 | "transformers", 18 | "CLIPTextModel" 19 | ], 20 | "tokenizer": [ 21 | "transformers", 22 | "CLIPTokenizer" 23 | ], 24 | "unet": [ 25 | "diffusers", 26 | "UNet2DConditionModel" 27 | ], 28 | "vae": [ 29 | "diffusers", 30 | "AutoencoderKL" 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /checkpoints/stable-diffusion-v1-5/safety_checker/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_commit_hash": "4bb648a606ef040e7685bde262611766a5fdd67b", 3 | "_name_or_path": "CompVis/stable-diffusion-safety-checker", 4 | "architectures": [ 5 | "StableDiffusionSafetyChecker" 6 | ], 7 | "initializer_factor": 1.0, 8 | "logit_scale_init_value": 2.6592, 9 | "model_type": "clip", 10 | "projection_dim": 768, 11 | "text_config": { 12 | "_name_or_path": "", 13 | "add_cross_attention": false, 14 | "architectures": null, 15 | "attention_dropout": 0.0, 16 | "bad_words_ids": null, 17 | "bos_token_id": 0, 18 | "chunk_size_feed_forward": 0, 19 | "cross_attention_hidden_size": null, 20 | "decoder_start_token_id": null, 21 | "diversity_penalty": 0.0, 22 | "do_sample": false, 23 | "dropout": 0.0, 24 | "early_stopping": false, 25 | "encoder_no_repeat_ngram_size": 0, 26 | "eos_token_id": 2, 27 | "exponential_decay_length_penalty": null, 28 | "finetuning_task": null, 29 | "forced_bos_token_id": null, 30 | "forced_eos_token_id": null, 31 | "hidden_act": "quick_gelu", 32 | "hidden_size": 768, 33 | "id2label": { 34 | "0": "LABEL_0", 35 | "1": "LABEL_1" 36 | }, 37 | "initializer_factor": 1.0, 38 | "initializer_range": 0.02, 39 | "intermediate_size": 3072, 40 | "is_decoder": false, 41 | "is_encoder_decoder": false, 42 | "label2id": { 43 | "LABEL_0": 0, 44 | "LABEL_1": 1 45 | }, 46 | "layer_norm_eps": 1e-05, 47 | "length_penalty": 1.0, 48 | "max_length": 20, 49 | "max_position_embeddings": 77, 50 | "min_length": 0, 51 | "model_type": "clip_text_model", 52 | "no_repeat_ngram_size": 0, 53 | "num_attention_heads": 12, 54 | "num_beam_groups": 1, 55 | "num_beams": 1, 56 | "num_hidden_layers": 12, 57 | "num_return_sequences": 1, 58 | "output_attentions": false, 59 | "output_hidden_states": false, 60 | "output_scores": false, 61 | "pad_token_id": 1, 62 | "prefix": null, 63 | "problem_type": null, 64 | "pruned_heads": {}, 65 | "remove_invalid_values": false, 66 | "repetition_penalty": 1.0, 67 | "return_dict": true, 68 | "return_dict_in_generate": false, 69 | "sep_token_id": null, 70 | "task_specific_params": null, 71 | "temperature": 1.0, 72 | "tf_legacy_loss": false, 73 | "tie_encoder_decoder": false, 74 | "tie_word_embeddings": true, 75 | "tokenizer_class": null, 76 | "top_k": 50, 77 | "top_p": 1.0, 78 | "torch_dtype": null, 79 | "torchscript": false, 80 | "transformers_version": "4.22.0.dev0", 81 | "typical_p": 1.0, 82 | "use_bfloat16": false, 83 | "vocab_size": 49408 84 | }, 85 | "text_config_dict": { 86 | "hidden_size": 768, 87 | "intermediate_size": 3072, 88 | "num_attention_heads": 12, 89 | "num_hidden_layers": 12 90 | }, 91 | "torch_dtype": "float32", 92 | "transformers_version": null, 93 | "vision_config": { 94 | "_name_or_path": "", 95 | "add_cross_attention": false, 96 | "architectures": null, 97 | "attention_dropout": 0.0, 98 | "bad_words_ids": null, 99 | "bos_token_id": null, 100 | "chunk_size_feed_forward": 0, 101 | "cross_attention_hidden_size": null, 102 | "decoder_start_token_id": null, 103 | "diversity_penalty": 0.0, 104 | "do_sample": false, 105 | "dropout": 0.0, 106 | "early_stopping": false, 107 | "encoder_no_repeat_ngram_size": 0, 108 | "eos_token_id": null, 109 | "exponential_decay_length_penalty": null, 110 | "finetuning_task": null, 111 | "forced_bos_token_id": null, 112 | "forced_eos_token_id": null, 113 | "hidden_act": "quick_gelu", 114 | "hidden_size": 1024, 115 | "id2label": { 116 | "0": "LABEL_0", 117 | "1": "LABEL_1" 118 | }, 119 | "image_size": 224, 120 | "initializer_factor": 1.0, 121 | "initializer_range": 0.02, 122 | "intermediate_size": 4096, 123 | "is_decoder": false, 124 | "is_encoder_decoder": false, 125 | "label2id": { 126 | "LABEL_0": 0, 127 | "LABEL_1": 1 128 | }, 129 | "layer_norm_eps": 1e-05, 130 | "length_penalty": 1.0, 131 | "max_length": 20, 132 | "min_length": 0, 133 | "model_type": "clip_vision_model", 134 | "no_repeat_ngram_size": 0, 135 | "num_attention_heads": 16, 136 | "num_beam_groups": 1, 137 | "num_beams": 1, 138 | "num_channels": 3, 139 | "num_hidden_layers": 24, 140 | "num_return_sequences": 1, 141 | "output_attentions": false, 142 | "output_hidden_states": false, 143 | "output_scores": false, 144 | "pad_token_id": null, 145 | "patch_size": 14, 146 | "prefix": null, 147 | "problem_type": null, 148 | "pruned_heads": {}, 149 | "remove_invalid_values": false, 150 | "repetition_penalty": 1.0, 151 | "return_dict": true, 152 | "return_dict_in_generate": false, 153 | "sep_token_id": null, 154 | "task_specific_params": null, 155 | "temperature": 1.0, 156 | "tf_legacy_loss": false, 157 | "tie_encoder_decoder": false, 158 | "tie_word_embeddings": true, 159 | "tokenizer_class": null, 160 | "top_k": 50, 161 | "top_p": 1.0, 162 | "torch_dtype": null, 163 | "torchscript": false, 164 | "transformers_version": "4.22.0.dev0", 165 | "typical_p": 1.0, 166 | "use_bfloat16": false 167 | }, 168 | "vision_config_dict": { 169 | "hidden_size": 1024, 170 | "intermediate_size": 4096, 171 | "num_attention_heads": 16, 172 | "num_hidden_layers": 24, 173 | "patch_size": 14 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /checkpoints/stable-diffusion-v1-5/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "PNDMScheduler", 3 | "_diffusers_version": "0.6.0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "num_train_timesteps": 1000, 8 | "set_alpha_to_one": false, 9 | "skip_prk_steps": true, 10 | "steps_offset": 1, 11 | "trained_betas": null, 12 | "clip_sample": false 13 | } 14 | -------------------------------------------------------------------------------- /checkpoints/stable-diffusion-v1-5/text_encoder/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "openai/clip-vit-large-patch14", 3 | "architectures": [ 4 | "CLIPTextModel" 5 | ], 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 0, 8 | "dropout": 0.0, 9 | "eos_token_id": 2, 10 | "hidden_act": "quick_gelu", 11 | "hidden_size": 768, 12 | "initializer_factor": 1.0, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-05, 16 | "max_position_embeddings": 77, 17 | "model_type": "clip_text_model", 18 | "num_attention_heads": 12, 19 | "num_hidden_layers": 12, 20 | "pad_token_id": 1, 21 | "projection_dim": 768, 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.22.0.dev0", 24 | "vocab_size": 49408 25 | } 26 | -------------------------------------------------------------------------------- /checkpoints/stable-diffusion-v1-5/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "<|endoftext|>", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /checkpoints/stable-diffusion-v1-5/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "do_lower_case": true, 12 | "eos_token": { 13 | "__type": "AddedToken", 14 | "content": "<|endoftext|>", 15 | "lstrip": false, 16 | "normalized": true, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "errors": "replace", 21 | "model_max_length": 77, 22 | "name_or_path": "openai/clip-vit-large-patch14", 23 | "pad_token": "<|endoftext|>", 24 | "special_tokens_map_file": "./special_tokens_map.json", 25 | "tokenizer_class": "CLIPTokenizer", 26 | "unk_token": { 27 | "__type": "AddedToken", 28 | "content": "<|endoftext|>", 29 | "lstrip": false, 30 | "normalized": true, 31 | "rstrip": false, 32 | "single_word": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /checkpoints/stable-diffusion-v1-5/unet/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "UNet2DConditionModel", 3 | "_diffusers_version": "0.6.0", 4 | "act_fn": "silu", 5 | "attention_head_dim": 8, 6 | "block_out_channels": [ 7 | 320, 8 | 640, 9 | 1280, 10 | 1280 11 | ], 12 | "center_input_sample": false, 13 | "cross_attention_dim": 768, 14 | "down_block_types": [ 15 | "CrossAttnDownBlock2D", 16 | "CrossAttnDownBlock2D", 17 | "CrossAttnDownBlock2D", 18 | "DownBlock2D" 19 | ], 20 | "downsample_padding": 1, 21 | "flip_sin_to_cos": true, 22 | "freq_shift": 0, 23 | "in_channels": 4, 24 | "layers_per_block": 2, 25 | "mid_block_scale_factor": 1, 26 | "norm_eps": 1e-05, 27 | "norm_num_groups": 32, 28 | "out_channels": 4, 29 | "sample_size": 64, 30 | "up_block_types": [ 31 | "UpBlock2D", 32 | "CrossAttnUpBlock2D", 33 | "CrossAttnUpBlock2D", 34 | "CrossAttnUpBlock2D" 35 | ] 36 | } 37 | -------------------------------------------------------------------------------- /checkpoints/stable-diffusion-v1-5/vae/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.6.0", 4 | "act_fn": "silu", 5 | "block_out_channels": [ 6 | 128, 7 | 256, 8 | 512, 9 | 512 10 | ], 11 | "down_block_types": [ 12 | "DownEncoderBlock2D", 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D" 16 | ], 17 | "in_channels": 3, 18 | "latent_channels": 4, 19 | "layers_per_block": 2, 20 | "norm_num_groups": 32, 21 | "out_channels": 3, 22 | "sample_size": 512, 23 | "up_block_types": [ 24 | "UpDecoderBlock2D", 25 | "UpDecoderBlock2D", 26 | "UpDecoderBlock2D", 27 | "UpDecoderBlock2D" 28 | ] 29 | } 30 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | Please put loacal training datasets here. -------------------------------------------------------------------------------- /examples/RealSRSet/Lincoln.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/Lincoln.png -------------------------------------------------------------------------------- /examples/RealSRSet/building.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/building.png -------------------------------------------------------------------------------- /examples/RealSRSet/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/butterfly.png -------------------------------------------------------------------------------- /examples/RealSRSet/butterfly2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/butterfly2.png -------------------------------------------------------------------------------- /examples/RealSRSet/chip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/chip.png -------------------------------------------------------------------------------- /examples/RealSRSet/comic1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/comic1.png -------------------------------------------------------------------------------- /examples/RealSRSet/comic2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/comic2.png -------------------------------------------------------------------------------- /examples/RealSRSet/comic3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/comic3.png -------------------------------------------------------------------------------- /examples/RealSRSet/computer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/computer.png -------------------------------------------------------------------------------- /examples/RealSRSet/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/dog.png -------------------------------------------------------------------------------- /examples/RealSRSet/dped_crop00061.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/dped_crop00061.png -------------------------------------------------------------------------------- /examples/RealSRSet/foreman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/foreman.png -------------------------------------------------------------------------------- /examples/RealSRSet/frog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/frog.png -------------------------------------------------------------------------------- /examples/RealSRSet/oldphoto2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/oldphoto2.png -------------------------------------------------------------------------------- /examples/RealSRSet/oldphoto3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/oldphoto3.png -------------------------------------------------------------------------------- /examples/RealSRSet/oldphoto6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/oldphoto6.png -------------------------------------------------------------------------------- /examples/RealSRSet/painting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/painting.png -------------------------------------------------------------------------------- /examples/RealSRSet/pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/pattern.png -------------------------------------------------------------------------------- /examples/RealSRSet/ppt3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/ppt3.png -------------------------------------------------------------------------------- /examples/RealSRSet/tiger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/RealSRSet/tiger.png -------------------------------------------------------------------------------- /examples/Set14/baboon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/baboon.png -------------------------------------------------------------------------------- /examples/Set14/barbara.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/barbara.png -------------------------------------------------------------------------------- /examples/Set14/bridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/bridge.png -------------------------------------------------------------------------------- /examples/Set14/coastguard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/coastguard.png -------------------------------------------------------------------------------- /examples/Set14/comic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/comic.png -------------------------------------------------------------------------------- /examples/Set14/face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/face.png -------------------------------------------------------------------------------- /examples/Set14/flowers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/flowers.png -------------------------------------------------------------------------------- /examples/Set14/foreman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/foreman.png -------------------------------------------------------------------------------- /examples/Set14/lenna.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/lenna.png -------------------------------------------------------------------------------- /examples/Set14/man.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/man.png -------------------------------------------------------------------------------- /examples/Set14/monarch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/monarch.png -------------------------------------------------------------------------------- /examples/Set14/pepper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/pepper.png -------------------------------------------------------------------------------- /examples/Set14/ppt3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/ppt3.png -------------------------------------------------------------------------------- /examples/Set14/zebra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set14/zebra.png -------------------------------------------------------------------------------- /examples/Set5/baby.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set5/baby.png -------------------------------------------------------------------------------- /examples/Set5/bird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set5/bird.png -------------------------------------------------------------------------------- /examples/Set5/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set5/butterfly.png -------------------------------------------------------------------------------- /examples/Set5/head.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set5/head.png -------------------------------------------------------------------------------- /examples/Set5/woman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/Set5/woman.png -------------------------------------------------------------------------------- /examples/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/examples/dog.png -------------------------------------------------------------------------------- /gradio_pasd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import einops 3 | import gradio as gr 4 | import numpy as np 5 | import torch 6 | import random 7 | from PIL import Image 8 | from pathlib import Path 9 | from torchvision import transforms 10 | import torch.nn.functional as F 11 | from torchvision.models import resnet50, ResNet50_Weights 12 | 13 | from pytorch_lightning import seed_everything 14 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor 15 | from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler 16 | 17 | from pasd.pipelines.pipeline_pasd import StableDiffusionControlNetPipeline 18 | from pasd.myutils.misc import load_dreambooth_lora, rand_name 19 | from pasd.myutils.wavelet_color_fix import wavelet_color_fix 20 | from pasd.annotator.retinaface import RetinaFaceDetection 21 | 22 | use_pasd_light = False 23 | face_detector = RetinaFaceDetection() 24 | 25 | if use_pasd_light: 26 | from pasd.models.pasd_light.unet_2d_condition import UNet2DConditionModel 27 | from pasd.models.pasd_light.controlnet import ControlNetModel 28 | else: 29 | from pasd.models.pasd.unet_2d_condition import UNet2DConditionModel 30 | from pasd.models.pasd.controlnet import ControlNetModel 31 | 32 | pretrained_model_path = "checkpoints/stable-diffusion-v1-5" 33 | ckpt_path = "runs/pasd/checkpoint-100000" 34 | #dreambooth_lora_path = "checkpoints/personalized_models/toonyou_beta3.safetensors" 35 | dreambooth_lora_path = "checkpoints/personalized_models/majicmixRealistic_v6.safetensors" 36 | #dreambooth_lora_path = "checkpoints/personalized_models/Realistic_Vision_V5.1.safetensors" 37 | weight_dtype = torch.float16 38 | device = "cuda" 39 | 40 | scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") 41 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 42 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 43 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 44 | feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor") 45 | unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet") 46 | controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet") 47 | vae.requires_grad_(False) 48 | text_encoder.requires_grad_(False) 49 | unet.requires_grad_(False) 50 | controlnet.requires_grad_(False) 51 | 52 | unet, vae, text_encoder = load_dreambooth_lora(unet, vae, text_encoder, dreambooth_lora_path) 53 | 54 | text_encoder.to(device, dtype=weight_dtype) 55 | vae.to(device, dtype=weight_dtype) 56 | unet.to(device, dtype=weight_dtype) 57 | controlnet.to(device, dtype=weight_dtype) 58 | 59 | validation_pipeline = StableDiffusionControlNetPipeline( 60 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, 61 | unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, 62 | ) 63 | #validation_pipeline.enable_vae_tiling() 64 | validation_pipeline._init_tiled_vae(decoder_tile_size=224) 65 | 66 | weights = ResNet50_Weights.DEFAULT 67 | preprocess = weights.transforms() 68 | resnet = resnet50(weights=weights) 69 | resnet.eval() 70 | 71 | def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed): 72 | process_size = 768 73 | resize_preproc = transforms.Compose([ 74 | transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR), 75 | ]) 76 | 77 | with torch.no_grad(): 78 | seed_everything(seed) 79 | generator = torch.Generator(device=device) 80 | 81 | input_image = input_image.convert('RGB') 82 | batch = preprocess(input_image).unsqueeze(0) 83 | prediction = resnet(batch).squeeze(0).softmax(0) 84 | class_id = prediction.argmax().item() 85 | score = prediction[class_id].item() 86 | category_name = weights.meta["categories"][class_id] 87 | if score >= 0.1: 88 | prompt += f"{category_name}" if prompt=='' else f", {category_name}" 89 | 90 | prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}" 91 | 92 | ori_width, ori_height = input_image.size 93 | resize_flag = False 94 | 95 | rscale = upscale 96 | input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale)) 97 | 98 | if min(input_image.size) < process_size: 99 | input_image = resize_preproc(input_image) 100 | 101 | input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8)) 102 | width, height = input_image.size 103 | resize_flag = True # 104 | 105 | try: 106 | image = validation_pipeline( 107 | None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg, 108 | negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0, 109 | ).images[0] 110 | 111 | if True: #alpha<1.0: 112 | image = wavelet_color_fix(image, input_image) 113 | 114 | if resize_flag: 115 | image = image.resize((ori_width*rscale, ori_height*rscale)) 116 | except Exception as e: 117 | print(e) 118 | image = Image.new(mode="RGB", size=(512, 512)) 119 | 120 | return image 121 | 122 | title = "Pixel-Aware Stable Diffusion for Real-ISR" 123 | description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them." 124 | article = "
" 125 | examples=[['samples/27d38eeb2dbbe7c9.png'],['samples/629e4da70703193b.png']] 126 | 127 | demo = gr.Interface( 128 | fn=inference, 129 | inputs=[gr.Image(type="pil"), 130 | gr.Textbox(label="Prompt", value="Asian"), 131 | gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece'), 132 | gr.Textbox(label="Negative Prompt",value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'), 133 | gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1), 134 | gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1), 135 | gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1), 136 | gr.Slider(label="Classier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1), 137 | gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)], 138 | outputs=gr.Image(type="pil"), 139 | title=title, 140 | description=description, 141 | article=article, 142 | examples=examples).queue(concurrency_count=1) 143 | 144 | demo.launch( 145 | server_name="0.0.0.0" if os.getenv('GRADIO_LISTEN', '') != '' else "127.0.0.1", 146 | share=True, 147 | root_path=f"/{os.getenv('GRADIO_PROXY_PATH')}" if os.getenv('GRADIO_PROXY_PATH') else "" 148 | ) -------------------------------------------------------------------------------- /pasd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/__init__.py -------------------------------------------------------------------------------- /pasd/annotator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/annotator/__init__.py -------------------------------------------------------------------------------- /pasd/annotator/ckpts/README.md: -------------------------------------------------------------------------------- 1 | Please put checkpoints here. 2 | 3 | [RetinaFace-R50](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth) | [yolov8n](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/yolov8n.pt) -------------------------------------------------------------------------------- /pasd/annotator/retinaface/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .wider_face import WiderFaceDetection, detection_collate 2 | from .data_augment import * 3 | from .config import * 4 | -------------------------------------------------------------------------------- /pasd/annotator/retinaface/data/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | 3 | cfg_mnet = { 4 | 'name': 'mobilenet0.25', 5 | 'min_sizes': [[16, 32], [64, 128], [256, 512]], 6 | 'steps': [8, 16, 32], 7 | 'variance': [0.1, 0.2], 8 | 'clip': False, 9 | 'loc_weight': 2.0, 10 | 'gpu_train': True, 11 | 'batch_size': 32, 12 | 'ngpu': 1, 13 | 'epoch': 250, 14 | 'decay1': 190, 15 | 'decay2': 220, 16 | 'image_size': 640, 17 | 'pretrain': False, 18 | 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3}, 19 | 'in_channel': 32, 20 | 'out_channel': 64 21 | } 22 | 23 | cfg_re50 = { 24 | 'name': 'Resnet50', 25 | 'min_sizes': [[16, 32], [64, 128], [256, 512]], 26 | 'steps': [8, 16, 32], 27 | 'variance': [0.1, 0.2], 28 | 'clip': False, 29 | 'loc_weight': 2.0, 30 | 'gpu_train': True, 31 | 'batch_size': 24, 32 | 'ngpu': 4, 33 | 'epoch': 100, 34 | 'decay1': 70, 35 | 'decay2': 90, 36 | 'image_size': 840, 37 | 'pretrain': False, 38 | 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3}, 39 | 'in_channel': 256, 40 | 'out_channel': 256 41 | } 42 | 43 | -------------------------------------------------------------------------------- /pasd/annotator/retinaface/data/data_augment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | from ..utils.box_utils import matrix_iof 5 | 6 | 7 | def _crop(image, boxes, labels, landm, img_dim): 8 | height, width, _ = image.shape 9 | pad_image_flag = True 10 | 11 | for _ in range(250): 12 | """ 13 | if random.uniform(0, 1) <= 0.2: 14 | scale = 1.0 15 | else: 16 | scale = random.uniform(0.3, 1.0) 17 | """ 18 | PRE_SCALES = [0.3, 0.45, 0.6, 0.8, 1.0] 19 | scale = random.choice(PRE_SCALES) 20 | short_side = min(width, height) 21 | w = int(scale * short_side) 22 | h = w 23 | 24 | if width == w: 25 | l = 0 26 | else: 27 | l = random.randrange(width - w) 28 | if height == h: 29 | t = 0 30 | else: 31 | t = random.randrange(height - h) 32 | roi = np.array((l, t, l + w, t + h)) 33 | 34 | value = matrix_iof(boxes, roi[np.newaxis]) 35 | flag = (value >= 1) 36 | if not flag.any(): 37 | continue 38 | 39 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2 40 | mask_a = np.logical_and(roi[:2] < centers, centers < roi[2:]).all(axis=1) 41 | boxes_t = boxes[mask_a].copy() 42 | labels_t = labels[mask_a].copy() 43 | landms_t = landm[mask_a].copy() 44 | landms_t = landms_t.reshape([-1, 5, 2]) 45 | 46 | if boxes_t.shape[0] == 0: 47 | continue 48 | 49 | image_t = image[roi[1]:roi[3], roi[0]:roi[2]] 50 | 51 | boxes_t[:, :2] = np.maximum(boxes_t[:, :2], roi[:2]) 52 | boxes_t[:, :2] -= roi[:2] 53 | boxes_t[:, 2:] = np.minimum(boxes_t[:, 2:], roi[2:]) 54 | boxes_t[:, 2:] -= roi[:2] 55 | 56 | # landm 57 | landms_t[:, :, :2] = landms_t[:, :, :2] - roi[:2] 58 | landms_t[:, :, :2] = np.maximum(landms_t[:, :, :2], np.array([0, 0])) 59 | landms_t[:, :, :2] = np.minimum(landms_t[:, :, :2], roi[2:] - roi[:2]) 60 | landms_t = landms_t.reshape([-1, 10]) 61 | 62 | 63 | # make sure that the cropped image contains at least one face > 16 pixel at training image scale 64 | b_w_t = (boxes_t[:, 2] - boxes_t[:, 0] + 1) / w * img_dim 65 | b_h_t = (boxes_t[:, 3] - boxes_t[:, 1] + 1) / h * img_dim 66 | mask_b = np.minimum(b_w_t, b_h_t) > 0.0 67 | boxes_t = boxes_t[mask_b] 68 | labels_t = labels_t[mask_b] 69 | landms_t = landms_t[mask_b] 70 | 71 | if boxes_t.shape[0] == 0: 72 | continue 73 | 74 | pad_image_flag = False 75 | 76 | return image_t, boxes_t, labels_t, landms_t, pad_image_flag 77 | return image, boxes, labels, landm, pad_image_flag 78 | 79 | 80 | def _distort(image): 81 | 82 | def _convert(image, alpha=1, beta=0): 83 | tmp = image.astype(float) * alpha + beta 84 | tmp[tmp < 0] = 0 85 | tmp[tmp > 255] = 255 86 | image[:] = tmp 87 | 88 | image = image.copy() 89 | 90 | if random.randrange(2): 91 | 92 | #brightness distortion 93 | if random.randrange(2): 94 | _convert(image, beta=random.uniform(-32, 32)) 95 | 96 | #contrast distortion 97 | if random.randrange(2): 98 | _convert(image, alpha=random.uniform(0.5, 1.5)) 99 | 100 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 101 | 102 | #saturation distortion 103 | if random.randrange(2): 104 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) 105 | 106 | #hue distortion 107 | if random.randrange(2): 108 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) 109 | tmp %= 180 110 | image[:, :, 0] = tmp 111 | 112 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 113 | 114 | else: 115 | 116 | #brightness distortion 117 | if random.randrange(2): 118 | _convert(image, beta=random.uniform(-32, 32)) 119 | 120 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 121 | 122 | #saturation distortion 123 | if random.randrange(2): 124 | _convert(image[:, :, 1], alpha=random.uniform(0.5, 1.5)) 125 | 126 | #hue distortion 127 | if random.randrange(2): 128 | tmp = image[:, :, 0].astype(int) + random.randint(-18, 18) 129 | tmp %= 180 130 | image[:, :, 0] = tmp 131 | 132 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 133 | 134 | #contrast distortion 135 | if random.randrange(2): 136 | _convert(image, alpha=random.uniform(0.5, 1.5)) 137 | 138 | return image 139 | 140 | 141 | def _expand(image, boxes, fill, p): 142 | if random.randrange(2): 143 | return image, boxes 144 | 145 | height, width, depth = image.shape 146 | 147 | scale = random.uniform(1, p) 148 | w = int(scale * width) 149 | h = int(scale * height) 150 | 151 | left = random.randint(0, w - width) 152 | top = random.randint(0, h - height) 153 | 154 | boxes_t = boxes.copy() 155 | boxes_t[:, :2] += (left, top) 156 | boxes_t[:, 2:] += (left, top) 157 | expand_image = np.empty( 158 | (h, w, depth), 159 | dtype=image.dtype) 160 | expand_image[:, :] = fill 161 | expand_image[top:top + height, left:left + width] = image 162 | image = expand_image 163 | 164 | return image, boxes_t 165 | 166 | 167 | def _mirror(image, boxes, landms): 168 | _, width, _ = image.shape 169 | if random.randrange(2): 170 | image = image[:, ::-1] 171 | boxes = boxes.copy() 172 | boxes[:, 0::2] = width - boxes[:, 2::-2] 173 | 174 | # landm 175 | landms = landms.copy() 176 | landms = landms.reshape([-1, 5, 2]) 177 | landms[:, :, 0] = width - landms[:, :, 0] 178 | tmp = landms[:, 1, :].copy() 179 | landms[:, 1, :] = landms[:, 0, :] 180 | landms[:, 0, :] = tmp 181 | tmp1 = landms[:, 4, :].copy() 182 | landms[:, 4, :] = landms[:, 3, :] 183 | landms[:, 3, :] = tmp1 184 | landms = landms.reshape([-1, 10]) 185 | 186 | return image, boxes, landms 187 | 188 | 189 | def _pad_to_square(image, rgb_mean, pad_image_flag): 190 | if not pad_image_flag: 191 | return image 192 | height, width, _ = image.shape 193 | long_side = max(width, height) 194 | image_t = np.empty((long_side, long_side, 3), dtype=image.dtype) 195 | image_t[:, :] = rgb_mean 196 | image_t[0:0 + height, 0:0 + width] = image 197 | return image_t 198 | 199 | 200 | def _resize_subtract_mean(image, insize, rgb_mean): 201 | interp_methods = [cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4] 202 | interp_method = interp_methods[random.randrange(5)] 203 | image = cv2.resize(image, (insize, insize), interpolation=interp_method) 204 | image = image.astype(np.float32) 205 | image -= rgb_mean 206 | return image.transpose(2, 0, 1) 207 | 208 | 209 | class preproc(object): 210 | 211 | def __init__(self, img_dim, rgb_means): 212 | self.img_dim = img_dim 213 | self.rgb_means = rgb_means 214 | 215 | def __call__(self, image, targets): 216 | assert targets.shape[0] > 0, "this image does not have gt" 217 | 218 | boxes = targets[:, :4].copy() 219 | labels = targets[:, -1].copy() 220 | landm = targets[:, 4:-1].copy() 221 | 222 | image_t, boxes_t, labels_t, landm_t, pad_image_flag = _crop(image, boxes, labels, landm, self.img_dim) 223 | image_t = _distort(image_t) 224 | image_t = _pad_to_square(image_t,self.rgb_means, pad_image_flag) 225 | image_t, boxes_t, landm_t = _mirror(image_t, boxes_t, landm_t) 226 | height, width, _ = image_t.shape 227 | image_t = _resize_subtract_mean(image_t, self.img_dim, self.rgb_means) 228 | boxes_t[:, 0::2] /= width 229 | boxes_t[:, 1::2] /= height 230 | 231 | landm_t[:, 0::2] /= width 232 | landm_t[:, 1::2] /= height 233 | 234 | labels_t = np.expand_dims(labels_t, 1) 235 | targets_t = np.hstack((boxes_t, landm_t, labels_t)) 236 | 237 | return image_t, targets_t 238 | -------------------------------------------------------------------------------- /pasd/annotator/retinaface/data/wider_face.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import sys 4 | import torch 5 | import torch.utils.data as data 6 | import cv2 7 | import numpy as np 8 | 9 | class WiderFaceDetection(data.Dataset): 10 | def __init__(self, txt_path, preproc=None): 11 | self.preproc = preproc 12 | self.imgs_path = [] 13 | self.words = [] 14 | f = open(txt_path,'r') 15 | lines = f.readlines() 16 | isFirst = True 17 | labels = [] 18 | for line in lines: 19 | line = line.rstrip() 20 | if line.startswith('#'): 21 | if isFirst is True: 22 | isFirst = False 23 | else: 24 | labels_copy = labels.copy() 25 | self.words.append(labels_copy) 26 | labels.clear() 27 | path = line[2:] 28 | path = txt_path.replace('label.txt','images/') + path 29 | self.imgs_path.append(path) 30 | else: 31 | line = line.split(' ') 32 | label = [float(x) for x in line] 33 | labels.append(label) 34 | 35 | self.words.append(labels) 36 | 37 | def __len__(self): 38 | return len(self.imgs_path) 39 | 40 | def __getitem__(self, index): 41 | img = cv2.imread(self.imgs_path[index]) 42 | height, width, _ = img.shape 43 | 44 | labels = self.words[index] 45 | annotations = np.zeros((0, 15)) 46 | if len(labels) == 0: 47 | return annotations 48 | for idx, label in enumerate(labels): 49 | annotation = np.zeros((1, 15)) 50 | # bbox 51 | annotation[0, 0] = label[0] # x1 52 | annotation[0, 1] = label[1] # y1 53 | annotation[0, 2] = label[0] + label[2] # x2 54 | annotation[0, 3] = label[1] + label[3] # y2 55 | 56 | # landmarks 57 | annotation[0, 4] = label[4] # l0_x 58 | annotation[0, 5] = label[5] # l0_y 59 | annotation[0, 6] = label[7] # l1_x 60 | annotation[0, 7] = label[8] # l1_y 61 | annotation[0, 8] = label[10] # l2_x 62 | annotation[0, 9] = label[11] # l2_y 63 | annotation[0, 10] = label[13] # l3_x 64 | annotation[0, 11] = label[14] # l3_y 65 | annotation[0, 12] = label[16] # l4_x 66 | annotation[0, 13] = label[17] # l4_y 67 | if (annotation[0, 4]<0): 68 | annotation[0, 14] = -1 69 | else: 70 | annotation[0, 14] = 1 71 | 72 | annotations = np.append(annotations, annotation, axis=0) 73 | target = np.array(annotations) 74 | if self.preproc is not None: 75 | img, target = self.preproc(img, target) 76 | 77 | return torch.from_numpy(img), target 78 | 79 | def detection_collate(batch): 80 | """Custom collate fn for dealing with batches of images that have a different 81 | number of associated object annotations (bounding boxes). 82 | 83 | Arguments: 84 | batch: (tuple) A tuple of tensor images and lists of annotations 85 | 86 | Return: 87 | A tuple containing: 88 | 1) (tensor) batch of images stacked on their 0 dim 89 | 2) (list of tensors) annotations for a given image are stacked on 0 dim 90 | """ 91 | targets = [] 92 | imgs = [] 93 | for _, sample in enumerate(batch): 94 | for _, tup in enumerate(sample): 95 | if torch.is_tensor(tup): 96 | imgs.append(tup) 97 | elif isinstance(tup, type(np.empty(0))): 98 | annos = torch.from_numpy(tup).float() 99 | targets.append(annos) 100 | 101 | return (torch.stack(imgs, 0), targets) 102 | -------------------------------------------------------------------------------- /pasd/annotator/retinaface/facemodels/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/annotator/retinaface/facemodels/__init__.py -------------------------------------------------------------------------------- /pasd/annotator/retinaface/facemodels/net.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models._utils as _utils 5 | import torchvision.models as models 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | def conv_bn(inp, oup, stride = 1, leaky = 0): 10 | return nn.Sequential( 11 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 12 | nn.BatchNorm2d(oup), 13 | nn.LeakyReLU(negative_slope=leaky, inplace=True) 14 | ) 15 | 16 | def conv_bn_no_relu(inp, oup, stride): 17 | return nn.Sequential( 18 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 19 | nn.BatchNorm2d(oup), 20 | ) 21 | 22 | def conv_bn1X1(inp, oup, stride, leaky=0): 23 | return nn.Sequential( 24 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), 25 | nn.BatchNorm2d(oup), 26 | nn.LeakyReLU(negative_slope=leaky, inplace=True) 27 | ) 28 | 29 | def conv_dw(inp, oup, stride, leaky=0.1): 30 | return nn.Sequential( 31 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 32 | nn.BatchNorm2d(inp), 33 | nn.LeakyReLU(negative_slope= leaky,inplace=True), 34 | 35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 36 | nn.BatchNorm2d(oup), 37 | nn.LeakyReLU(negative_slope= leaky,inplace=True), 38 | ) 39 | 40 | class SSH(nn.Module): 41 | def __init__(self, in_channel, out_channel): 42 | super(SSH, self).__init__() 43 | assert out_channel % 4 == 0 44 | leaky = 0 45 | if (out_channel <= 64): 46 | leaky = 0.1 47 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1) 48 | 49 | self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky) 50 | self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) 51 | 52 | self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky) 53 | self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) 54 | 55 | def forward(self, input): 56 | conv3X3 = self.conv3X3(input) 57 | 58 | conv5X5_1 = self.conv5X5_1(input) 59 | conv5X5 = self.conv5X5_2(conv5X5_1) 60 | 61 | conv7X7_2 = self.conv7X7_2(conv5X5_1) 62 | conv7X7 = self.conv7x7_3(conv7X7_2) 63 | 64 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) 65 | out = F.relu(out) 66 | return out 67 | 68 | class FPN(nn.Module): 69 | def __init__(self,in_channels_list,out_channels): 70 | super(FPN,self).__init__() 71 | leaky = 0 72 | if (out_channels <= 64): 73 | leaky = 0.1 74 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky) 75 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky) 76 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky) 77 | 78 | self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky) 79 | self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky) 80 | 81 | def forward(self, input): 82 | # names = list(input.keys()) 83 | input = list(input.values()) 84 | 85 | output1 = self.output1(input[0]) 86 | output2 = self.output2(input[1]) 87 | output3 = self.output3(input[2]) 88 | 89 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest") 90 | output2 = output2 + up3 91 | output2 = self.merge2(output2) 92 | 93 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest") 94 | output1 = output1 + up2 95 | output1 = self.merge1(output1) 96 | 97 | out = [output1, output2, output3] 98 | return out 99 | 100 | 101 | 102 | class MobileNetV1(nn.Module): 103 | def __init__(self): 104 | super(MobileNetV1, self).__init__() 105 | self.stage1 = nn.Sequential( 106 | conv_bn(3, 8, 2, leaky = 0.1), # 3 107 | conv_dw(8, 16, 1), # 7 108 | conv_dw(16, 32, 2), # 11 109 | conv_dw(32, 32, 1), # 19 110 | conv_dw(32, 64, 2), # 27 111 | conv_dw(64, 64, 1), # 43 112 | ) 113 | self.stage2 = nn.Sequential( 114 | conv_dw(64, 128, 2), # 43 + 16 = 59 115 | conv_dw(128, 128, 1), # 59 + 32 = 91 116 | conv_dw(128, 128, 1), # 91 + 32 = 123 117 | conv_dw(128, 128, 1), # 123 + 32 = 155 118 | conv_dw(128, 128, 1), # 155 + 32 = 187 119 | conv_dw(128, 128, 1), # 187 + 32 = 219 120 | ) 121 | self.stage3 = nn.Sequential( 122 | conv_dw(128, 256, 2), # 219 +3 2 = 241 123 | conv_dw(256, 256, 1), # 241 + 64 = 301 124 | ) 125 | self.avg = nn.AdaptiveAvgPool2d((1,1)) 126 | self.fc = nn.Linear(256, 1000) 127 | 128 | def forward(self, x): 129 | x = self.stage1(x) 130 | x = self.stage2(x) 131 | x = self.stage3(x) 132 | x = self.avg(x) 133 | # x = self.model(x) 134 | x = x.view(-1, 256) 135 | x = self.fc(x) 136 | return x 137 | 138 | -------------------------------------------------------------------------------- /pasd/annotator/retinaface/facemodels/retinaface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models.detection.backbone_utils as backbone_utils 4 | import torchvision.models._utils as _utils 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | 8 | from ..facemodels.net import MobileNetV1 as MobileNetV1 9 | from ..facemodels.net import FPN as FPN 10 | from ..facemodels.net import SSH as SSH 11 | 12 | 13 | 14 | class ClassHead(nn.Module): 15 | def __init__(self,inchannels=512,num_anchors=3): 16 | super(ClassHead,self).__init__() 17 | self.num_anchors = num_anchors 18 | self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0) 19 | 20 | def forward(self,x): 21 | out = self.conv1x1(x) 22 | out = out.permute(0,2,3,1).contiguous() 23 | 24 | return out.view(out.shape[0], -1, 2) 25 | 26 | class BboxHead(nn.Module): 27 | def __init__(self,inchannels=512,num_anchors=3): 28 | super(BboxHead,self).__init__() 29 | self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0) 30 | 31 | def forward(self,x): 32 | out = self.conv1x1(x) 33 | out = out.permute(0,2,3,1).contiguous() 34 | 35 | return out.view(out.shape[0], -1, 4) 36 | 37 | class LandmarkHead(nn.Module): 38 | def __init__(self,inchannels=512,num_anchors=3): 39 | super(LandmarkHead,self).__init__() 40 | self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0) 41 | 42 | def forward(self,x): 43 | out = self.conv1x1(x) 44 | out = out.permute(0,2,3,1).contiguous() 45 | 46 | return out.view(out.shape[0], -1, 10) 47 | 48 | class RetinaFace(nn.Module): 49 | def __init__(self, cfg = None, phase = 'train'): 50 | """ 51 | :param cfg: Network related settings. 52 | :param phase: train or test. 53 | """ 54 | super(RetinaFace,self).__init__() 55 | self.phase = phase 56 | backbone = None 57 | if cfg['name'] == 'mobilenet0.25': 58 | backbone = MobileNetV1() 59 | if cfg['pretrain']: 60 | checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu')) 61 | from collections import OrderedDict 62 | new_state_dict = OrderedDict() 63 | for k, v in checkpoint['state_dict'].items(): 64 | name = k[7:] # remove module. 65 | new_state_dict[name] = v 66 | # load params 67 | backbone.load_state_dict(new_state_dict) 68 | elif cfg['name'] == 'Resnet50': 69 | import torchvision.models as models 70 | backbone = models.resnet50(pretrained=cfg['pretrain']) 71 | 72 | self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers']) 73 | in_channels_stage2 = cfg['in_channel'] 74 | in_channels_list = [ 75 | in_channels_stage2 * 2, 76 | in_channels_stage2 * 4, 77 | in_channels_stage2 * 8, 78 | ] 79 | out_channels = cfg['out_channel'] 80 | self.fpn = FPN(in_channels_list,out_channels) 81 | self.ssh1 = SSH(out_channels, out_channels) 82 | self.ssh2 = SSH(out_channels, out_channels) 83 | self.ssh3 = SSH(out_channels, out_channels) 84 | 85 | self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel']) 86 | self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) 87 | self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) 88 | 89 | def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2): 90 | classhead = nn.ModuleList() 91 | for i in range(fpn_num): 92 | classhead.append(ClassHead(inchannels,anchor_num)) 93 | return classhead 94 | 95 | def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2): 96 | bboxhead = nn.ModuleList() 97 | for i in range(fpn_num): 98 | bboxhead.append(BboxHead(inchannels,anchor_num)) 99 | return bboxhead 100 | 101 | def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2): 102 | landmarkhead = nn.ModuleList() 103 | for i in range(fpn_num): 104 | landmarkhead.append(LandmarkHead(inchannels,anchor_num)) 105 | return landmarkhead 106 | 107 | def forward(self,inputs): 108 | out = self.body(inputs) 109 | 110 | # FPN 111 | fpn = self.fpn(out) 112 | 113 | # SSH 114 | feature1 = self.ssh1(fpn[0]) 115 | feature2 = self.ssh2(fpn[1]) 116 | feature3 = self.ssh3(fpn[2]) 117 | features = [feature1, feature2, feature3] 118 | 119 | bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) 120 | classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1) 121 | ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1) 122 | 123 | if self.phase == 'train': 124 | output = (bbox_regressions, classifications, ldm_regressions) 125 | else: 126 | output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) 127 | return output -------------------------------------------------------------------------------- /pasd/annotator/retinaface/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /pasd/annotator/retinaface/layers/functions/prior_box.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from itertools import product as product 3 | import numpy as np 4 | from math import ceil 5 | 6 | 7 | class PriorBox(object): 8 | def __init__(self, cfg, image_size=None, phase='train'): 9 | super(PriorBox, self).__init__() 10 | self.min_sizes = cfg['min_sizes'] 11 | self.steps = cfg['steps'] 12 | self.clip = cfg['clip'] 13 | self.image_size = image_size 14 | self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps] 15 | self.name = "s" 16 | 17 | def forward(self): 18 | anchors = [] 19 | for k, f in enumerate(self.feature_maps): 20 | min_sizes = self.min_sizes[k] 21 | for i, j in product(range(f[0]), range(f[1])): 22 | for min_size in min_sizes: 23 | s_kx = min_size / self.image_size[1] 24 | s_ky = min_size / self.image_size[0] 25 | dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] 26 | dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] 27 | for cy, cx in product(dense_cy, dense_cx): 28 | anchors += [cx, cy, s_kx, s_ky] 29 | 30 | # back to torch land 31 | output = torch.Tensor(anchors).view(-1, 4) 32 | if self.clip: 33 | output.clamp_(max=1, min=0) 34 | return output 35 | -------------------------------------------------------------------------------- /pasd/annotator/retinaface/layers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .multibox_loss import MultiBoxLoss 2 | 3 | __all__ = ['MultiBoxLoss'] 4 | -------------------------------------------------------------------------------- /pasd/annotator/retinaface/layers/modules/multibox_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from ...utils.box_utils import match, log_sum_exp 6 | from ...data import cfg_mnet 7 | GPU = cfg_mnet['gpu_train'] 8 | 9 | class MultiBoxLoss(nn.Module): 10 | """SSD Weighted Loss Function 11 | Compute Targets: 12 | 1) Produce Confidence Target Indices by matching ground truth boxes 13 | with (default) 'priorboxes' that have jaccard index > threshold parameter 14 | (default threshold: 0.5). 15 | 2) Produce localization target by 'encoding' variance into offsets of ground 16 | truth boxes and their matched 'priorboxes'. 17 | 3) Hard negative mining to filter the excessive number of negative examples 18 | that comes with using a large number of default bounding boxes. 19 | (default negative:positive ratio 3:1) 20 | Objective Loss: 21 | L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 22 | Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss 23 | weighted by α which is set to 1 by cross val. 24 | Args: 25 | c: class confidences, 26 | l: predicted boxes, 27 | g: ground truth boxes 28 | N: number of matched default boxes 29 | See: https://arxiv.org/pdf/1512.02325.pdf for more details. 30 | """ 31 | 32 | def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): 33 | super(MultiBoxLoss, self).__init__() 34 | self.num_classes = num_classes 35 | self.threshold = overlap_thresh 36 | self.background_label = bkg_label 37 | self.encode_target = encode_target 38 | self.use_prior_for_matching = prior_for_matching 39 | self.do_neg_mining = neg_mining 40 | self.negpos_ratio = neg_pos 41 | self.neg_overlap = neg_overlap 42 | self.variance = [0.1, 0.2] 43 | 44 | def forward(self, predictions, priors, targets): 45 | """Multibox Loss 46 | Args: 47 | predictions (tuple): A tuple containing loc preds, conf preds, 48 | and prior boxes from SSD net. 49 | conf shape: torch.size(batch_size,num_priors,num_classes) 50 | loc shape: torch.size(batch_size,num_priors,4) 51 | priors shape: torch.size(num_priors,4) 52 | 53 | ground_truth (tensor): Ground truth boxes and labels for a batch, 54 | shape: [batch_size,num_objs,5] (last idx is the label). 55 | """ 56 | 57 | loc_data, conf_data, landm_data = predictions 58 | priors = priors 59 | num = loc_data.size(0) 60 | num_priors = (priors.size(0)) 61 | 62 | # match priors (default boxes) and ground truth boxes 63 | loc_t = torch.Tensor(num, num_priors, 4) 64 | landm_t = torch.Tensor(num, num_priors, 10) 65 | conf_t = torch.LongTensor(num, num_priors) 66 | for idx in range(num): 67 | truths = targets[idx][:, :4].data 68 | labels = targets[idx][:, -1].data 69 | landms = targets[idx][:, 4:14].data 70 | defaults = priors.data 71 | match(self.threshold, truths, defaults, self.variance, labels, landms, loc_t, conf_t, landm_t, idx) 72 | if GPU: 73 | loc_t = loc_t.cuda() 74 | conf_t = conf_t.cuda() 75 | landm_t = landm_t.cuda() 76 | 77 | zeros = torch.tensor(0).cuda() 78 | # landm Loss (Smooth L1) 79 | # Shape: [batch,num_priors,10] 80 | pos1 = conf_t > zeros 81 | num_pos_landm = pos1.long().sum(1, keepdim=True) 82 | N1 = max(num_pos_landm.data.sum().float(), 1) 83 | pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data) 84 | landm_p = landm_data[pos_idx1].view(-1, 10) 85 | landm_t = landm_t[pos_idx1].view(-1, 10) 86 | loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum') 87 | 88 | 89 | pos = conf_t != zeros 90 | conf_t[pos] = 1 91 | 92 | # Localization Loss (Smooth L1) 93 | # Shape: [batch,num_priors,4] 94 | pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) 95 | loc_p = loc_data[pos_idx].view(-1, 4) 96 | loc_t = loc_t[pos_idx].view(-1, 4) 97 | loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') 98 | 99 | # Compute max conf across batch for hard negative mining 100 | batch_conf = conf_data.view(-1, self.num_classes) 101 | loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) 102 | 103 | # Hard Negative Mining 104 | loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now 105 | loss_c = loss_c.view(num, -1) 106 | _, loss_idx = loss_c.sort(1, descending=True) 107 | _, idx_rank = loss_idx.sort(1) 108 | num_pos = pos.long().sum(1, keepdim=True) 109 | num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) 110 | neg = idx_rank < num_neg.expand_as(idx_rank) 111 | 112 | # Confidence Loss Including Positive and Negative Examples 113 | pos_idx = pos.unsqueeze(2).expand_as(conf_data) 114 | neg_idx = neg.unsqueeze(2).expand_as(conf_data) 115 | conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) 116 | targets_weighted = conf_t[(pos+neg).gt(0)] 117 | loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') 118 | 119 | # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N 120 | N = max(num_pos.data.sum().float(), 1) 121 | loss_l /= N 122 | loss_c /= N 123 | loss_landm /= N1 124 | 125 | return loss_l, loss_c, loss_landm 126 | -------------------------------------------------------------------------------- /pasd/annotator/retinaface/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/annotator/retinaface/utils/__init__.py -------------------------------------------------------------------------------- /pasd/annotator/retinaface/utils/nms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/annotator/retinaface/utils/nms/__init__.py -------------------------------------------------------------------------------- /pasd/annotator/retinaface/utils/nms/py_cpu_nms.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | 10 | def py_cpu_nms(dets, thresh): 11 | """Pure Python NMS baseline.""" 12 | x1 = dets[:, 0] 13 | y1 = dets[:, 1] 14 | x2 = dets[:, 2] 15 | y2 = dets[:, 3] 16 | scores = dets[:, 4] 17 | 18 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 19 | order = scores.argsort()[::-1] 20 | 21 | keep = [] 22 | while order.size > 0: 23 | i = order[0] 24 | keep.append(i) 25 | xx1 = np.maximum(x1[i], x1[order[1:]]) 26 | yy1 = np.maximum(y1[i], y1[order[1:]]) 27 | xx2 = np.minimum(x2[i], x2[order[1:]]) 28 | yy2 = np.minimum(y2[i], y2[order[1:]]) 29 | 30 | w = np.maximum(0.0, xx2 - xx1 + 1) 31 | h = np.maximum(0.0, yy2 - yy1 + 1) 32 | inter = w * h 33 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 34 | 35 | inds = np.where(ovr <= thresh)[0] 36 | order = order[inds + 1] 37 | 38 | return keep 39 | -------------------------------------------------------------------------------- /pasd/annotator/retinaface/utils/timer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import time 9 | 10 | 11 | class Timer(object): 12 | """A simple timer.""" 13 | def __init__(self): 14 | self.total_time = 0. 15 | self.calls = 0 16 | self.start_time = 0. 17 | self.diff = 0. 18 | self.average_time = 0. 19 | 20 | def tic(self): 21 | # using time.time instead of time.clock because time time.clock 22 | # does not normalize for multithreading 23 | self.start_time = time.time() 24 | 25 | def toc(self, average=True): 26 | self.diff = time.time() - self.start_time 27 | self.total_time += self.diff 28 | self.calls += 1 29 | self.average_time = self.total_time / self.calls 30 | if average: 31 | return self.average_time 32 | else: 33 | return self.diff 34 | 35 | def clear(self): 36 | self.total_time = 0. 37 | self.calls = 0 38 | self.start_time = 0. 39 | self.diff = 0. 40 | self.average_time = 0. 41 | -------------------------------------------------------------------------------- /pasd/annotator/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | 5 | 6 | annotator_path = os.path.join(os.path.dirname(__file__)) 7 | 8 | 9 | def HWC3(x): 10 | assert x.dtype == np.uint8 11 | if x.ndim == 2: 12 | x = x[:, :, None] 13 | assert x.ndim == 3 14 | H, W, C = x.shape 15 | assert C == 1 or C == 3 or C == 4 16 | if C == 3: 17 | return x 18 | if C == 1: 19 | return np.concatenate([x, x, x], axis=2) 20 | if C == 4: 21 | color = x[:, :, 0:3].astype(np.float32) 22 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 23 | y = color * alpha + 255.0 * (1.0 - alpha) 24 | y = y.clip(0, 255).astype(np.uint8) 25 | return y 26 | 27 | 28 | def resize_image(input_image, resolution): 29 | H, W, C = input_image.shape 30 | H = float(H) 31 | W = float(W) 32 | k = float(resolution) / min(H, W) 33 | H *= k 34 | W *= k 35 | H = int(np.round(H / 64.0)) * 64 36 | W = int(np.round(W / 64.0)) * 64 37 | img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 38 | return img 39 | -------------------------------------------------------------------------------- /pasd/annotator/yolo/__init__.py: -------------------------------------------------------------------------------- 1 | from ultralytics import YOLO 2 | 3 | from annotator.util import annotator_path 4 | 5 | class YoLoDetection(object): 6 | def __init__(self, device='cuda'): 7 | MODEL = f'{annotator_path}/ckpts/yolov8n.pt' 8 | self.model = YOLO(MODEL) 9 | 10 | def detect(self, image, imgsz=640): 11 | results = self.model.predict(image, imgsz=imgsz, verbose=False, save=False) 12 | for result in results[:1]: 13 | boxes = result.boxes # Boxes object for bbox outputs 14 | cls = (boxes.cls).cpu().numpy() 15 | conf = (boxes.conf).cpu().numpy() 16 | names = result.names 17 | return cls, conf, names 18 | 19 | -------------------------------------------------------------------------------- /pasd/dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/__init__.py -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/__init__.py -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from .prefetch_dataloader import PrefetchDataLoader 11 | from ..utils import get_root_logger, scandir 12 | from ..utils.dist_util import get_dist_info 13 | from ..utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'loaders.basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/data/__pycache__/ffhq_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/data/__pycache__/ffhq_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/data/__pycache__/prefetch_dataloader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/data/__pycache__/prefetch_dataloader.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/data/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/data/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 11 | 12 | Args: 13 | generator: Python generator. 14 | num_prefetch_queue (int): Number of prefetch queue. 15 | """ 16 | 17 | def __init__(self, generator, num_prefetch_queue): 18 | threading.Thread.__init__(self) 19 | self.queue = Queue.Queue(num_prefetch_queue) 20 | self.generator = generator 21 | self.daemon = True 22 | self.start() 23 | 24 | def run(self): 25 | for item in self.generator: 26 | self.queue.put(item) 27 | self.queue.put(None) 28 | 29 | def __next__(self): 30 | next_item = self.queue.get() 31 | if next_item is None: 32 | raise StopIteration 33 | return next_item 34 | 35 | def __iter__(self): 36 | return self 37 | 38 | 39 | class PrefetchDataLoader(DataLoader): 40 | """Prefetch version of dataloader. 41 | 42 | Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 43 | 44 | TODO: 45 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 46 | ddp. 47 | 48 | Args: 49 | num_prefetch_queue (int): Number of prefetch queue. 50 | kwargs (dict): Other arguments for dataloader. 51 | """ 52 | 53 | def __init__(self, num_prefetch_queue, **kwargs): 54 | self.num_prefetch_queue = num_prefetch_queue 55 | super(PrefetchDataLoader, self).__init__(**kwargs) 56 | 57 | def __iter__(self): 58 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 59 | 60 | 61 | class CPUPrefetcher(): 62 | """CPU prefetcher. 63 | 64 | Args: 65 | loader: Dataloader. 66 | """ 67 | 68 | def __init__(self, loader): 69 | self.ori_loader = loader 70 | self.loader = iter(loader) 71 | 72 | def next(self): 73 | try: 74 | return next(self.loader) 75 | except StopIteration: 76 | return None 77 | 78 | def reset(self): 79 | self.loader = iter(self.ori_loader) 80 | 81 | 82 | class CUDAPrefetcher(): 83 | """CUDA prefetcher. 84 | 85 | Reference: https://github.com/NVIDIA/apex/issues/304# 86 | 87 | It may consume more GPU memory. 88 | 89 | Args: 90 | loader: Dataloader. 91 | opt (dict): Options. 92 | """ 93 | 94 | def __init__(self, loader, opt): 95 | self.ori_loader = loader 96 | self.loader = iter(loader) 97 | self.opt = opt 98 | self.stream = torch.cuda.Stream() 99 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 100 | self.preload() 101 | 102 | def preload(self): 103 | try: 104 | self.batch = next(self.loader) # self.batch is a dict 105 | except StopIteration: 106 | self.batch = None 107 | return None 108 | # put tensors to gpu 109 | with torch.cuda.stream(self.stream): 110 | for k, v in self.batch.items(): 111 | if torch.is_tensor(v): 112 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 113 | 114 | def next(self): 115 | torch.cuda.current_stream().wait_stream(self.stream) 116 | batch = self.batch 117 | self.preload() 118 | return batch 119 | 120 | def reset(self): 121 | self.loader = iter(self.ori_loader) 122 | self.preload() 123 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import torch 4 | 5 | 6 | def mod_crop(img, scale): 7 | """Mod crop images, used during testing. 8 | 9 | Args: 10 | img (ndarray): Input image. 11 | scale (int): Scale factor. 12 | 13 | Returns: 14 | ndarray: Result image. 15 | """ 16 | img = img.copy() 17 | if img.ndim in (2, 3): 18 | h, w = img.shape[0], img.shape[1] 19 | h_remainder, w_remainder = h % scale, w % scale 20 | img = img[:h - h_remainder, :w - w_remainder, ...] 21 | else: 22 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 23 | return img 24 | 25 | 26 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): 27 | """Paired random crop. Support Numpy array and Tensor inputs. 28 | 29 | It crops lists of lq and gt images with corresponding locations. 30 | 31 | Args: 32 | img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images 33 | should have the same shape. If the input is an ndarray, it will 34 | be transformed to a list containing itself. 35 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 36 | should have the same shape. If the input is an ndarray, it will 37 | be transformed to a list containing itself. 38 | gt_patch_size (int): GT patch size. 39 | scale (int): Scale factor. 40 | gt_path (str): Path to ground-truth. Default: None. 41 | 42 | Returns: 43 | list[ndarray] | ndarray: GT images and LQ images. If returned results 44 | only have one element, just return ndarray. 45 | """ 46 | 47 | if not isinstance(img_gts, list): 48 | img_gts = [img_gts] 49 | if not isinstance(img_lqs, list): 50 | img_lqs = [img_lqs] 51 | 52 | # determine input type: Numpy array or Tensor 53 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' 54 | 55 | if input_type == 'Tensor': 56 | h_lq, w_lq = img_lqs[0].size()[-2:] 57 | h_gt, w_gt = img_gts[0].size()[-2:] 58 | else: 59 | h_lq, w_lq = img_lqs[0].shape[0:2] 60 | h_gt, w_gt = img_gts[0].shape[0:2] 61 | lq_patch_size = gt_patch_size // scale 62 | 63 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 64 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 65 | f'multiplication of LQ ({h_lq}, {w_lq}).') 66 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 67 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 68 | f'({lq_patch_size}, {lq_patch_size}). ' 69 | f'Please remove {gt_path}.') 70 | 71 | # randomly choose top and left coordinates for lq patch 72 | top = random.randint(0, h_lq - lq_patch_size) 73 | left = random.randint(0, w_lq - lq_patch_size) 74 | 75 | # crop lq patch 76 | if input_type == 'Tensor': 77 | img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] 78 | else: 79 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] 80 | 81 | # crop corresponding gt patch 82 | top_gt, left_gt = int(top * scale), int(left * scale) 83 | if input_type == 'Tensor': 84 | img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] 85 | else: 86 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] 87 | if len(img_gts) == 1: 88 | img_gts = img_gts[0] 89 | if len(img_lqs) == 1: 90 | img_lqs = img_lqs[0] 91 | return img_gts, img_lqs 92 | 93 | 94 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 95 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 96 | 97 | We use vertical flip and transpose for rotation implementation. 98 | All the images in the list use the same augmentation. 99 | 100 | Args: 101 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 102 | is an ndarray, it will be transformed to a list. 103 | hflip (bool): Horizontal flip. Default: True. 104 | rotation (bool): Ratotation. Default: True. 105 | flows (list[ndarray]: Flows to be augmented. If the input is an 106 | ndarray, it will be transformed to a list. 107 | Dimension is (h, w, 2). Default: None. 108 | return_status (bool): Return the status of flip and rotation. 109 | Default: False. 110 | 111 | Returns: 112 | list[ndarray] | ndarray: Augmented images and flows. If returned 113 | results only have one element, just return ndarray. 114 | 115 | """ 116 | hflip = hflip and random.random() < 0.5 117 | vflip = rotation and random.random() < 0.5 118 | rot90 = rotation and random.random() < 0.5 119 | 120 | def _augment(img): 121 | if hflip: # horizontal 122 | cv2.flip(img, 1, img) 123 | if vflip: # vertical 124 | cv2.flip(img, 0, img) 125 | if rot90: 126 | img = img.transpose(1, 0, 2) 127 | return img 128 | 129 | def _augment_flow(flow): 130 | if hflip: # horizontal 131 | cv2.flip(flow, 1, flow) 132 | flow[:, :, 0] *= -1 133 | if vflip: # vertical 134 | cv2.flip(flow, 0, flow) 135 | flow[:, :, 1] *= -1 136 | if rot90: 137 | flow = flow.transpose(1, 0, 2) 138 | flow = flow[:, :, [1, 0]] 139 | return flow 140 | 141 | if not isinstance(imgs, list): 142 | imgs = [imgs] 143 | imgs = [_augment(img) for img in imgs] 144 | if len(imgs) == 1: 145 | imgs = imgs[0] 146 | 147 | if flows is not None: 148 | if not isinstance(flows, list): 149 | flows = [flows] 150 | flows = [_augment_flow(flow) for flow in flows] 151 | if len(flows) == 1: 152 | flows = flows[0] 153 | return imgs, flows 154 | else: 155 | if return_status: 156 | return imgs, (hflip, vflip, rot90) 157 | else: 158 | return imgs 159 | 160 | 161 | def img_rotate(img, angle, center=None, scale=1.0): 162 | """Rotate image. 163 | 164 | Args: 165 | img (ndarray): Image to be rotated. 166 | angle (float): Rotation angle in degrees. Positive values mean 167 | counter-clockwise rotation. 168 | center (tuple[int]): Rotation center. If the center is None, 169 | initialize it as the center of the image. Default: None. 170 | scale (float): Isotropic scale factor. Default: 1.0. 171 | """ 172 | (h, w) = img.shape[:2] 173 | 174 | if center is None: 175 | center = (w // 2, h // 2) 176 | 177 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 178 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 179 | return rotated_img 180 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb 2 | from .diffjpeg import DiffJPEG 3 | from .file_client import FileClient 4 | from .img_process_util import USMSharp, usm_sharp 5 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 6 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 7 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 8 | from .options import yaml_load 9 | 10 | __all__ = [ 11 | # color_util.py 12 | 'bgr2ycbcr', 13 | 'rgb2ycbcr', 14 | 'rgb2ycbcr_pt', 15 | 'ycbcr2bgr', 16 | 'ycbcr2rgb', 17 | # file_client.py 18 | 'FileClient', 19 | # img_util.py 20 | 'img2tensor', 21 | 'tensor2img', 22 | 'imfrombytes', 23 | 'imwrite', 24 | 'crop_border', 25 | # logger.py 26 | 'MessageLogger', 27 | 'AvgTimer', 28 | 'init_tb_logger', 29 | 'init_wandb_logger', 30 | 'get_root_logger', 31 | 'get_env_info', 32 | # misc.py 33 | 'set_random_seed', 34 | 'get_time_str', 35 | 'mkdir_and_rename', 36 | 'make_exp_dirs', 37 | 'scandir', 38 | 'check_resume', 39 | 'sizeof_fmt', 40 | # diffjpeg 41 | 'DiffJPEG', 42 | # img_process_util 43 | 'USMSharp', 44 | 'usm_sharp', 45 | # options 46 | 'yaml_load' 47 | ] 48 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__pycache__/color_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/utils/__pycache__/color_util.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__pycache__/diffjpeg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/utils/__pycache__/diffjpeg.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__pycache__/dist_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/utils/__pycache__/dist_util.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__pycache__/file_client.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/utils/__pycache__/file_client.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__pycache__/img_process_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/utils/__pycache__/img_process_util.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__pycache__/img_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/utils/__pycache__/img_util.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/utils/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__pycache__/misc.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/utils/__pycache__/misc.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__pycache__/options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/utils/__pycache__/options.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/__pycache__/registry.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/dataloader/basicsr/utils/__pycache__/registry.cpython-39.pyc -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | 14 | Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive 15 | 16 | Args: 17 | file_id (str): File id. 18 | save_path (str): Save path. 19 | """ 20 | 21 | session = requests.Session() 22 | URL = 'https://docs.google.com/uc?export=download' 23 | params = {'id': file_id} 24 | 25 | response = session.get(URL, params=params, stream=True) 26 | token = get_confirm_token(response) 27 | if token: 28 | params['confirm'] = token 29 | response = session.get(URL, params=params, stream=True) 30 | 31 | # get file size 32 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 33 | if 'Content-Range' in response_file_size.headers: 34 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 49 | if file_size is not None: 50 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 51 | 52 | readable_file_size = sizeof_fmt(file_size) 53 | else: 54 | pbar = None 55 | 56 | with open(destination, 'wb') as f: 57 | downloaded_size = 0 58 | for chunk in response.iter_content(chunk_size): 59 | downloaded_size += chunk_size 60 | if pbar is not None: 61 | pbar.update(1) 62 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | if pbar is not None: 66 | pbar.close() 67 | 68 | 69 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 70 | """Load file form http url, will download models if necessary. 71 | 72 | Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 73 | 74 | Args: 75 | url (str): URL to be downloaded. 76 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 77 | Default: None. 78 | progress (bool): Whether to show the download progress. Default: True. 79 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 80 | 81 | Returns: 82 | str: The path to the downloaded file. 83 | """ 84 | if model_dir is None: # use the pytorch hub_dir 85 | hub_dir = get_dir() 86 | model_dir = os.path.join(hub_dir, 'checkpoints') 87 | 88 | os.makedirs(model_dir, exist_ok=True) 89 | 90 | parts = urlparse(url) 91 | filename = os.path.basename(parts.path) 92 | if file_name is not None: 93 | filename = file_name 94 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 95 | if not os.path.exists(cached_file): 96 | print(f'Downloading: "{url}" to {cached_file}\n') 97 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 98 | return cached_file 99 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing different lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 2 | import cv2 3 | import numpy as np 4 | import os 5 | 6 | 7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 8 | """Read an optical flow map. 9 | 10 | Args: 11 | flow_path (ndarray or str): Flow path. 12 | quantize (bool): whether to read quantized pair, if set to True, 13 | remaining args will be passed to :func:`dequantize_flow`. 14 | concat_axis (int): The axis that dx and dy are concatenated, 15 | can be either 0 or 1. Ignored if quantize is False. 16 | 17 | Returns: 18 | ndarray: Optical flow represented as a (h, w, 2) numpy array 19 | """ 20 | if quantize: 21 | assert concat_axis in [0, 1] 22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 23 | if cat_flow.ndim != 2: 24 | raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') 25 | assert cat_flow.shape[concat_axis] % 2 == 0 26 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 27 | flow = dequantize_flow(dx, dy, *args, **kwargs) 28 | else: 29 | with open(flow_path, 'rb') as f: 30 | try: 31 | header = f.read(4).decode('utf-8') 32 | except Exception: 33 | raise IOError(f'Invalid flow file: {flow_path}') 34 | else: 35 | if header != 'PIEH': 36 | raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') 37 | 38 | w = np.fromfile(f, np.int32, 1).squeeze() 39 | h = np.fromfile(f, np.int32, 1).squeeze() 40 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 41 | 42 | return flow.astype(np.float32) 43 | 44 | 45 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 46 | """Write optical flow to file. 47 | 48 | If the flow is not quantized, it will be saved as a .flo file losslessly, 49 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 50 | will be concatenated horizontally into a single image if quantize is True.) 51 | 52 | Args: 53 | flow (ndarray): (h, w, 2) array of optical flow. 54 | filename (str): Output filepath. 55 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 56 | images. If set to True, remaining args will be passed to 57 | :func:`quantize_flow`. 58 | concat_axis (int): The axis that dx and dy are concatenated, 59 | can be either 0 or 1. Ignored if quantize is False. 60 | """ 61 | if not quantize: 62 | with open(filename, 'wb') as f: 63 | f.write('PIEH'.encode('utf-8')) 64 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 65 | flow = flow.astype(np.float32) 66 | flow.tofile(f) 67 | f.flush() 68 | else: 69 | assert concat_axis in [0, 1] 70 | dx, dy = quantize_flow(flow, *args, **kwargs) 71 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 72 | os.makedirs(os.path.dirname(filename), exist_ok=True) 73 | cv2.imwrite(filename, dxdy) 74 | 75 | 76 | def quantize_flow(flow, max_val=0.02, norm=True): 77 | """Quantize flow to [0, 255]. 78 | 79 | After this step, the size of flow will be much smaller, and can be 80 | dumped as jpeg images. 81 | 82 | Args: 83 | flow (ndarray): (h, w, 2) array of optical flow. 84 | max_val (float): Maximum value of flow, values beyond 85 | [-max_val, max_val] will be truncated. 86 | norm (bool): Whether to divide flow values by image width/height. 87 | 88 | Returns: 89 | tuple[ndarray]: Quantized dx and dy. 90 | """ 91 | h, w, _ = flow.shape 92 | dx = flow[..., 0] 93 | dy = flow[..., 1] 94 | if norm: 95 | dx = dx / w # avoid inplace operations 96 | dy = dy / h 97 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 98 | flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] 99 | return tuple(flow_comps) 100 | 101 | 102 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 103 | """Recover from quantized flow. 104 | 105 | Args: 106 | dx (ndarray): Quantized dx. 107 | dy (ndarray): Quantized dy. 108 | max_val (float): Maximum value used when quantizing. 109 | denorm (bool): Whether to multiply flow values with width/height. 110 | 111 | Returns: 112 | ndarray: Dequantized flow. 113 | """ 114 | assert dx.shape == dy.shape 115 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 116 | 117 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 118 | 119 | if denorm: 120 | dx *= dx.shape[1] 121 | dy *= dx.shape[0] 122 | flow = np.dstack((dx, dy)) 123 | return flow 124 | 125 | 126 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 127 | """Quantize an array of (-inf, inf) to [0, levels-1]. 128 | 129 | Args: 130 | arr (ndarray): Input array. 131 | min_val (scalar): Minimum value to be clipped. 132 | max_val (scalar): Maximum value to be clipped. 133 | levels (int): Quantization levels. 134 | dtype (np.type): The type of the quantized array. 135 | 136 | Returns: 137 | tuple: Quantized array. 138 | """ 139 | if not (isinstance(levels, int) and levels > 1): 140 | raise ValueError(f'levels must be a positive integer, but got {levels}') 141 | if min_val >= max_val: 142 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 143 | 144 | arr = np.clip(arr, min_val, max_val) - min_val 145 | quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 146 | 147 | return quantized_arr 148 | 149 | 150 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 151 | """Dequantize an array. 152 | 153 | Args: 154 | arr (ndarray): Input array. 155 | min_val (scalar): Minimum value to be clipped. 156 | max_val (scalar): Maximum value to be clipped. 157 | levels (int): Quantization levels. 158 | dtype (np.type): The type of the dequantized array. 159 | 160 | Returns: 161 | tuple: Dequantized array. 162 | """ 163 | if not (isinstance(levels, int) and levels > 1): 164 | raise ValueError(f'levels must be a positive integer, but got {levels}') 165 | if min_val >= max_val: 166 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 167 | 168 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val 169 | 170 | return dequantized_arr 171 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/img_process_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def filter2D(img, kernel): 8 | """PyTorch version of cv2.filter2D 9 | 10 | Args: 11 | img (Tensor): (b, c, h, w) 12 | kernel (Tensor): (b, k, k) 13 | """ 14 | k = kernel.size(-1) 15 | b, c, h, w = img.size() 16 | if k % 2 == 1: 17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 18 | else: 19 | raise ValueError('Wrong kernel size') 20 | 21 | ph, pw = img.size()[-2:] 22 | 23 | if kernel.size(0) == 1: 24 | # apply the same kernel to all batch images 25 | img = img.view(b * c, 1, ph, pw) 26 | kernel = kernel.view(1, 1, k, k) 27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 28 | else: 29 | img = img.view(1, b * c, ph, pw) 30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 32 | 33 | 34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10): 35 | """USM sharpening. 36 | 37 | Input image: I; Blurry image: B. 38 | 1. sharp = I + weight * (I - B) 39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 40 | 3. Blur mask: 41 | 4. Out = Mask * sharp + (1 - Mask) * I 42 | 43 | 44 | Args: 45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 46 | weight (float): Sharp weight. Default: 1. 47 | radius (float): Kernel size of Gaussian blur. Default: 50. 48 | threshold (int): 49 | """ 50 | if radius % 2 == 0: 51 | radius += 1 52 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 53 | residual = img - blur 54 | mask = np.abs(residual) * 255 > threshold 55 | mask = mask.astype('float32') 56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 57 | 58 | sharp = img + weight * residual 59 | sharp = np.clip(sharp, 0, 1) 60 | return soft_mask * sharp + (1 - soft_mask) * img 61 | 62 | 63 | class USMSharp(torch.nn.Module): 64 | 65 | def __init__(self, radius=50, sigma=0): 66 | super(USMSharp, self).__init__() 67 | if radius % 2 == 0: 68 | radius += 1 69 | self.radius = radius 70 | kernel = cv2.getGaussianKernel(radius, sigma) 71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 72 | self.register_buffer('kernel', kernel) 73 | 74 | def forward(self, img, weight=0.5, threshold=10): 75 | blur = filter2D(img, self.kernel) 76 | residual = img - blur 77 | 78 | mask = torch.abs(residual) * 255 > threshold 79 | mask = mask.float() 80 | soft_mask = filter2D(mask, self.kernel) 81 | sharp = img + weight * residual 82 | sharp = torch.clip(sharp, 0, 1) 83 | return soft_mask * sharp + (1 - soft_mask) * img 84 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | ok = cv2.imwrite(file_path, img, params) 152 | if not ok: 153 | raise IOError('Failed in writing images.') 154 | 155 | 156 | def crop_border(imgs, crop_border): 157 | """Crop borders of images. 158 | 159 | Args: 160 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 161 | crop_border (int): Crop border for each end of height and weight. 162 | 163 | Returns: 164 | list[ndarray]: Cropped images. 165 | """ 166 | if crop_border == 0: 167 | return imgs 168 | else: 169 | if isinstance(imgs, list): 170 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 171 | else: 172 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 173 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lmdb 3 | import sys 4 | from multiprocessing import Pool 5 | from os import path as osp 6 | from tqdm import tqdm 7 | 8 | 9 | def make_lmdb_from_imgs(data_path, 10 | lmdb_path, 11 | img_path_list, 12 | keys, 13 | batch=5000, 14 | compress_level=1, 15 | multiprocessing_read=False, 16 | n_thread=40, 17 | map_size=None): 18 | """Make lmdb from images. 19 | 20 | Contents of lmdb. The file structure is: 21 | 22 | :: 23 | 24 | example.lmdb 25 | ├── data.mdb 26 | ├── lock.mdb 27 | ├── meta_info.txt 28 | 29 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 30 | https://lmdb.readthedocs.io/en/release/ for more details. 31 | 32 | The meta_info.txt is a specified txt file to record the meta information 33 | of our datasets. It will be automatically created when preparing 34 | datasets by our provided dataset tools. 35 | Each line in the txt file records 1)image name (with extension), 36 | 2)image shape, and 3)compression level, separated by a white space. 37 | 38 | For example, the meta information could be: 39 | `000_00000000.png (720,1280,3) 1`, which means: 40 | 1) image name (with extension): 000_00000000.png; 41 | 2) image shape: (720,1280,3); 42 | 3) compression level: 1 43 | 44 | We use the image name without extension as the lmdb key. 45 | 46 | If `multiprocessing_read` is True, it will read all the images to memory 47 | using multiprocessing. Thus, your server needs to have enough memory. 48 | 49 | Args: 50 | data_path (str): Data path for reading images. 51 | lmdb_path (str): Lmdb save path. 52 | img_path_list (str): Image path list. 53 | keys (str): Used for lmdb keys. 54 | batch (int): After processing batch images, lmdb commits. 55 | Default: 5000. 56 | compress_level (int): Compress level when encoding images. Default: 1. 57 | multiprocessing_read (bool): Whether use multiprocessing to read all 58 | the images to memory. Default: False. 59 | n_thread (int): For multiprocessing. 60 | map_size (int | None): Map size for lmdb env. If None, use the 61 | estimated size from images. Default: None 62 | """ 63 | 64 | assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' 65 | f'but got {len(img_path_list)} and {len(keys)}') 66 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 67 | print(f'Totoal images: {len(img_path_list)}') 68 | if not lmdb_path.endswith('.lmdb'): 69 | raise ValueError("lmdb_path must end with '.lmdb'.") 70 | if osp.exists(lmdb_path): 71 | print(f'Folder {lmdb_path} already exists. Exit.') 72 | sys.exit(1) 73 | 74 | if multiprocessing_read: 75 | # read all the images to memory (multiprocessing) 76 | dataset = {} # use dict to keep the order for multiprocessing 77 | shapes = {} 78 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 79 | pbar = tqdm(total=len(img_path_list), unit='image') 80 | 81 | def callback(arg): 82 | """get the image data and update pbar.""" 83 | key, dataset[key], shapes[key] = arg 84 | pbar.update(1) 85 | pbar.set_description(f'Read {key}') 86 | 87 | pool = Pool(n_thread) 88 | for path, key in zip(img_path_list, keys): 89 | pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) 90 | pool.close() 91 | pool.join() 92 | pbar.close() 93 | print(f'Finish reading {len(img_path_list)} images.') 94 | 95 | # create lmdb environment 96 | if map_size is None: 97 | # obtain data size for one image 98 | img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 99 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 100 | data_size_per_img = img_byte.nbytes 101 | print('Data size per image is: ', data_size_per_img) 102 | data_size = data_size_per_img * len(img_path_list) 103 | map_size = data_size * 10 104 | 105 | env = lmdb.open(lmdb_path, map_size=map_size) 106 | 107 | # write data to lmdb 108 | pbar = tqdm(total=len(img_path_list), unit='chunk') 109 | txn = env.begin(write=True) 110 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 111 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 112 | pbar.update(1) 113 | pbar.set_description(f'Write {key}') 114 | key_byte = key.encode('ascii') 115 | if multiprocessing_read: 116 | img_byte = dataset[key] 117 | h, w, c = shapes[key] 118 | else: 119 | _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) 120 | h, w, c = img_shape 121 | 122 | txn.put(key_byte, img_byte) 123 | # write meta information 124 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 125 | if idx % batch == 0: 126 | txn.commit() 127 | txn = env.begin(write=True) 128 | pbar.close() 129 | txn.commit() 130 | env.close() 131 | txt_file.close() 132 | print('\nFinish writing lmdb.') 133 | 134 | 135 | def read_img_worker(path, key, compress_level): 136 | """Read image worker. 137 | 138 | Args: 139 | path (str): Image path. 140 | key (str): Image key. 141 | compress_level (int): Compress level when encoding images. 142 | 143 | Returns: 144 | str: Image key. 145 | byte: Image byte. 146 | tuple[int]: Image shape. 147 | """ 148 | 149 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 150 | if img.ndim == 2: 151 | h, w = img.shape 152 | c = 1 153 | else: 154 | h, w, c = img.shape 155 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 156 | return (key, img_byte, (h, w, c)) 157 | 158 | 159 | class LmdbMaker(): 160 | """LMDB Maker. 161 | 162 | Args: 163 | lmdb_path (str): Lmdb save path. 164 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 165 | batch (int): After processing batch images, lmdb commits. 166 | Default: 5000. 167 | compress_level (int): Compress level when encoding images. Default: 1. 168 | """ 169 | 170 | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): 171 | if not lmdb_path.endswith('.lmdb'): 172 | raise ValueError("lmdb_path must end with '.lmdb'.") 173 | if osp.exists(lmdb_path): 174 | print(f'Folder {lmdb_path} already exists. Exit.') 175 | sys.exit(1) 176 | 177 | self.lmdb_path = lmdb_path 178 | self.batch = batch 179 | self.compress_level = compress_level 180 | self.env = lmdb.open(lmdb_path, map_size=map_size) 181 | self.txn = self.env.begin(write=True) 182 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 183 | self.counter = 0 184 | 185 | def put(self, img_byte, key, img_shape): 186 | self.counter += 1 187 | key_byte = key.encode('ascii') 188 | self.txn.put(key_byte, img_byte) 189 | # write meta information 190 | h, w, c = img_shape 191 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 192 | if self.counter % self.batch == 0: 193 | self.txn.commit() 194 | self.txn = self.env.begin(write=True) 195 | 196 | def close(self): 197 | self.txn.commit() 198 | self.env.close() 199 | self.txt_file.close() 200 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class AvgTimer(): 11 | 12 | def __init__(self, window=200): 13 | self.window = window # average window 14 | self.current_time = 0 15 | self.total_time = 0 16 | self.count = 0 17 | self.avg_time = 0 18 | self.start() 19 | 20 | def start(self): 21 | self.start_time = self.tic = time.time() 22 | 23 | def record(self): 24 | self.count += 1 25 | self.toc = time.time() 26 | self.current_time = self.toc - self.tic 27 | self.total_time += self.current_time 28 | # calculate average time 29 | self.avg_time = self.total_time / self.count 30 | 31 | # reset 32 | if self.count > self.window: 33 | self.count = 0 34 | self.total_time = 0 35 | 36 | self.tic = time.time() 37 | 38 | def get_current_time(self): 39 | return self.current_time 40 | 41 | def get_avg_time(self): 42 | return self.avg_time 43 | 44 | 45 | class MessageLogger(): 46 | """Message logger for printing. 47 | 48 | Args: 49 | opt (dict): Config. It contains the following keys: 50 | name (str): Exp name. 51 | logger (dict): Contains 'print_freq' (str) for logger interval. 52 | train (dict): Contains 'total_iter' (int) for total iters. 53 | use_tb_logger (bool): Use tensorboard logger. 54 | start_iter (int): Start iter. Default: 1. 55 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 56 | """ 57 | 58 | def __init__(self, opt, start_iter=1, tb_logger=None): 59 | self.exp_name = opt['name'] 60 | self.interval = opt['logger']['print_freq'] 61 | self.start_iter = start_iter 62 | self.max_iters = opt['train']['total_iter'] 63 | self.use_tb_logger = opt['logger']['use_tb_logger'] 64 | self.tb_logger = tb_logger 65 | self.start_time = time.time() 66 | self.logger = get_root_logger() 67 | 68 | def reset_start_time(self): 69 | self.start_time = time.time() 70 | 71 | @master_only 72 | def __call__(self, log_vars): 73 | """Format logging message. 74 | 75 | Args: 76 | log_vars (dict): It contains the following keys: 77 | epoch (int): Epoch number. 78 | iter (int): Current iter. 79 | lrs (list): List for learning rates. 80 | 81 | time (float): Iter time. 82 | data_time (float): Data time for each iter. 83 | """ 84 | # epoch, iter, learning rates 85 | epoch = log_vars.pop('epoch') 86 | current_iter = log_vars.pop('iter') 87 | lrs = log_vars.pop('lrs') 88 | 89 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') 90 | for v in lrs: 91 | message += f'{v:.3e},' 92 | message += ')] ' 93 | 94 | # time and estimated time 95 | if 'time' in log_vars.keys(): 96 | iter_time = log_vars.pop('time') 97 | data_time = log_vars.pop('data_time') 98 | 99 | total_time = time.time() - self.start_time 100 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 101 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 102 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 103 | message += f'[eta: {eta_str}, ' 104 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 105 | 106 | # other items, especially losses 107 | for k, v in log_vars.items(): 108 | message += f'{k}: {v:.4e} ' 109 | # tensorboard logger 110 | if self.use_tb_logger and 'debug' not in self.exp_name: 111 | if k.startswith('l_'): 112 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 113 | else: 114 | self.tb_logger.add_scalar(k, v, current_iter) 115 | self.logger.info(message) 116 | 117 | 118 | @master_only 119 | def init_tb_logger(log_dir): 120 | from torch.utils.tensorboard import SummaryWriter 121 | tb_logger = SummaryWriter(log_dir=log_dir) 122 | return tb_logger 123 | 124 | 125 | @master_only 126 | def init_wandb_logger(opt): 127 | """We now only use wandb to sync tensorboard log.""" 128 | import wandb 129 | logger = get_root_logger() 130 | 131 | project = opt['logger']['wandb']['project'] 132 | resume_id = opt['logger']['wandb'].get('resume_id') 133 | if resume_id: 134 | wandb_id = resume_id 135 | resume = 'allow' 136 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 137 | else: 138 | wandb_id = wandb.util.generate_id() 139 | resume = 'never' 140 | 141 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 142 | 143 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 144 | 145 | 146 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 147 | """Get the root logger. 148 | 149 | The logger will be initialized if it has not been initialized. By default a 150 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 151 | also be added. 152 | 153 | Args: 154 | logger_name (str): root logger name. Default: 'basicsr'. 155 | log_file (str | None): The log filename. If specified, a FileHandler 156 | will be added to the root logger. 157 | log_level (int): The root logger level. Note that only the process of 158 | rank 0 is affected, while other processes will set the level to 159 | "Error" and be silent most of the time. 160 | 161 | Returns: 162 | logging.Logger: The root logger. 163 | """ 164 | logger = logging.getLogger(logger_name) 165 | # if the logger has been initialized, just return it 166 | if logger_name in initialized_logger: 167 | return logger 168 | 169 | format_str = '%(asctime)s %(levelname)s: %(message)s' 170 | stream_handler = logging.StreamHandler() 171 | stream_handler.setFormatter(logging.Formatter(format_str)) 172 | logger.addHandler(stream_handler) 173 | logger.propagate = False 174 | rank, _ = get_dist_info() 175 | if rank != 0: 176 | logger.setLevel('ERROR') 177 | elif log_file is not None: 178 | logger.setLevel(log_level) 179 | # add file handler 180 | file_handler = logging.FileHandler(log_file, 'w') 181 | file_handler.setFormatter(logging.Formatter(format_str)) 182 | file_handler.setLevel(log_level) 183 | logger.addHandler(file_handler) 184 | initialized_logger[logger_name] = True 185 | return logger 186 | 187 | 188 | def get_env_info(): 189 | """Get environment information. 190 | 191 | Currently, only log the software version. 192 | """ 193 | import torch 194 | import torchvision 195 | 196 | from basicsr.version import __version__ 197 | msg = r""" 198 | ____ _ _____ ____ 199 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 200 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 201 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 202 | /_____/ \__,_//____//_/ \___//____//_/ |_| 203 | ______ __ __ __ __ 204 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 205 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 206 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 207 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 208 | """ 209 | msg += ('\nVersion Information: ' 210 | f'\n\tBasicSR: {__version__}' 211 | f'\n\tPyTorch: {torch.__version__}' 212 | f'\n\tTorchVision: {torchvision.__version__}') 213 | return msg 214 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/matlab_functions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def cubic(x): 7 | """cubic function used for calculate_weights_indices.""" 8 | absx = torch.abs(x) 9 | absx2 = absx**2 10 | absx3 = absx**3 11 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ( 12 | (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * 13 | (absx <= 2)).type_as(absx)) 14 | 15 | 16 | def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): 17 | """Calculate weights and indices, used for imresize function. 18 | 19 | Args: 20 | in_length (int): Input length. 21 | out_length (int): Output length. 22 | scale (float): Scale factor. 23 | kernel_width (int): Kernel width. 24 | antialisaing (bool): Whether to apply anti-aliasing when downsampling. 25 | """ 26 | 27 | if (scale < 1) and antialiasing: 28 | # Use a modified kernel (larger kernel width) to simultaneously 29 | # interpolate and antialias 30 | kernel_width = kernel_width / scale 31 | 32 | # Output-space coordinates 33 | x = torch.linspace(1, out_length, out_length) 34 | 35 | # Input-space coordinates. Calculate the inverse mapping such that 0.5 36 | # in output space maps to 0.5 in input space, and 0.5 + scale in output 37 | # space maps to 1.5 in input space. 38 | u = x / scale + 0.5 * (1 - 1 / scale) 39 | 40 | # What is the left-most pixel that can be involved in the computation? 41 | left = torch.floor(u - kernel_width / 2) 42 | 43 | # What is the maximum number of pixels that can be involved in the 44 | # computation? Note: it's OK to use an extra pixel here; if the 45 | # corresponding weights are all zero, it will be eliminated at the end 46 | # of this function. 47 | p = math.ceil(kernel_width) + 2 48 | 49 | # The indices of the input pixels involved in computing the k-th output 50 | # pixel are in row k of the indices matrix. 51 | indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( 52 | out_length, p) 53 | 54 | # The weights used to compute the k-th output pixel are in row k of the 55 | # weights matrix. 56 | distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices 57 | 58 | # apply cubic kernel 59 | if (scale < 1) and antialiasing: 60 | weights = scale * cubic(distance_to_center * scale) 61 | else: 62 | weights = cubic(distance_to_center) 63 | 64 | # Normalize the weights matrix so that each row sums to 1. 65 | weights_sum = torch.sum(weights, 1).view(out_length, 1) 66 | weights = weights / weights_sum.expand(out_length, p) 67 | 68 | # If a column in weights is all zero, get rid of it. only consider the 69 | # first and last column. 70 | weights_zero_tmp = torch.sum((weights == 0), 0) 71 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): 72 | indices = indices.narrow(1, 1, p - 2) 73 | weights = weights.narrow(1, 1, p - 2) 74 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): 75 | indices = indices.narrow(1, 0, p - 2) 76 | weights = weights.narrow(1, 0, p - 2) 77 | weights = weights.contiguous() 78 | indices = indices.contiguous() 79 | sym_len_s = -indices.min() + 1 80 | sym_len_e = indices.max() - in_length 81 | indices = indices + sym_len_s - 1 82 | return weights, indices, int(sym_len_s), int(sym_len_e) 83 | 84 | 85 | @torch.no_grad() 86 | def imresize(img, scale, antialiasing=True): 87 | """imresize function same as MATLAB. 88 | 89 | It now only supports bicubic. 90 | The same scale applies for both height and width. 91 | 92 | Args: 93 | img (Tensor | Numpy array): 94 | Tensor: Input image with shape (c, h, w), [0, 1] range. 95 | Numpy: Input image with shape (h, w, c), [0, 1] range. 96 | scale (float): Scale factor. The same scale applies for both height 97 | and width. 98 | antialisaing (bool): Whether to apply anti-aliasing when downsampling. 99 | Default: True. 100 | 101 | Returns: 102 | Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. 103 | """ 104 | squeeze_flag = False 105 | if type(img).__module__ == np.__name__: # numpy type 106 | numpy_type = True 107 | if img.ndim == 2: 108 | img = img[:, :, None] 109 | squeeze_flag = True 110 | img = torch.from_numpy(img.transpose(2, 0, 1)).float() 111 | else: 112 | numpy_type = False 113 | if img.ndim == 2: 114 | img = img.unsqueeze(0) 115 | squeeze_flag = True 116 | 117 | in_c, in_h, in_w = img.size() 118 | out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) 119 | kernel_width = 4 120 | kernel = 'cubic' 121 | 122 | # get weights and indices 123 | weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, 124 | antialiasing) 125 | weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, 126 | antialiasing) 127 | # process H dimension 128 | # symmetric copying 129 | img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) 130 | img_aug.narrow(1, sym_len_hs, in_h).copy_(img) 131 | 132 | sym_patch = img[:, :sym_len_hs, :] 133 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 134 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 135 | img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) 136 | 137 | sym_patch = img[:, -sym_len_he:, :] 138 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 139 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 140 | img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) 141 | 142 | out_1 = torch.FloatTensor(in_c, out_h, in_w) 143 | kernel_width = weights_h.size(1) 144 | for i in range(out_h): 145 | idx = int(indices_h[i][0]) 146 | for j in range(in_c): 147 | out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) 148 | 149 | # process W dimension 150 | # symmetric copying 151 | out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) 152 | out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) 153 | 154 | sym_patch = out_1[:, :, :sym_len_ws] 155 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 156 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 157 | out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) 158 | 159 | sym_patch = out_1[:, :, -sym_len_we:] 160 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 161 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 162 | out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) 163 | 164 | out_2 = torch.FloatTensor(in_c, out_h, out_w) 165 | kernel_width = weights_w.size(1) 166 | for i in range(out_w): 167 | idx = int(indices_w[i][0]) 168 | for j in range(in_c): 169 | out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) 170 | 171 | if squeeze_flag: 172 | out_2 = out_2.squeeze(0) 173 | if numpy_type: 174 | out_2 = out_2.numpy() 175 | if not squeeze_flag: 176 | out_2 = out_2.transpose(1, 2, 0) 177 | 178 | return out_2 179 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | 10 | 11 | def set_random_seed(seed): 12 | """Set random seeds.""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def get_time_str(): 21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 22 | 23 | 24 | def mkdir_and_rename(path): 25 | """mkdirs. If path exists, rename it with timestamp and create a new one. 26 | 27 | Args: 28 | path (str): Folder path. 29 | """ 30 | if osp.exists(path): 31 | new_name = path + '_archived_' + get_time_str() 32 | print(f'Path already exists. Rename it to {new_name}', flush=True) 33 | os.rename(path, new_name) 34 | os.makedirs(path, exist_ok=True) 35 | 36 | 37 | @master_only 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): 47 | continue 48 | else: 49 | os.makedirs(path, exist_ok=True) 50 | 51 | 52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 53 | """Scan a directory to find the interested files. 54 | 55 | Args: 56 | dir_path (str): Path of the directory. 57 | suffix (str | tuple(str), optional): File suffix that we are 58 | interested in. Default: None. 59 | recursive (bool, optional): If set to True, recursively scan the 60 | directory. Default: False. 61 | full_path (bool, optional): If set to True, include the dir_path. 62 | Default: False. 63 | 64 | Returns: 65 | A generator for all the interested files with relative paths. 66 | """ 67 | 68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 69 | raise TypeError('"suffix" must be a string or tuple of strings') 70 | 71 | root = dir_path 72 | 73 | def _scandir(dir_path, suffix, recursive): 74 | for entry in os.scandir(dir_path): 75 | if not entry.name.startswith('.') and entry.is_file(): 76 | if full_path: 77 | return_path = entry.path 78 | else: 79 | return_path = osp.relpath(entry.path, root) 80 | 81 | if suffix is None: 82 | yield return_path 83 | elif return_path.endswith(suffix): 84 | yield return_path 85 | else: 86 | if recursive: 87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 88 | else: 89 | continue 90 | 91 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 92 | 93 | 94 | def check_resume(opt, resume_iter): 95 | """Check resume states and pretrain_network paths. 96 | 97 | Args: 98 | opt (dict): Options. 99 | resume_iter (int): Resume iteration. 100 | """ 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | print('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') 114 | if opt['path'].get('ignore_resume_networks') is None or (network 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | print(f"Set {name} to {opt['path'][name]}") 118 | 119 | # change param_key to params in resume 120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] 121 | for param_key in param_keys: 122 | if opt['path'][param_key] == 'params_ema': 123 | opt['path'][param_key] = 'params' 124 | print(f'Set {param_key} to params') 125 | 126 | 127 | def sizeof_fmt(size, suffix='B'): 128 | """Get human readable file size. 129 | 130 | Args: 131 | size (int): File size. 132 | suffix (str): Suffix. Default: 'B'. 133 | 134 | Return: 135 | str: Formatted file size. 136 | """ 137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 138 | if abs(size) < 1024.0: 139 | return f'{size:3.1f} {unit}{suffix}' 140 | size /= 1024.0 141 | return f'{size:3.1f} Y{suffix}' 142 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import torch 5 | import yaml 6 | from collections import OrderedDict 7 | from os import path as osp 8 | 9 | from .misc import set_random_seed 10 | from .dist_util import get_dist_info, init_dist, master_only 11 | 12 | 13 | def ordered_yaml(): 14 | """Support OrderedDict for yaml. 15 | 16 | Returns: 17 | tuple: yaml Loader and Dumper. 18 | """ 19 | try: 20 | from yaml import CDumper as Dumper 21 | from yaml import CLoader as Loader 22 | except ImportError: 23 | from yaml import Dumper, Loader 24 | 25 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 26 | 27 | def dict_representer(dumper, data): 28 | return dumper.represent_dict(data.items()) 29 | 30 | def dict_constructor(loader, node): 31 | return OrderedDict(loader.construct_pairs(node)) 32 | 33 | Dumper.add_representer(OrderedDict, dict_representer) 34 | Loader.add_constructor(_mapping_tag, dict_constructor) 35 | return Loader, Dumper 36 | 37 | 38 | def yaml_load(f): 39 | """Load yaml file or string. 40 | 41 | Args: 42 | f (str): File path or a python string. 43 | 44 | Returns: 45 | dict: Loaded dict. 46 | """ 47 | if os.path.isfile(f): 48 | with open(f, 'r') as f: 49 | return yaml.load(f, Loader=ordered_yaml()[0]) 50 | else: 51 | return yaml.load(f, Loader=ordered_yaml()[0]) 52 | 53 | 54 | def dict2str(opt, indent_level=1): 55 | """dict to string for printing options. 56 | 57 | Args: 58 | opt (dict): Option dict. 59 | indent_level (int): Indent level. Default: 1. 60 | 61 | Return: 62 | (str): Option string for printing. 63 | """ 64 | msg = '\n' 65 | for k, v in opt.items(): 66 | if isinstance(v, dict): 67 | msg += ' ' * (indent_level * 2) + k + ':[' 68 | msg += dict2str(v, indent_level + 1) 69 | msg += ' ' * (indent_level * 2) + ']\n' 70 | else: 71 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 72 | return msg 73 | 74 | 75 | def _postprocess_yml_value(value): 76 | # None 77 | if value == '~' or value.lower() == 'none': 78 | return None 79 | # bool 80 | if value.lower() == 'true': 81 | return True 82 | elif value.lower() == 'false': 83 | return False 84 | # !!float number 85 | if value.startswith('!!float'): 86 | return float(value.replace('!!float', '')) 87 | # number 88 | if value.isdigit(): 89 | return int(value) 90 | elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: 91 | return float(value) 92 | # list 93 | if value.startswith('['): 94 | return eval(value) 95 | # str 96 | return value 97 | 98 | 99 | def parse_options(root_path, is_train=True): 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') 102 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') 103 | parser.add_argument('--auto_resume', action='store_true') 104 | parser.add_argument('--debug', action='store_true') 105 | parser.add_argument('--local_rank', type=int, default=0) 106 | parser.add_argument( 107 | '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') 108 | args = parser.parse_args() 109 | 110 | # parse yml to dict 111 | opt = yaml_load(args.opt) 112 | 113 | # distributed settings 114 | if args.launcher == 'none': 115 | opt['dist'] = False 116 | print('Disable distributed.', flush=True) 117 | else: 118 | opt['dist'] = True 119 | if args.launcher == 'slurm' and 'dist_params' in opt: 120 | init_dist(args.launcher, **opt['dist_params']) 121 | else: 122 | init_dist(args.launcher) 123 | opt['rank'], opt['world_size'] = get_dist_info() 124 | 125 | # random seed 126 | seed = opt.get('manual_seed') 127 | if seed is None: 128 | seed = random.randint(1, 10000) 129 | opt['manual_seed'] = seed 130 | set_random_seed(seed + opt['rank']) 131 | 132 | # force to update yml options 133 | if args.force_yml is not None: 134 | for entry in args.force_yml: 135 | # now do not support creating new keys 136 | keys, value = entry.split('=') 137 | keys, value = keys.strip(), value.strip() 138 | value = _postprocess_yml_value(value) 139 | eval_str = 'opt' 140 | for key in keys.split(':'): 141 | eval_str += f'["{key}"]' 142 | eval_str += '=value' 143 | # using exec function 144 | exec(eval_str) 145 | 146 | opt['auto_resume'] = args.auto_resume 147 | opt['is_train'] = is_train 148 | 149 | # debug setting 150 | if args.debug and not opt['name'].startswith('debug'): 151 | opt['name'] = 'debug_' + opt['name'] 152 | 153 | if opt['num_gpu'] == 'auto': 154 | opt['num_gpu'] = torch.cuda.device_count() 155 | 156 | # datasets 157 | for phase, dataset in opt['datasets'].items(): 158 | # for multiple datasets, e.g., val_1, val_2; test_1, test_2 159 | phase = phase.split('_')[0] 160 | dataset['phase'] = phase 161 | if 'scale' in opt: 162 | dataset['scale'] = opt['scale'] 163 | if dataset.get('dataroot_gt') is not None: 164 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 165 | if dataset.get('dataroot_lq') is not None: 166 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 167 | 168 | # paths 169 | for key, val in opt['path'].items(): 170 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 171 | opt['path'][key] = osp.expanduser(val) 172 | 173 | if is_train: 174 | experiments_root = opt['path'].get('experiments_root') 175 | if experiments_root is None: 176 | experiments_root = osp.join(root_path, 'experiments') 177 | experiments_root = osp.join(experiments_root, opt['name']) 178 | 179 | opt['path']['experiments_root'] = experiments_root 180 | opt['path']['models'] = osp.join(experiments_root, 'models') 181 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 182 | opt['path']['log'] = experiments_root 183 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 184 | 185 | # change some options for debug mode 186 | if 'debug' in opt['name']: 187 | if 'val' in opt: 188 | opt['val']['val_freq'] = 8 189 | opt['logger']['print_freq'] = 1 190 | opt['logger']['save_checkpoint_freq'] = 8 191 | else: # test 192 | results_root = opt['path'].get('results_root') 193 | if results_root is None: 194 | results_root = osp.join(root_path, 'results') 195 | results_root = osp.join(results_root, opt['name']) 196 | 197 | opt['path']['results_root'] = results_root 198 | opt['path']['log'] = results_root 199 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 200 | 201 | return opt, args 202 | 203 | 204 | @master_only 205 | def copy_opt_file(opt_file, experiments_root): 206 | # copy the yml file to the experiment root 207 | import sys 208 | import time 209 | from shutil import copyfile 210 | cmd = ' '.join(sys.argv) 211 | filename = osp.join(experiments_root, osp.basename(opt_file)) 212 | copyfile(opt_file, filename) 213 | 214 | with open(filename, 'r+') as f: 215 | lines = f.readlines() 216 | lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') 217 | f.seek(0) 218 | f.writelines(lines) 219 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/plot_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def read_data_from_tensorboard(log_path, tag): 5 | """Get raw data (steps and values) from tensorboard events. 6 | 7 | Args: 8 | log_path (str): Path to the tensorboard log. 9 | tag (str): tag to be read. 10 | """ 11 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 12 | 13 | # tensorboard event 14 | event_acc = EventAccumulator(log_path) 15 | event_acc.Reload() 16 | scalar_list = event_acc.Tags()['scalars'] 17 | print('tag list: ', scalar_list) 18 | steps = [int(s.step) for s in event_acc.Scalars(tag)] 19 | values = [s.value for s in event_acc.Scalars(tag)] 20 | return steps, values 21 | 22 | 23 | def read_data_from_txt_2v(path, pattern, step_one=False): 24 | """Read data from txt with 2 returned values (usually [step, value]). 25 | 26 | Args: 27 | path (str): path to the txt file. 28 | pattern (str): re (regular expression) pattern. 29 | step_one (bool): add 1 to steps. Default: False. 30 | """ 31 | with open(path) as f: 32 | lines = f.readlines() 33 | lines = [line.strip() for line in lines] 34 | steps = [] 35 | values = [] 36 | 37 | pattern = re.compile(pattern) 38 | for line in lines: 39 | match = pattern.match(line) 40 | if match: 41 | steps.append(int(match.group(1))) 42 | values.append(float(match.group(2))) 43 | if step_one: 44 | steps = [v + 1 for v in steps] 45 | return steps, values 46 | 47 | 48 | def read_data_from_txt_1v(path, pattern): 49 | """Read data from txt with 1 returned values. 50 | 51 | Args: 52 | path (str): path to the txt file. 53 | pattern (str): re (regular expression) pattern. 54 | """ 55 | with open(path) as f: 56 | lines = f.readlines() 57 | lines = [line.strip() for line in lines] 58 | data = [] 59 | 60 | pattern = re.compile(pattern) 61 | for line in lines: 62 | match = pattern.match(line) 63 | if match: 64 | data.append(float(match.group(1))) 65 | return data 66 | 67 | 68 | def smooth_data(values, smooth_weight): 69 | """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does). 70 | 71 | Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501 72 | 73 | Args: 74 | values (list): A list of values to be smoothed. 75 | smooth_weight (float): Smooth weight. 76 | """ 77 | values_sm = [] 78 | last_sm_value = values[0] 79 | for value in values: 80 | value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value 81 | values_sm.append(value_sm) 82 | last_sm_value = value_sm 83 | return values_sm 84 | -------------------------------------------------------------------------------- /pasd/dataloader/basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj, suffix=None): 39 | if isinstance(suffix, str): 40 | name = name + '_' + suffix 41 | 42 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 43 | f"in '{self._name}' registry!") 44 | self._obj_map[name] = obj 45 | 46 | def register(self, obj=None, suffix=None): 47 | """ 48 | Register the given object under the the name `obj.__name__`. 49 | Can be used as either a decorator or not. 50 | See docstring of this class for usage. 51 | """ 52 | if obj is None: 53 | # used as a decorator 54 | def deco(func_or_class): 55 | name = func_or_class.__name__ 56 | self._do_register(name, func_or_class, suffix) 57 | return func_or_class 58 | 59 | return deco 60 | 61 | # used as a function call 62 | name = obj.__name__ 63 | self._do_register(name, obj, suffix) 64 | 65 | def get(self, name, suffix='basicsr'): 66 | ret = self._obj_map.get(name) 67 | if ret is None: 68 | ret = self._obj_map.get(name + '_' + suffix) 69 | print(f'Name {name} is not found, use name: {name}_{suffix}!') 70 | if ret is None: 71 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 72 | return ret 73 | 74 | def __contains__(self, name): 75 | return name in self._obj_map 76 | 77 | def __iter__(self): 78 | return iter(self._obj_map.items()) 79 | 80 | def keys(self): 81 | return self._obj_map.keys() 82 | 83 | 84 | DATASET_REGISTRY = Registry('dataset') 85 | ARCH_REGISTRY = Registry('arch') 86 | MODEL_REGISTRY = Registry('model') 87 | LOSS_REGISTRY = Registry('loss') 88 | METRIC_REGISTRY = Registry('metric') 89 | -------------------------------------------------------------------------------- /pasd/dataloader/localdatasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import random 5 | import numpy as np 6 | from PIL import Image 7 | from functools import partial 8 | 9 | from torch import nn 10 | from torchvision import transforms 11 | from torch.utils import data as data 12 | 13 | from .realesrgan import RealESRGAN_degradation 14 | from ..myutils.img_util import convert_image_to_fn 15 | from ..myutils.misc import exists 16 | 17 | class LocalImageDataset(data.Dataset): 18 | def __init__(self, 19 | pngtxt_dir="datasets/pngtxt", 20 | image_size=512, 21 | tokenizer=None, 22 | accelerator=None, 23 | control_type=None, 24 | null_text_ratio=0.0, 25 | center_crop=False, 26 | random_flip=True, 27 | resize_bak=True, 28 | convert_image_to="RGB", 29 | ): 30 | super(LocalImageDataset, self).__init__() 31 | self.tokenizer = tokenizer 32 | self.control_type = control_type 33 | self.resize_bak = resize_bak 34 | self.null_text_ratio = null_text_ratio 35 | 36 | self.degradation = RealESRGAN_degradation('params_realesrgan.yml', device='cpu') 37 | 38 | maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity() 39 | self.crop_preproc = transforms.Compose([ 40 | transforms.Lambda(maybe_convert_fn), 41 | #transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR), 42 | transforms.CenterCrop(image_size) if center_crop else transforms.RandomCrop(image_size), 43 | transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x), 44 | ]) 45 | self.img_preproc = transforms.Compose([ 46 | #transforms.Lambda(maybe_convert_fn), 47 | #transforms.Resize(image_size), 48 | #transforms.CenterCrop(image_size), 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 51 | ]) 52 | 53 | self.img_paths = [] 54 | folders = os.listdir(pngtxt_dir) 55 | for folder in folders: 56 | self.img_paths.extend(sorted(glob.glob(f'{pngtxt_dir}/{folder}/*.png'))[:]) 57 | 58 | def tokenize_caption(self, caption): 59 | if random.random() < self.null_text_ratio: 60 | caption = "" 61 | 62 | inputs = self.tokenizer( 63 | caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 64 | ) 65 | 66 | return inputs.input_ids 67 | 68 | def __getitem__(self, index): 69 | example = dict() 70 | 71 | # load image 72 | img_path = self.img_paths[index] 73 | txt_path = img_path.replace(".png", ".txt") 74 | image = Image.open(img_path).convert('RGB') 75 | 76 | image = self.crop_preproc(image) 77 | 78 | example["pixel_values"] = self.img_preproc(image) 79 | if self.control_type is not None: 80 | if self.control_type == 'realisr': 81 | GT_image_t, LR_image_t = self.degradation.degrade_process(np.asarray(image)/255., resize_bak=self.resize_bak) 82 | example["conditioning_pixel_values"] = LR_image_t.squeeze(0) 83 | example["pixel_values"] = GT_image_t.squeeze(0) * 2.0 - 1.0 84 | elif self.control_type == 'grayscale': 85 | image = np.asarray(image.convert('L').convert('RGB'))/255. 86 | example["conditioning_pixel_values"] = torch.from_numpy(image).permute(2,0,1) 87 | else: 88 | raise NotImplementedError 89 | 90 | fp = open(txt_path, "r") 91 | caption = fp.readlines()[0] 92 | if self.tokenizer is not None: 93 | example["input_ids"] = self.tokenize_caption(caption).squeeze(0) 94 | fp.close() 95 | 96 | return example 97 | 98 | def __len__(self): 99 | return len(self.img_paths) -------------------------------------------------------------------------------- /pasd/dataloader/params_realesrgan.yml: -------------------------------------------------------------------------------- 1 | scale: 4 2 | color_jitter_prob: 0.0 3 | gray_prob: 0.0 4 | 5 | # the first degradation process 6 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep 7 | resize_range: [0.3, 1.5] 8 | gaussian_noise_prob: 0.5 9 | noise_range: [1, 15] 10 | poisson_scale_range: [0.05, 2.0] 11 | gray_noise_prob: 0.4 12 | jpeg_range: [60, 95] 13 | 14 | # the second degradation process 15 | second_blur_prob: 0.5 16 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep 17 | resize_range2: [0.6, 1.2] 18 | gaussian_noise_prob2: 0.5 19 | noise_range2: [1, 12] 20 | poisson_scale_range2: [0.05, 1.0] 21 | gray_noise_prob2: 0.4 22 | jpeg_range2: [60, 100] 23 | 24 | kernel_info: 25 | blur_kernel_size: 21 26 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 27 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 28 | sinc_prob: 0.1 29 | blur_sigma: [0.2, 3] 30 | betag_range: [0.5, 4] 31 | betap_range: [1, 2] 32 | 33 | blur_kernel_size2: 21 34 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] 35 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] 36 | sinc_prob2: 0.1 37 | blur_sigma2: [0.2, 1.5] 38 | betag_range2: [0.5, 4] 39 | betap_range2: [1, 2] 40 | 41 | final_sinc_prob: 0.8 42 | 43 | 44 | -------------------------------------------------------------------------------- /pasd/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/models/__init__.py -------------------------------------------------------------------------------- /pasd/models/pasd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/models/pasd/__init__.py -------------------------------------------------------------------------------- /pasd/models/pasd_light/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/models/pasd_light/__init__.py -------------------------------------------------------------------------------- /pasd/myutils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/myutils/__init__.py -------------------------------------------------------------------------------- /pasd/myutils/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ Conversion script for the LoRA's safetensors checkpoints. """ 17 | import os 18 | import argparse 19 | 20 | import torch 21 | from safetensors.torch import load_file 22 | 23 | from diffusers import StableDiffusionPipeline 24 | import pdb 25 | from collections import defaultdict 26 | 27 | def convert(base_model_path, checkpoint_path, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=1.0): 28 | # load base model 29 | pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 30 | 31 | # load LoRA weight from .safetensors 32 | state_dict = load_file(checkpoint_path) 33 | 34 | visited = [] 35 | 36 | # directly update weight in diffusers model 37 | for key in state_dict: 38 | # it is suggested to print out the key, it usually will be something like below 39 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 40 | 41 | # as we have set the alpha beforehand, so just skip 42 | if ".alpha" in key or key in visited: 43 | continue 44 | 45 | if "text" in key: 46 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 47 | curr_layer = pipeline.text_encoder 48 | else: 49 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 50 | curr_layer = pipeline.unet 51 | 52 | # find the target layer 53 | temp_name = layer_infos.pop(0) 54 | while len(layer_infos) > -1: 55 | try: 56 | curr_layer = curr_layer.__getattr__(temp_name) 57 | if len(layer_infos) > 0: 58 | temp_name = layer_infos.pop(0) 59 | elif len(layer_infos) == 0: 60 | break 61 | except Exception: 62 | if len(temp_name) > 0: 63 | temp_name += "_" + layer_infos.pop(0) 64 | else: 65 | temp_name = layer_infos.pop(0) 66 | 67 | pair_keys = [] 68 | if "lora_down" in key: 69 | pair_keys.append(key.replace("lora_down", "lora_up")) 70 | pair_keys.append(key) 71 | else: 72 | pair_keys.append(key) 73 | pair_keys.append(key.replace("lora_up", "lora_down")) 74 | 75 | # update weight 76 | if len(state_dict[pair_keys[0]].shape) == 4: 77 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 78 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 79 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) 80 | else: 81 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 82 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 83 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) 84 | 85 | # update visited list 86 | for item in pair_keys: 87 | visited.append(item) 88 | 89 | return pipeline 90 | 91 | def convert_lora(unet, text_encoder, lora_path, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", multiplier=0.6, device='cuda'): 92 | # load base model 93 | #pipeline.to(device) 94 | # load LoRA weight from .safetensors 95 | state_dict = load_file(lora_path, device=device) if isinstance(lora_path, str) else lora_path 96 | visited = [] 97 | 98 | # directly update weight in diffusers model 99 | for key in state_dict: 100 | # it is suggested to print out the key, it usually will be something like below 101 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 102 | 103 | # as we have set the alpha beforehand, so just skip 104 | if ".alpha" in key or key in visited: 105 | continue 106 | 107 | if "text" in key: 108 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 109 | #curr_layer = pipeline.text_encoder 110 | curr_layer = text_encoder 111 | else: 112 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 113 | #curr_layer = pipeline.unet 114 | curr_layer = unet 115 | 116 | # find the target layer 117 | temp_name = layer_infos.pop(0) 118 | while len(layer_infos) > -1: 119 | try: 120 | curr_layer = curr_layer.__getattr__(temp_name) 121 | if len(layer_infos) > 0: 122 | temp_name = layer_infos.pop(0) 123 | elif len(layer_infos) == 0: 124 | break 125 | except Exception: 126 | if len(temp_name) > 0: 127 | temp_name += "_" + layer_infos.pop(0) 128 | else: 129 | temp_name = layer_infos.pop(0) 130 | 131 | pair_keys = [] 132 | if "lora_down" in key: 133 | pair_keys.append(key.replace("lora_down", "lora_up")) 134 | pair_keys.append(key) 135 | else: 136 | pair_keys.append(key) 137 | pair_keys.append(key.replace("lora_up", "lora_down")) 138 | 139 | # update weight 140 | if len(state_dict[pair_keys[0]].shape) == 4: 141 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 142 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 143 | curr_layer.weight.data += multiplier * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) 144 | else: 145 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 146 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 147 | curr_layer.weight.data += multiplier * torch.mm(weight_up, weight_down) 148 | 149 | # update visited list 150 | for item in pair_keys: 151 | visited.append(item) 152 | 153 | return unet, text_encoder 154 | 155 | 156 | if __name__ == "__main__": 157 | parser = argparse.ArgumentParser() 158 | 159 | parser.add_argument( 160 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." 161 | ) 162 | parser.add_argument( 163 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 164 | ) 165 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 166 | parser.add_argument( 167 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 168 | ) 169 | parser.add_argument( 170 | "--lora_prefix_text_encoder", 171 | default="lora_te", 172 | type=str, 173 | help="The prefix of text encoder weight in safetensors", 174 | ) 175 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") 176 | parser.add_argument( 177 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." 178 | ) 179 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 180 | 181 | args = parser.parse_args() 182 | 183 | base_model_path = args.base_model_path 184 | checkpoint_path = args.checkpoint_path 185 | dump_path = args.dump_path 186 | lora_prefix_unet = args.lora_prefix_unet 187 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 188 | alpha = args.alpha 189 | 190 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) 191 | 192 | pipe = pipe.to(args.device) 193 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 194 | -------------------------------------------------------------------------------- /pasd/myutils/devices.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import contextlib 3 | from functools import lru_cache 4 | 5 | import torch 6 | #from modules import errors 7 | 8 | if sys.platform == "darwin": 9 | # from modules import mac_specific 10 | raise NotImplementedError("Mac is not yet supported") 11 | 12 | 13 | def has_mps() -> bool: 14 | if sys.platform != "darwin": 15 | return False 16 | else: 17 | return mac_specific.has_mps 18 | 19 | 20 | def get_cuda_device_string(): 21 | return "cuda" 22 | 23 | 24 | def get_optimal_device_name(): 25 | if torch.cuda.is_available(): 26 | return get_cuda_device_string() 27 | 28 | if has_mps(): 29 | return "mps" 30 | 31 | return "cpu" 32 | 33 | 34 | def get_optimal_device(): 35 | return torch.device(get_optimal_device_name()) 36 | 37 | 38 | def get_device_for(task): 39 | return get_optimal_device() 40 | 41 | 42 | def torch_gc(): 43 | 44 | if torch.cuda.is_available(): 45 | with torch.cuda.device(get_cuda_device_string()): 46 | torch.cuda.empty_cache() 47 | torch.cuda.ipc_collect() 48 | 49 | if has_mps(): 50 | mac_specific.torch_mps_gc() 51 | 52 | 53 | def enable_tf32(): 54 | if torch.cuda.is_available(): 55 | 56 | # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't 57 | # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 58 | if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): 59 | torch.backends.cudnn.benchmark = True 60 | 61 | torch.backends.cuda.matmul.allow_tf32 = True 62 | torch.backends.cudnn.allow_tf32 = True 63 | 64 | 65 | enable_tf32() 66 | #errors.run(enable_tf32, "Enabling TF32") 67 | 68 | cpu = torch.device("cpu") 69 | device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda") 70 | dtype = torch.float16 71 | dtype_vae = torch.float16 72 | dtype_unet = torch.float16 73 | unet_needs_upcast = False 74 | 75 | 76 | def cond_cast_unet(input): 77 | return input.to(dtype_unet) if unet_needs_upcast else input 78 | 79 | 80 | def cond_cast_float(input): 81 | return input.float() if unet_needs_upcast else input 82 | 83 | 84 | def randn(seed, shape): 85 | torch.manual_seed(seed) 86 | return torch.randn(shape, device=device) 87 | 88 | 89 | def randn_without_seed(shape): 90 | return torch.randn(shape, device=device) 91 | 92 | 93 | def autocast(disable=False): 94 | if disable: 95 | return contextlib.nullcontext() 96 | 97 | return torch.autocast("cuda") 98 | 99 | 100 | def without_autocast(disable=False): 101 | return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() 102 | 103 | 104 | class NansException(Exception): 105 | pass 106 | 107 | 108 | def test_for_nans(x, where): 109 | if not torch.all(torch.isnan(x)).item(): 110 | return 111 | 112 | if where == "unet": 113 | message = "A tensor with all NaNs was produced in Unet." 114 | 115 | elif where == "vae": 116 | message = "A tensor with all NaNs was produced in VAE." 117 | 118 | else: 119 | message = "A tensor with all NaNs was produced." 120 | 121 | message += " Use --disable-nan-check commandline argument to disable this check." 122 | 123 | raise NansException(message) 124 | 125 | 126 | @lru_cache 127 | def first_time_calculation(): 128 | """ 129 | just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and 130 | spends about 2.7 seconds doing that, at least wih NVidia. 131 | """ 132 | 133 | x = torch.zeros((1, 1)).to(device, dtype) 134 | linear = torch.nn.Linear(1, 1).to(device, dtype) 135 | linear(x) 136 | 137 | x = torch.zeros((1, 1, 3, 3)).to(device, dtype) 138 | conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) 139 | conv2d(x) 140 | -------------------------------------------------------------------------------- /pasd/myutils/img_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import PIL 3 | import cv2 4 | import math 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | import imageio 9 | 10 | from einops import rearrange 11 | 12 | def save_videos_grid(videos, path=None, rescale=True, n_rows=4, fps=8, discardN=0): 13 | videos = rearrange(videos, "b c t h w -> t b c h w").cpu() 14 | outputs = [] 15 | for x in videos: 16 | x = torchvision.utils.make_grid(x, nrow=n_rows) 17 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 18 | if rescale: 19 | x = (x / 2.0 + 0.5).clamp(0, 1) # -1,1 -> 0,1 20 | x = (x * 255).numpy().astype(np.uint8) 21 | #x = adjust_gamma(x, 0.5) 22 | outputs.append(x) 23 | 24 | outputs = outputs[discardN:] 25 | 26 | if path is not None: 27 | #os.makedirs(os.path.dirname(path), exist_ok=True) 28 | imageio.mimsave(path, outputs, duration=1000/fps, loop=0) 29 | 30 | return outputs 31 | 32 | def convert_image_to_fn(img_type, image, minsize=512, eps=0.02): 33 | width, height = image.size 34 | if min(width, height) < minsize: 35 | scale = minsize/min(width, height) + eps 36 | image = image.resize((math.ceil(width*scale), math.ceil(height*scale))) 37 | 38 | if image.mode != img_type: 39 | return image.convert(img_type) 40 | return image 41 | 42 | def colorful_loss(pred): 43 | colorfulness_loss = 0 44 | for i in range(pred.shape[0]): 45 | (R, G, B) = pred[i][0], pred[i][1], pred[i][2] 46 | rg = torch.abs(R - G) 47 | yb = torch.abs(0.5 * (R+G) - B) 48 | (rbMean, rbStd) = (torch.mean(rg), torch.std(rg)) 49 | (ybMean, ybStd) = (torch.mean(yb), torch.std(yb)) 50 | stdRoot = torch.sqrt((rbStd ** 2) + (ybStd ** 2)) 51 | meanRoot = torch.sqrt((rbMean ** 2) + (ybMean ** 2)) 52 | colorfulness = stdRoot + (0.3 * meanRoot) 53 | colorfulness_loss += (1 - colorfulness) 54 | return colorfulness_loss -------------------------------------------------------------------------------- /pasd/myutils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import binascii 3 | from safetensors import safe_open 4 | 5 | import torch 6 | 7 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, convert_ldm_clip_checkpoint 8 | from .convert_lora_safetensor_to_diffusers import convert_lora 9 | 10 | def rand_name(length=8, suffix=''): 11 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') 12 | if suffix: 13 | if not suffix.startswith('.'): 14 | suffix = '.' + suffix 15 | name += suffix 16 | return name 17 | 18 | def cycle(dl): 19 | while True: 20 | for data in dl: 21 | yield data 22 | 23 | def exists(x): 24 | return x is not None 25 | 26 | def identity(x): 27 | return x 28 | 29 | def load_dreambooth_lora(unet, vae=None, text_encoder=None, model_path=None, blending_alpha=1.0, multiplier=0.6, model_base=None): 30 | if model_path is None: return unet 31 | 32 | if model_path.endswith(".ckpt"): 33 | base_state_dict = torch.load(model_path)['state_dict'] 34 | elif model_path.endswith(".safetensors"): 35 | state_dict = {} 36 | with safe_open(model_path, framework="pt", device="cpu") as f: 37 | for key in f.keys(): 38 | state_dict[key] = f.get_tensor(key) 39 | 40 | is_lora = all("lora" in k for k in state_dict.keys()) 41 | if not is_lora: 42 | base_state_dict = state_dict 43 | else: 44 | base_state_dict = {} 45 | if model_base is not None: 46 | with safe_open(model_base, framework="pt", device="cpu") as f: 47 | for key in f.keys(): 48 | base_state_dict[key] = f.get_tensor(key) 49 | 50 | if base_state_dict: 51 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, unet.config) 52 | 53 | unet_state_dict = unet.state_dict() 54 | for key in converted_unet_checkpoint: 55 | if key in unet_state_dict: 56 | converted_unet_checkpoint[key] = converted_unet_checkpoint[key] * blending_alpha + unet_state_dict[key] * (1.0 - blending_alpha) 57 | else: 58 | print(key) 59 | 60 | unet.load_state_dict(converted_unet_checkpoint, strict=False) 61 | 62 | if vae is not None: 63 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, vae.config) 64 | vae.load_state_dict(converted_vae_checkpoint) 65 | 66 | if text_encoder is not None: 67 | text_encoder = convert_ldm_clip_checkpoint(base_state_dict) 68 | 69 | if is_lora: 70 | unet, text_encoder = convert_lora(unet, text_encoder, state_dict, multiplier=multiplier) 71 | 72 | return unet, vae, text_encoder -------------------------------------------------------------------------------- /pasd/myutils/wavelet_color_fix.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # -------------------------------------------------------------------------------- 3 | # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py) 4 | # -------------------------------------------------------------------------------- 5 | ''' 6 | 7 | import torch 8 | from PIL import Image 9 | from torch import Tensor 10 | from torch.nn import functional as F 11 | 12 | from torchvision.transforms import ToTensor, ToPILImage 13 | 14 | def adain_color_fix(target: Image, source: Image): 15 | # Convert images to tensors 16 | to_tensor = ToTensor() 17 | target_tensor = to_tensor(target).unsqueeze(0) 18 | source_tensor = to_tensor(source).unsqueeze(0) 19 | 20 | # Apply adaptive instance normalization 21 | result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) 22 | 23 | # Convert tensor back to image 24 | to_image = ToPILImage() 25 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 26 | 27 | return result_image 28 | 29 | def wavelet_color_fix(target: Image, source: Image): 30 | # Convert images to tensors 31 | to_tensor = ToTensor() 32 | target_tensor = to_tensor(target).unsqueeze(0) 33 | source_tensor = to_tensor(source).unsqueeze(0) 34 | 35 | # Apply wavelet reconstruction 36 | result_tensor = wavelet_reconstruction(target_tensor, source_tensor) 37 | 38 | # Convert tensor back to image 39 | to_image = ToPILImage() 40 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) 41 | 42 | return result_image 43 | 44 | def calc_mean_std(feat: Tensor, eps=1e-5): 45 | """Calculate mean and std for adaptive_instance_normalization. 46 | Args: 47 | feat (Tensor): 4D tensor. 48 | eps (float): A small value added to the variance to avoid 49 | divide-by-zero. Default: 1e-5. 50 | """ 51 | size = feat.size() 52 | assert len(size) == 4, 'The input feature should be 4D tensor.' 53 | b, c = size[:2] 54 | feat_var = feat.reshape(b, c, -1).var(dim=2) + eps 55 | feat_std = feat_var.sqrt().reshape(b, c, 1, 1) 56 | feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) 57 | return feat_mean, feat_std 58 | 59 | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): 60 | """Adaptive instance normalization. 61 | Adjust the reference features to have the similar color and illuminations 62 | as those in the degradate features. 63 | Args: 64 | content_feat (Tensor): The reference feature. 65 | style_feat (Tensor): The degradate features. 66 | """ 67 | size = content_feat.size() 68 | style_mean, style_std = calc_mean_std(style_feat) 69 | content_mean, content_std = calc_mean_std(content_feat) 70 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 71 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 72 | 73 | def wavelet_blur(image: Tensor, radius: int): 74 | """ 75 | Apply wavelet blur to the input tensor. 76 | """ 77 | # input shape: (1, 3, H, W) 78 | # convolution kernel 79 | kernel_vals = [ 80 | [0.0625, 0.125, 0.0625], 81 | [0.125, 0.25, 0.125], 82 | [0.0625, 0.125, 0.0625], 83 | ] 84 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) 85 | # add channel dimensions to the kernel to make it a 4D tensor 86 | kernel = kernel[None, None] 87 | # repeat the kernel across all input channels 88 | kernel = kernel.repeat(3, 1, 1, 1) 89 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate') 90 | # apply convolution 91 | output = F.conv2d(image, kernel, groups=3, dilation=radius) 92 | return output 93 | 94 | def wavelet_decomposition(image: Tensor, levels=5): 95 | """ 96 | Apply wavelet decomposition to the input tensor. 97 | This function only returns the low frequency & the high frequency. 98 | """ 99 | high_freq = torch.zeros_like(image) 100 | for i in range(levels): 101 | radius = 2 ** i 102 | low_freq = wavelet_blur(image, radius) 103 | high_freq += (image - low_freq) 104 | image = low_freq 105 | 106 | return high_freq, low_freq 107 | 108 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): 109 | """ 110 | Apply wavelet decomposition, so that the content will have the same color as the style. 111 | """ 112 | # calculate the wavelet decomposition of the content feature 113 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat) 114 | del content_low_freq 115 | # calculate the wavelet decomposition of the style feature 116 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat) 117 | del style_high_freq 118 | # reconstruct the content feature with the style's high frequency 119 | return content_high_freq + style_low_freq 120 | -------------------------------------------------------------------------------- /pasd/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/pasd/pipelines/__init__.py -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | salesforce-lavis==1.0.2 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.29.2 2 | accelerate 3 | transformers 4 | xformers; sys_platform != 'darwin' 5 | ultralytics 6 | webdataset 7 | open_clip_torch 8 | einops 9 | gradio<4.0.0 10 | pytorch_lightning 11 | -------------------------------------------------------------------------------- /runs/pasd/README.md: -------------------------------------------------------------------------------- 1 | Please put our pretrained pasd models here. 2 | 3 | [pasd](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/PASD/pasd.zip) -------------------------------------------------------------------------------- /runs/pasd_light/README.md: -------------------------------------------------------------------------------- 1 | Please put our pretrained pasd_light models here. 2 | 3 | [pasd_light](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/PASD/pasd_light.zip) -------------------------------------------------------------------------------- /runs/pasd_light_rrdb/README.md: -------------------------------------------------------------------------------- 1 | Please put our pretrained pasd_light_rrdb models here. 2 | 3 | [pasd_light_rrdb](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/PASD/pasd_light_rrdb.zip) -------------------------------------------------------------------------------- /runs/pasd_rrdb/README.md: -------------------------------------------------------------------------------- 1 | Please put our pretrained pasd_rrdb models here. 2 | 3 | [pasd_rrdb](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/PASD/pasd_rrdb.zip) -------------------------------------------------------------------------------- /samples/000001x2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000001x2.gif -------------------------------------------------------------------------------- /samples/000001x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000001x2.png -------------------------------------------------------------------------------- /samples/000001x2_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000001x2_comp.png -------------------------------------------------------------------------------- /samples/000001x2_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000001x2_out.png -------------------------------------------------------------------------------- /samples/000004x2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000004x2.gif -------------------------------------------------------------------------------- /samples/000004x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000004x2.png -------------------------------------------------------------------------------- /samples/000004x2_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000004x2_comp.png -------------------------------------------------------------------------------- /samples/000004x2_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000004x2_out.png -------------------------------------------------------------------------------- /samples/000020x2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000020x2.gif -------------------------------------------------------------------------------- /samples/000020x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000020x2.png -------------------------------------------------------------------------------- /samples/000020x2_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000020x2_comp.png -------------------------------------------------------------------------------- /samples/000020x2_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000020x2_out.png -------------------------------------------------------------------------------- /samples/000030x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000030x2.png -------------------------------------------------------------------------------- /samples/000030x2_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000030x2_comp.png -------------------------------------------------------------------------------- /samples/000030x2_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000030x2_out.png -------------------------------------------------------------------------------- /samples/000067x2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000067x2.gif -------------------------------------------------------------------------------- /samples/000067x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000067x2.png -------------------------------------------------------------------------------- /samples/000067x2_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000067x2_comp.png -------------------------------------------------------------------------------- /samples/000067x2_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000067x2_out.png -------------------------------------------------------------------------------- /samples/000080x2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000080x2.gif -------------------------------------------------------------------------------- /samples/000080x2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000080x2.png -------------------------------------------------------------------------------- /samples/000080x2_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000080x2_comp.png -------------------------------------------------------------------------------- /samples/000080x2_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/000080x2_out.png -------------------------------------------------------------------------------- /samples/0c74bd2420d532c2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/0c74bd2420d532c2.png -------------------------------------------------------------------------------- /samples/0c74bd2420d532c2_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/0c74bd2420d532c2_out.png -------------------------------------------------------------------------------- /samples/1125e119c19065f3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/1125e119c19065f3.png -------------------------------------------------------------------------------- /samples/1125e119c19065f3_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/1125e119c19065f3_out.png -------------------------------------------------------------------------------- /samples/1965223411271f2f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/1965223411271f2f.png -------------------------------------------------------------------------------- /samples/1965223411271f2f_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/1965223411271f2f_out.png -------------------------------------------------------------------------------- /samples/27d38eeb2dbbe7c9.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/27d38eeb2dbbe7c9.gif -------------------------------------------------------------------------------- /samples/27d38eeb2dbbe7c9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/27d38eeb2dbbe7c9.png -------------------------------------------------------------------------------- /samples/27d38eeb2dbbe7c9_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/27d38eeb2dbbe7c9_comp.png -------------------------------------------------------------------------------- /samples/27d38eeb2dbbe7c9_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/27d38eeb2dbbe7c9_out.png -------------------------------------------------------------------------------- /samples/2e512b688ef48a43.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/2e512b688ef48a43.gif -------------------------------------------------------------------------------- /samples/2e512b688ef48a43.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/2e512b688ef48a43.png -------------------------------------------------------------------------------- /samples/2e512b688ef48a43_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/2e512b688ef48a43_comp.png -------------------------------------------------------------------------------- /samples/2e512b688ef48a43_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/2e512b688ef48a43_out.png -------------------------------------------------------------------------------- /samples/2e753d77bca91095.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/2e753d77bca91095.png -------------------------------------------------------------------------------- /samples/2e753d77bca91095_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/2e753d77bca91095_out.png -------------------------------------------------------------------------------- /samples/38fa6c25e210c3a2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/38fa6c25e210c3a2.png -------------------------------------------------------------------------------- /samples/38fa6c25e210c3a2_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/38fa6c25e210c3a2_out.png -------------------------------------------------------------------------------- /samples/461a96f62b724eab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/461a96f62b724eab.png -------------------------------------------------------------------------------- /samples/461a96f62b724eab_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/461a96f62b724eab_out.png -------------------------------------------------------------------------------- /samples/629e4da70703193b.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/629e4da70703193b.gif -------------------------------------------------------------------------------- /samples/629e4da70703193b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/629e4da70703193b.png -------------------------------------------------------------------------------- /samples/629e4da70703193b_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/629e4da70703193b_comp.png -------------------------------------------------------------------------------- /samples/629e4da70703193b_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/629e4da70703193b_out.png -------------------------------------------------------------------------------- /samples/8f5cb2715536eef0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/8f5cb2715536eef0.png -------------------------------------------------------------------------------- /samples/8f5cb2715536eef0_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/8f5cb2715536eef0_out.png -------------------------------------------------------------------------------- /samples/Lincoln.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/Lincoln.gif -------------------------------------------------------------------------------- /samples/Lincoln.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/Lincoln.png -------------------------------------------------------------------------------- /samples/Lincoln_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/Lincoln_comp.png -------------------------------------------------------------------------------- /samples/Lincoln_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/Lincoln_out.png -------------------------------------------------------------------------------- /samples/RealPhoto60_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/RealPhoto60_06.png -------------------------------------------------------------------------------- /samples/RealPhoto60_09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/RealPhoto60_09.png -------------------------------------------------------------------------------- /samples/RealPhoto60_22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/RealPhoto60_22.png -------------------------------------------------------------------------------- /samples/RealPhoto60_56.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/RealPhoto60_56.png -------------------------------------------------------------------------------- /samples/building.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/building.gif -------------------------------------------------------------------------------- /samples/building.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/building.png -------------------------------------------------------------------------------- /samples/building_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/building_comp.png -------------------------------------------------------------------------------- /samples/building_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/building_out.png -------------------------------------------------------------------------------- /samples/d4f59e89c1011bc4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/d4f59e89c1011bc4.png -------------------------------------------------------------------------------- /samples/d4f59e89c1011bc4_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/d4f59e89c1011bc4_out.png -------------------------------------------------------------------------------- /samples/ed2ec7d15fcbe80e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/ed2ec7d15fcbe80e.png -------------------------------------------------------------------------------- /samples/ed2ec7d15fcbe80e_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/ed2ec7d15fcbe80e_comp.png -------------------------------------------------------------------------------- /samples/ed2ec7d15fcbe80e_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/ed2ec7d15fcbe80e_out.png -------------------------------------------------------------------------------- /samples/ed84982626af1f44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/ed84982626af1f44.png -------------------------------------------------------------------------------- /samples/ed84982626af1f44_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/ed84982626af1f44_comp.png -------------------------------------------------------------------------------- /samples/ed84982626af1f44_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/ed84982626af1f44_out.png -------------------------------------------------------------------------------- /samples/f125ee5838471073.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/f125ee5838471073.gif -------------------------------------------------------------------------------- /samples/f125ee5838471073.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/f125ee5838471073.png -------------------------------------------------------------------------------- /samples/f125ee5838471073_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/f125ee5838471073_comp.png -------------------------------------------------------------------------------- /samples/f125ee5838471073_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/f125ee5838471073_out.png -------------------------------------------------------------------------------- /samples/fe1ade76596bffdf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/fe1ade76596bffdf.png -------------------------------------------------------------------------------- /samples/fe1ade76596bffdf_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/fe1ade76596bffdf_out.png -------------------------------------------------------------------------------- /samples/frog.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/frog.gif -------------------------------------------------------------------------------- /samples/frog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/frog.png -------------------------------------------------------------------------------- /samples/frog_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/frog_out.png -------------------------------------------------------------------------------- /samples/house.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/house.gif -------------------------------------------------------------------------------- /samples/house.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/house.jpg -------------------------------------------------------------------------------- /samples/house_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/house_out.png -------------------------------------------------------------------------------- /samples/pasd_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangxy/PASD/396f9ac24f9fa2b9787658bd9eea31729e51f264/samples/pasd_arch.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | def read_requirements(): 5 | with open("requirements.txt") as f: 6 | return f.readlines() 7 | 8 | 9 | setup( 10 | name="pasd", 11 | version="0.0.1", 12 | url="https://github.com/yangxy/PASD.git", 13 | description=( 14 | "[ECCV2024] Pixel-Aware Stable Diffusion for Realistic " 15 | "Image Super-Resolution and Personalized Stylization" 16 | ), 17 | packages=find_packages(), 18 | install_requires=read_requirements(), 19 | ) -------------------------------------------------------------------------------- /train_pasd.sh: -------------------------------------------------------------------------------- 1 | TORCH_DISTRIBUTED_DEBUG=DETAIL accelerate launch \ 2 | train_pasd.py \ 3 | --dataset_name="pasd" \ 4 | --pretrained_model_name_or_path="checkpoints/stable-diffusion-v1-5" \ 5 | --output_dir="runs/pasd" \ 6 | --resolution=512 \ 7 | --learning_rate=5e-5 \ 8 | --gradient_accumulation_steps=2 \ 9 | --train_batch_size=4 \ 10 | --num_train_epochs=1000 \ 11 | --max_train_samples=10000000 \ 12 | --tracker_project_name="pasd" \ 13 | --enable_xformers_memory_efficient_attention \ 14 | --checkpointing_steps=10000 \ 15 | --control_type="realisr" \ 16 | --mixed_precision="fp16" \ 17 | --dataloader_num_workers=8 \ 18 | # --multi_gpu --num_processes=8 --gpu_ids '0,1,2,3,4,5,6,7' \ 19 | 20 | --------------------------------------------------------------------------------