├── .idea
├── CCSR.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── ADD
├── dnnlib
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ └── util.cpython-310.pyc
│ └── util.py
├── layers
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── attention.cpython-310.pyc
│ │ ├── block.cpython-310.pyc
│ │ ├── dino_head.cpython-310.pyc
│ │ ├── drop_path.cpython-310.pyc
│ │ ├── layer_scale.cpython-310.pyc
│ │ ├── mlp.cpython-310.pyc
│ │ ├── patch_embed.cpython-310.pyc
│ │ └── swiglu_ffn.cpython-310.pyc
│ ├── attention.py
│ ├── block.py
│ ├── dino_head.py
│ ├── drop_path.py
│ ├── layer_scale.py
│ ├── mlp.py
│ ├── patch_embed.py
│ └── swiglu_ffn.py
├── models
│ ├── __pycache__
│ │ ├── discriminator.cpython-310.pyc
│ │ └── vit.cpython-310.pyc
│ ├── discriminator.py
│ └── vit.py
├── th_utils
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── custom_ops.cpython-310.pyc
│ │ └── misc.cpython-310.pyc
│ ├── custom_ops.py
│ ├── misc.py
│ └── ops
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ └── bias_act.cpython-310.pyc
│ │ ├── bias_act.cpp
│ │ ├── bias_act.cu
│ │ ├── bias_act.h
│ │ ├── bias_act.py
│ │ ├── conv2d_gradfix.py
│ │ ├── conv2d_resample.py
│ │ ├── filtered_lrelu.cpp
│ │ ├── filtered_lrelu.cu
│ │ ├── filtered_lrelu.h
│ │ ├── filtered_lrelu.py
│ │ ├── filtered_lrelu_ns.cu
│ │ ├── filtered_lrelu_rd.cu
│ │ ├── filtered_lrelu_wr.cu
│ │ ├── fma.py
│ │ ├── grid_sample_gradfix.py
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.cu
│ │ ├── upfirdn2d.h
│ │ └── upfirdn2d.py
└── utils
│ ├── __pycache__
│ └── util_net.cpython-310.pyc
│ └── util_net.py
├── README.md
├── dataloaders
├── __pycache__
│ ├── paired_dataset_txt.cpython-310.pyc
│ └── realesrgan.cpython-310.pyc
├── paired_dataset_txt.py
├── params_ccsr.yml
└── realesrgan.py
├── figs
├── compare_1.png
├── compare_2.png
├── compare_3.png
├── compare_4.png
├── compare_efficient.png
├── compare_standard.png
├── fig.png
├── framework.png
├── logo.png
└── table.png
├── models
├── DiffAugment.py
├── __pycache__
│ ├── DiffAugment.cpython-310.pyc
│ ├── controlnet.cpython-310.pyc
│ ├── shared.cpython-310.pyc
│ ├── unet_2d_blocks.cpython-310.pyc
│ ├── unet_2d_condition.cpython-310.pyc
│ └── vit_utils.cpython-310.pyc
├── controlnet.py
├── losses
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── contperceptual.cpython-310.pyc
│ │ ├── contperceptual.cpython-37.pyc
│ │ └── contperceptual.cpython-38.pyc
│ ├── contperceptual.py
│ └── vqperceptual.py
├── shared.py
├── unet_2d_blocks.py
├── unet_2d_condition.py
└── vit_utils.py
├── myutils
├── __pycache__
│ ├── devices.cpython-310.pyc
│ └── wavelet_color_fix.cpython-310.pyc
├── devices.py
├── img_util.py
├── misc.py
├── vaehook.py
└── wavelet_color_fix.py
├── pipelines
├── __pycache__
│ └── pipeline_ccsr.cpython-310.pyc
└── pipeline_ccsr.py
├── requirements.txt
├── scripts
├── get_path.py
├── test
│ ├── test_ccsr_multistep.sh
│ ├── test_ccsr_onestep.sh
│ └── test_ccsr_tile.sh
└── train
│ ├── train_ccsr_stage1.sh
│ ├── train_ccsr_stage2.sh
│ └── train_controlnet.sh
├── test_ccsr_tile.py
├── train_ccsr_stage1.py
├── train_ccsr_stage2.py
├── train_controlnet.py
└── utils
├── __pycache__
└── vaehook.cpython-310.pyc
├── devices.py
├── img_util.py
├── misc.py
├── vaehook.py
└── wavelet_color_fix.py
/.idea/CCSR.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | 1734539270044
27 |
28 |
29 | 1734539270044
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/ADD/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | from .util import EasyDict, make_cache_dir_path
10 |
--------------------------------------------------------------------------------
/ADD/dnnlib/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/dnnlib/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/dnnlib/__pycache__/util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/dnnlib/__pycache__/util.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | from .dino_head import DINOHead
7 | from .mlp import Mlp
8 | from .patch_embed import PatchEmbed
9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10 | from .block import NestedTensorBlock
11 | from .attention import MemEffAttention
12 |
--------------------------------------------------------------------------------
/ADD/layers/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/layers/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/layers/__pycache__/attention.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/layers/__pycache__/attention.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/layers/__pycache__/block.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/layers/__pycache__/block.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/layers/__pycache__/dino_head.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/layers/__pycache__/dino_head.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/layers/__pycache__/drop_path.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/layers/__pycache__/drop_path.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/layers/__pycache__/layer_scale.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/layers/__pycache__/layer_scale.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/layers/__pycache__/mlp.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/layers/__pycache__/mlp.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/layers/__pycache__/patch_embed.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/layers/__pycache__/patch_embed.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/layers/__pycache__/swiglu_ffn.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/layers/__pycache__/swiglu_ffn.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9 |
10 | import logging
11 | import os
12 | import warnings
13 |
14 | from torch import Tensor
15 | from torch import nn
16 |
17 |
18 | logger = logging.getLogger("dinov2")
19 |
20 |
21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22 | try:
23 | if XFORMERS_ENABLED:
24 | from xformers.ops import memory_efficient_attention, unbind
25 |
26 | XFORMERS_AVAILABLE = True
27 | warnings.warn("xFormers is available (Attention)")
28 | else:
29 | warnings.warn("xFormers is disabled (Attention)")
30 | raise ImportError
31 | except ImportError:
32 | XFORMERS_AVAILABLE = False
33 | warnings.warn("xFormers is not available (Attention)")
34 |
35 |
36 | class Attention(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | num_heads: int = 8,
41 | qkv_bias: bool = False,
42 | proj_bias: bool = True,
43 | attn_drop: float = 0.0,
44 | proj_drop: float = 0.0,
45 | ) -> None:
46 | super().__init__()
47 | self.num_heads = num_heads
48 | head_dim = dim // num_heads
49 | self.scale = head_dim**-0.5
50 |
51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52 | self.attn_drop = nn.Dropout(attn_drop)
53 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
54 | self.proj_drop = nn.Dropout(proj_drop)
55 |
56 | def forward(self, x: Tensor) -> Tensor:
57 | B, N, C = x.shape
58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59 |
60 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61 | attn = q @ k.transpose(-2, -1)
62 |
63 | attn = attn.softmax(dim=-1)
64 | attn = self.attn_drop(attn)
65 |
66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67 | x = self.proj(x)
68 | x = self.proj_drop(x)
69 | return x
70 |
71 |
72 | class MemEffAttention(Attention):
73 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74 | if not XFORMERS_AVAILABLE:
75 | if attn_bias is not None:
76 | raise AssertionError("xFormers is required for using nested tensors")
77 | return super().forward(x)
78 |
79 | B, N, C = x.shape
80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81 |
82 | q, k, v = unbind(qkv, 2)
83 |
84 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85 | x = x.reshape([B, N, C])
86 |
87 | x = self.proj(x)
88 | x = self.proj_drop(x)
89 | return x
90 |
--------------------------------------------------------------------------------
/ADD/layers/block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9 |
10 | import logging
11 | import os
12 | from typing import Callable, List, Any, Tuple, Dict
13 | import warnings
14 |
15 | import torch
16 | from torch import nn, Tensor
17 |
18 | from .attention import Attention, MemEffAttention
19 | from .drop_path import DropPath
20 | from .layer_scale import LayerScale
21 | from .mlp import Mlp
22 |
23 |
24 | logger = logging.getLogger("dinov2")
25 |
26 |
27 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28 | try:
29 | if XFORMERS_ENABLED:
30 | from xformers.ops import fmha, scaled_index_add, index_select_cat
31 |
32 | XFORMERS_AVAILABLE = True
33 | warnings.warn("xFormers is available (Block)")
34 | else:
35 | warnings.warn("xFormers is disabled (Block)")
36 | raise ImportError
37 | except ImportError:
38 | XFORMERS_AVAILABLE = False
39 |
40 | warnings.warn("xFormers is not available (Block)")
41 |
42 |
43 | class Block(nn.Module):
44 | def __init__(
45 | self,
46 | dim: int,
47 | num_heads: int,
48 | mlp_ratio: float = 4.0,
49 | qkv_bias: bool = False,
50 | proj_bias: bool = True,
51 | ffn_bias: bool = True,
52 | drop: float = 0.0,
53 | attn_drop: float = 0.0,
54 | init_values=None,
55 | drop_path: float = 0.0,
56 | act_layer: Callable[..., nn.Module] = nn.GELU,
57 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
58 | attn_class: Callable[..., nn.Module] = Attention,
59 | ffn_layer: Callable[..., nn.Module] = Mlp,
60 | ) -> None:
61 | super().__init__()
62 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
63 | self.norm1 = norm_layer(dim)
64 | self.attn = attn_class(
65 | dim,
66 | num_heads=num_heads,
67 | qkv_bias=qkv_bias,
68 | proj_bias=proj_bias,
69 | attn_drop=attn_drop,
70 | proj_drop=drop,
71 | )
72 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
73 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
74 |
75 | self.norm2 = norm_layer(dim)
76 | mlp_hidden_dim = int(dim * mlp_ratio)
77 | self.mlp = ffn_layer(
78 | in_features=dim,
79 | hidden_features=mlp_hidden_dim,
80 | act_layer=act_layer,
81 | drop=drop,
82 | bias=ffn_bias,
83 | )
84 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
85 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
86 |
87 | self.sample_drop_ratio = drop_path
88 |
89 | def forward(self, x: Tensor) -> Tensor:
90 | def attn_residual_func(x: Tensor) -> Tensor:
91 | return self.ls1(self.attn(self.norm1(x)))
92 |
93 | def ffn_residual_func(x: Tensor) -> Tensor:
94 | return self.ls2(self.mlp(self.norm2(x)))
95 |
96 | if self.training and self.sample_drop_ratio > 0.1:
97 | # the overhead is compensated only for a drop path rate larger than 0.1
98 | x = drop_add_residual_stochastic_depth(
99 | x,
100 | residual_func=attn_residual_func,
101 | sample_drop_ratio=self.sample_drop_ratio,
102 | )
103 | x = drop_add_residual_stochastic_depth(
104 | x,
105 | residual_func=ffn_residual_func,
106 | sample_drop_ratio=self.sample_drop_ratio,
107 | )
108 | elif self.training and self.sample_drop_ratio > 0.0:
109 | x = x + self.drop_path1(attn_residual_func(x))
110 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
111 | else:
112 | x = x + attn_residual_func(x)
113 | x = x + ffn_residual_func(x)
114 | return x
115 |
116 |
117 | def drop_add_residual_stochastic_depth(
118 | x: Tensor,
119 | residual_func: Callable[[Tensor], Tensor],
120 | sample_drop_ratio: float = 0.0,
121 | ) -> Tensor:
122 | # 1) extract subset using permutation
123 | b, n, d = x.shape
124 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
125 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
126 | x_subset = x[brange]
127 |
128 | # 2) apply residual_func to get residual
129 | residual = residual_func(x_subset)
130 |
131 | x_flat = x.flatten(1)
132 | residual = residual.flatten(1)
133 |
134 | residual_scale_factor = b / sample_subset_size
135 |
136 | # 3) add the residual
137 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
138 | return x_plus_residual.view_as(x)
139 |
140 |
141 | def get_branges_scales(x, sample_drop_ratio=0.0):
142 | b, n, d = x.shape
143 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
144 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
145 | residual_scale_factor = b / sample_subset_size
146 | return brange, residual_scale_factor
147 |
148 |
149 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
150 | if scaling_vector is None:
151 | x_flat = x.flatten(1)
152 | residual = residual.flatten(1)
153 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
154 | else:
155 | x_plus_residual = scaled_index_add(
156 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
157 | )
158 | return x_plus_residual
159 |
160 |
161 | attn_bias_cache: Dict[Tuple, Any] = {}
162 |
163 |
164 | def get_attn_bias_and_cat(x_list, branges=None):
165 | """
166 | this will perform the index select, cat the tensors, and provide the attn_bias from cache
167 | """
168 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
169 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
170 | if all_shapes not in attn_bias_cache.keys():
171 | seqlens = []
172 | for b, x in zip(batch_sizes, x_list):
173 | for _ in range(b):
174 | seqlens.append(x.shape[1])
175 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
176 | attn_bias._batch_sizes = batch_sizes
177 | attn_bias_cache[all_shapes] = attn_bias
178 |
179 | if branges is not None:
180 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
181 | else:
182 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
183 | cat_tensors = torch.cat(tensors_bs1, dim=1)
184 |
185 | return attn_bias_cache[all_shapes], cat_tensors
186 |
187 |
188 | def drop_add_residual_stochastic_depth_list(
189 | x_list: List[Tensor],
190 | residual_func: Callable[[Tensor, Any], Tensor],
191 | sample_drop_ratio: float = 0.0,
192 | scaling_vector=None,
193 | ) -> Tensor:
194 | # 1) generate random set of indices for dropping samples in the batch
195 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
196 | branges = [s[0] for s in branges_scales]
197 | residual_scale_factors = [s[1] for s in branges_scales]
198 |
199 | # 2) get attention bias and index+concat the tensors
200 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
201 |
202 | # 3) apply residual_func to get residual, and split the result
203 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
204 |
205 | outputs = []
206 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
207 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
208 | return outputs
209 |
210 |
211 | class NestedTensorBlock(Block):
212 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
213 | """
214 | x_list contains a list of tensors to nest together and run
215 | """
216 | assert isinstance(self.attn, MemEffAttention)
217 |
218 | if self.training and self.sample_drop_ratio > 0.0:
219 |
220 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
221 | return self.attn(self.norm1(x), attn_bias=attn_bias)
222 |
223 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
224 | return self.mlp(self.norm2(x))
225 |
226 | x_list = drop_add_residual_stochastic_depth_list(
227 | x_list,
228 | residual_func=attn_residual_func,
229 | sample_drop_ratio=self.sample_drop_ratio,
230 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
231 | )
232 | x_list = drop_add_residual_stochastic_depth_list(
233 | x_list,
234 | residual_func=ffn_residual_func,
235 | sample_drop_ratio=self.sample_drop_ratio,
236 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
237 | )
238 | return x_list
239 | else:
240 |
241 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
242 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
243 |
244 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
245 | return self.ls2(self.mlp(self.norm2(x)))
246 |
247 | attn_bias, x = get_attn_bias_and_cat(x_list)
248 | x = x + attn_residual_func(x, attn_bias=attn_bias)
249 | x = x + ffn_residual_func(x)
250 | return attn_bias.split(x)
251 |
252 | def forward(self, x_or_x_list):
253 | if isinstance(x_or_x_list, Tensor):
254 | return super().forward(x_or_x_list)
255 | elif isinstance(x_or_x_list, list):
256 | if not XFORMERS_AVAILABLE:
257 | raise AssertionError("xFormers is required for using nested tensors")
258 | return self.forward_nested(x_or_x_list)
259 | else:
260 | raise AssertionError
261 |
--------------------------------------------------------------------------------
/ADD/layers/dino_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn.init import trunc_normal_
9 | from torch.nn.utils import weight_norm
10 |
11 |
12 | class DINOHead(nn.Module):
13 | def __init__(
14 | self,
15 | in_dim,
16 | out_dim,
17 | use_bn=False,
18 | nlayers=3,
19 | hidden_dim=2048,
20 | bottleneck_dim=256,
21 | mlp_bias=True,
22 | ):
23 | super().__init__()
24 | nlayers = max(nlayers, 1)
25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26 | self.apply(self._init_weights)
27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28 | self.last_layer.weight_g.data.fill_(1)
29 |
30 | def _init_weights(self, m):
31 | if isinstance(m, nn.Linear):
32 | trunc_normal_(m.weight, std=0.02)
33 | if isinstance(m, nn.Linear) and m.bias is not None:
34 | nn.init.constant_(m.bias, 0)
35 |
36 | def forward(self, x):
37 | x = self.mlp(x)
38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40 | x = self.last_layer(x)
41 | return x
42 |
43 |
44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45 | if nlayers == 1:
46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47 | else:
48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49 | if use_bn:
50 | layers.append(nn.BatchNorm1d(hidden_dim))
51 | layers.append(nn.GELU())
52 | for _ in range(nlayers - 2):
53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54 | if use_bn:
55 | layers.append(nn.BatchNorm1d(hidden_dim))
56 | layers.append(nn.GELU())
57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58 | return nn.Sequential(*layers)
59 |
--------------------------------------------------------------------------------
/ADD/layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9 |
10 |
11 | from torch import nn
12 |
13 |
14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15 | if drop_prob == 0.0 or not training:
16 | return x
17 | keep_prob = 1 - drop_prob
18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20 | if keep_prob > 0.0:
21 | random_tensor.div_(keep_prob)
22 | output = x * random_tensor
23 | return output
24 |
25 |
26 | class DropPath(nn.Module):
27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28 |
29 | def __init__(self, drop_prob=None):
30 | super(DropPath, self).__init__()
31 | self.drop_prob = drop_prob
32 |
33 | def forward(self, x):
34 | return drop_path(x, self.drop_prob, self.training)
35 |
--------------------------------------------------------------------------------
/ADD/layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7 |
8 | from typing import Union
9 |
10 | import torch
11 | from torch import Tensor
12 | from torch import nn
13 |
14 |
15 | class LayerScale(nn.Module):
16 | def __init__(
17 | self,
18 | dim: int,
19 | init_values: Union[float, Tensor] = 1e-5,
20 | inplace: bool = False,
21 | ) -> None:
22 | super().__init__()
23 | self.inplace = inplace
24 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
25 |
26 | def forward(self, x: Tensor) -> Tensor:
27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
28 |
--------------------------------------------------------------------------------
/ADD/layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9 |
10 |
11 | from typing import Callable, Optional
12 |
13 | from torch import Tensor, nn
14 |
15 |
16 | class Mlp(nn.Module):
17 | def __init__(
18 | self,
19 | in_features: int,
20 | hidden_features: Optional[int] = None,
21 | out_features: Optional[int] = None,
22 | act_layer: Callable[..., nn.Module] = nn.GELU,
23 | drop: float = 0.0,
24 | bias: bool = True,
25 | ) -> None:
26 | super().__init__()
27 | out_features = out_features or in_features
28 | hidden_features = hidden_features or in_features
29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30 | self.act = act_layer()
31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32 | self.drop = nn.Dropout(drop)
33 |
34 | def forward(self, x: Tensor) -> Tensor:
35 | x = self.fc1(x)
36 | x = self.act(x)
37 | x = self.drop(x)
38 | x = self.fc2(x)
39 | x = self.drop(x)
40 | return x
41 |
--------------------------------------------------------------------------------
/ADD/layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | # References:
7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9 |
10 | from typing import Callable, Optional, Tuple, Union
11 |
12 | from torch import Tensor
13 | import torch.nn as nn
14 |
15 |
16 | def make_2tuple(x):
17 | if isinstance(x, tuple):
18 | assert len(x) == 2
19 | return x
20 |
21 | assert isinstance(x, int)
22 | return (x, x)
23 |
24 |
25 | class PatchEmbed(nn.Module):
26 | """
27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28 |
29 | Args:
30 | img_size: Image size.
31 | patch_size: Patch token size.
32 | in_chans: Number of input image channels.
33 | embed_dim: Number of linear projection output channels.
34 | norm_layer: Normalization layer.
35 | """
36 |
37 | def __init__(
38 | self,
39 | img_size: Union[int, Tuple[int, int]] = 224,
40 | patch_size: Union[int, Tuple[int, int]] = 16,
41 | in_chans: int = 3,
42 | embed_dim: int = 768,
43 | norm_layer: Optional[Callable] = None,
44 | flatten_embedding: bool = True,
45 | ) -> None:
46 | super().__init__()
47 |
48 | image_HW = make_2tuple(img_size)
49 | patch_HW = make_2tuple(patch_size)
50 | patch_grid_size = (
51 | image_HW[0] // patch_HW[0],
52 | image_HW[1] // patch_HW[1],
53 | )
54 |
55 | self.img_size = image_HW
56 | self.patch_size = patch_HW
57 | self.patches_resolution = patch_grid_size
58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59 |
60 | #self.in_chans = in_chans
61 | self.embed_dim = embed_dim
62 |
63 | self.flatten_embedding = flatten_embedding
64 |
65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67 |
68 | def forward(self, x: Tensor) -> Tensor:
69 | _, _, H, W = x.shape
70 | patch_H, patch_W = self.patch_size
71 |
72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74 |
75 | x = self.proj(x) # B C H W
76 | H, W = x.size(2), x.size(3)
77 | x = x.flatten(2).transpose(1, 2) # B HW C
78 | x = self.norm(x)
79 | if not self.flatten_embedding:
80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81 | return x
82 |
83 | #def flops(self) -> float:
84 | #Ho, Wo = self.patches_resolution
85 | #flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86 | #if self.norm is not None:
87 | # flops += Ho * Wo * self.embed_dim
88 | #return flops
89 |
--------------------------------------------------------------------------------
/ADD/layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the Apache License, Version 2.0
4 | # found in the LICENSE file in the root directory of this source tree.
5 |
6 | import os
7 | from typing import Callable, Optional
8 | import warnings
9 |
10 | from torch import Tensor, nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class SwiGLUFFN(nn.Module):
15 | def __init__(
16 | self,
17 | in_features: int,
18 | hidden_features: Optional[int] = None,
19 | out_features: Optional[int] = None,
20 | act_layer: Callable[..., nn.Module] = None,
21 | drop: float = 0.0,
22 | bias: bool = True,
23 | ) -> None:
24 | super().__init__()
25 | out_features = out_features or in_features
26 | hidden_features = hidden_features or in_features
27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29 |
30 | def forward(self, x: Tensor) -> Tensor:
31 | x12 = self.w12(x)
32 | x1, x2 = x12.chunk(2, dim=-1)
33 | hidden = F.silu(x1) * x2
34 | return self.w3(hidden)
35 |
36 |
37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38 | try:
39 | if XFORMERS_ENABLED:
40 | from xformers.ops import SwiGLU
41 |
42 | XFORMERS_AVAILABLE = True
43 | warnings.warn("xFormers is available (SwiGLU)")
44 | else:
45 | warnings.warn("xFormers is disabled (SwiGLU)")
46 | raise ImportError
47 | except ImportError:
48 | SwiGLU = SwiGLUFFN
49 | XFORMERS_AVAILABLE = False
50 |
51 | warnings.warn("xFormers is not available (SwiGLU)")
52 |
53 |
54 | class SwiGLUFFNFused(SwiGLU):
55 | def __init__(
56 | self,
57 | in_features: int,
58 | hidden_features: Optional[int] = None,
59 | out_features: Optional[int] = None,
60 | act_layer: Callable[..., nn.Module] = None,
61 | drop: float = 0.0,
62 | bias: bool = True,
63 | ) -> None:
64 | out_features = out_features or in_features
65 | hidden_features = hidden_features or in_features
66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67 | super().__init__(
68 | in_features=in_features,
69 | hidden_features=hidden_features,
70 | out_features=out_features,
71 | bias=bias,
72 | )
73 |
--------------------------------------------------------------------------------
/ADD/models/__pycache__/discriminator.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/models/__pycache__/discriminator.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/models/__pycache__/vit.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/models/__pycache__/vit.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/models/discriminator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """
10 | Projected discriminator architecture from
11 | "StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis".
12 | """
13 |
14 | import numpy as np
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | from torch.nn.utils.spectral_norm import SpectralNorm
19 | from torchvision.transforms import RandomCrop, Normalize
20 | import timm
21 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
22 |
23 | from ADD.th_utils import misc
24 | from models.shared import ResidualBlock, FullyConnectedLayer
25 | from models.vit_utils import make_vit_backbone, forward_vit, make_sd_backbone
26 | from models.DiffAugment import DiffAugment
27 | from ADD.utils.util_net import reload_model_
28 |
29 | from functools import partial
30 |
31 | class SpectralConv1d(nn.Conv1d):
32 | def __init__(self, *args, **kwargs):
33 | super().__init__(*args, **kwargs)
34 | SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12)
35 |
36 |
37 | class BatchNormLocal(nn.Module):
38 | def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 3, eps: float = 1e-5):
39 | super().__init__()
40 | self.virtual_bs = virtual_bs
41 | self.eps = eps
42 | self.affine = affine
43 |
44 | if self.affine:
45 | self.weight = nn.Parameter(torch.ones(num_features))
46 | self.bias = nn.Parameter(torch.zeros(num_features))
47 |
48 | def forward(self, x: torch.Tensor) -> torch.Tensor:
49 | shape = x.size()
50 |
51 | # Reshape batch into groups.
52 | G = np.ceil(x.size(0)/self.virtual_bs).astype(int)
53 | x = x.view(G, -1, x.size(-2), x.size(-1))
54 |
55 | # Calculate stats.
56 | mean = x.mean([1, 3], keepdim=True)
57 | var = x.var([1, 3], keepdim=True, unbiased=False)
58 | x = (x - mean) / (torch.sqrt(var + self.eps))
59 |
60 | if self.affine:
61 | x = x * self.weight[None, :, None] + self.bias[None, :, None]
62 |
63 | return x.view(shape)
64 |
65 |
66 | def make_block(channels: int, kernel_size: int) -> nn.Module:
67 | return nn.Sequential(
68 | SpectralConv1d(
69 | channels,
70 | channels,
71 | kernel_size = kernel_size,
72 | padding = kernel_size//2,
73 | padding_mode = 'circular',
74 | ),
75 | #BatchNormLocal(channels),
76 | nn.GroupNorm(4, channels),
77 | nn.LeakyReLU(0.2, True),
78 | )
79 |
80 | class DiscHead(nn.Module):
81 | def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64):
82 | super().__init__()
83 | self.channels = channels
84 | self.c_dim = c_dim
85 | self.cmap_dim = cmap_dim
86 |
87 | self.main = nn.Sequential(
88 | make_block(channels, kernel_size=1),
89 | ResidualBlock(make_block(channels, kernel_size=9))
90 | )
91 |
92 | if self.c_dim > 0:
93 | self.cmapper = FullyConnectedLayer(self.c_dim, cmap_dim)
94 | self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0)
95 | else:
96 | self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0)
97 |
98 | def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
99 | h = self.main(x)
100 | out = self.cls(h)
101 |
102 | if self.c_dim > 0:
103 | cmap = self.cmapper(c).unsqueeze(-1)
104 | out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
105 |
106 | return out
107 |
108 | class DINO(torch.nn.Module):
109 | def __init__(self, hooks: list[int] = [2,5,8,11], hook_patch: bool = True):
110 | super().__init__()
111 | self.n_hooks = len(hooks) + int(hook_patch)
112 |
113 | self.model = make_vit_backbone(
114 | timm.create_model('vit_small_patch16_224.dino', pretrained=False),
115 | patch_size=[16,16], hooks=hooks, hook_patch=hook_patch,
116 | )
117 | reload_model_(self.model, torch.load('preset/models/dino/dino_deitsmall16_pretrain.pth'))
118 | self.model = self.model.eval().requires_grad_(False)
119 |
120 |
121 | self.img_resolution = self.model.model.patch_embed.img_size[0]
122 | self.embed_dim = self.model.model.embed_dim
123 | self.norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
124 |
125 | def forward(self, x: torch.Tensor) -> torch.Tensor:
126 | ''' input: x in [0, 1]; output: dict of activations '''
127 | x = F.interpolate(x, self.img_resolution, mode='area')
128 | x = self.norm(x)
129 | features = forward_vit(self.model, x)
130 | return features
131 |
132 |
133 | class ProjectedDiscriminator(nn.Module):
134 | def __init__(self, c_dim: int, diffaug: bool = True, p_crop: float = 0.5):
135 | super().__init__()
136 | self.c_dim = c_dim
137 | self.diffaug = diffaug
138 | self.p_crop = p_crop
139 |
140 | self.dino = DINO()
141 |
142 | heads = []
143 | for i in range(self.dino.n_hooks):
144 | heads += [str(i), DiscHead(self.dino.embed_dim, c_dim)],
145 | self.heads = nn.ModuleDict(heads)
146 |
147 | def train(self, mode: bool = True):
148 | self.dino = self.dino.train(False)
149 | self.heads = self.heads.train(mode)
150 | return self
151 |
152 | def eval(self):
153 | return self.train(False)
154 |
155 | def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
156 | # Apply augmentation (x in [-1, 1]).
157 | if self.diffaug:
158 | x = DiffAugment(x, policy='translation,cutout')
159 |
160 | # Transform to [0, 1].
161 | x = x.add(1).div(2)
162 |
163 | # Take crops with probablity p_crop if the image is larger.
164 | if x.size(-1) > self.dino.img_resolution and np.random.random() < self.p_crop:
165 | x = RandomCrop(self.dino.img_resolution)(x)
166 |
167 | # Forward pass through DINO ViT.
168 | features = self.dino(x)
169 |
170 | # Apply discriminator heads.
171 | logits = []
172 | for k, head in self.heads.items():
173 | features[k].requires_grad_(True)
174 | logits.append(head(features[k], c).view(x.size(0), -1))
175 | #logits = torch.cat(logits, dim=1)
176 |
177 | return logits, features
178 |
179 |
--------------------------------------------------------------------------------
/ADD/th_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/ADD/th_utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/th_utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/th_utils/__pycache__/custom_ops.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/th_utils/__pycache__/custom_ops.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/th_utils/__pycache__/misc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/th_utils/__pycache__/misc.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/th_utils/custom_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import glob
10 | import hashlib
11 | import importlib
12 | import os
13 | import re
14 | import shutil
15 | import uuid
16 |
17 | import torch
18 | import torch.utils.cpp_extension
19 | from torch.utils.file_baton import FileBaton
20 |
21 | #----------------------------------------------------------------------------
22 | # Global options.
23 |
24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
25 |
26 | #----------------------------------------------------------------------------
27 | # Internal helper funcs.
28 |
29 | def _find_compiler_bindir():
30 | patterns = [
31 | 'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
32 | 'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
33 | 'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
34 | 'C:/Program Files*/Microsoft Visual Studio */vc/bin',
35 | ]
36 | for pattern in patterns:
37 | matches = sorted(glob.glob(pattern))
38 | if len(matches):
39 | return matches[-1]
40 | return None
41 |
42 | #----------------------------------------------------------------------------
43 |
44 | def _get_mangled_gpu_name():
45 | name = torch.cuda.get_device_name().lower()
46 | out = []
47 | for c in name:
48 | if re.match('[a-z0-9_-]+', c):
49 | out.append(c)
50 | else:
51 | out.append('-')
52 | return ''.join(out)
53 |
54 | #----------------------------------------------------------------------------
55 | # Main entry point for compiling and loading C++/CUDA plugins.
56 |
57 | _cached_plugins = dict()
58 |
59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
60 | assert verbosity in ['none', 'brief', 'full']
61 | if headers is None:
62 | headers = []
63 | if source_dir is not None:
64 | sources = [os.path.join(source_dir, fname) for fname in sources]
65 | headers = [os.path.join(source_dir, fname) for fname in headers]
66 |
67 | # Already cached?
68 | if module_name in _cached_plugins:
69 | return _cached_plugins[module_name]
70 |
71 | # Print status.
72 | if verbosity == 'full':
73 | print(f'Setting up PyTorch plugin "{module_name}"...')
74 | elif verbosity == 'brief':
75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
76 | verbose_build = (verbosity == 'full')
77 |
78 | # Compile and load.
79 | try: # pylint: disable=too-many-nested-blocks
80 | # Make sure we can find the necessary compiler binaries.
81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
82 | compiler_bindir = _find_compiler_bindir()
83 | if compiler_bindir is None:
84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
85 | os.environ['PATH'] += ';' + compiler_bindir
86 |
87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
88 | # break the build or unnecessarily restrict what's available to nvcc.
89 | # Unset it to let nvcc decide based on what's available on the
90 | # machine.
91 | os.environ['TORCH_CUDA_ARCH_LIST'] = ''
92 |
93 | # Incremental build md5sum trickery. Copies all the input source files
94 | # into a cached build directory under a combined md5 digest of the input
95 | # source files. Copying is done only if the combined digest has changed.
96 | # This keeps input file timestamps and filenames the same as in previous
97 | # extension builds, allowing for fast incremental rebuilds.
98 | #
99 | # This optimization is done only in case all the source files reside in
100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
101 | # environment variable is set (we take this as a signal that the user
102 | # actually cares about this.)
103 | #
104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
105 | # around the *.cu dependency bug in ninja config.
106 | #
107 | all_source_files = sorted(sources + headers)
108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
110 |
111 | # Compute combined hash digest for all source files.
112 | hash_md5 = hashlib.md5()
113 | for src in all_source_files:
114 | with open(src, 'rb') as f:
115 | hash_md5.update(f.read())
116 |
117 | # Select cached build directory name.
118 | source_digest = hash_md5.hexdigest()
119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
121 |
122 | if not os.path.isdir(cached_build_dir):
123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
124 | os.makedirs(tmpdir)
125 | for src in all_source_files:
126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
127 | try:
128 | os.replace(tmpdir, cached_build_dir) # atomic
129 | except OSError:
130 | # source directory already exists, delete tmpdir and its contents.
131 | shutil.rmtree(tmpdir)
132 | if not os.path.isdir(cached_build_dir): raise
133 |
134 | # Compile.
135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
137 | verbose=verbose_build, sources=cached_sources, **build_kwargs)
138 | else:
139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
140 |
141 | # Load.
142 | module = importlib.import_module(module_name)
143 |
144 | except:
145 | if verbosity == 'brief':
146 | print('Failed!')
147 | raise
148 |
149 | # Print status and add to cache dict.
150 | if verbosity == 'full':
151 | print(f'Done setting up PyTorch plugin "{module_name}".')
152 | elif verbosity == 'brief':
153 | print('Done.')
154 | _cached_plugins[module_name] = module
155 | return module
156 |
157 | #----------------------------------------------------------------------------
158 |
--------------------------------------------------------------------------------
/ADD/th_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | import re
10 | import contextlib
11 | import numpy as np
12 | import torch
13 | import warnings
14 | import ADD.dnnlib as dnnlib
15 |
16 | #----------------------------------------------------------------------------
17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18 | # same constant is used multiple times.
19 |
20 | _constant_cache = dict()
21 |
22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23 | value = np.asarray(value)
24 | if shape is not None:
25 | shape = tuple(shape)
26 | if dtype is None:
27 | dtype = torch.get_default_dtype()
28 | if device is None:
29 | device = torch.device('cpu')
30 | if memory_format is None:
31 | memory_format = torch.contiguous_format
32 |
33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34 | tensor = _constant_cache.get(key, None)
35 | if tensor is None:
36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37 | if shape is not None:
38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39 | tensor = tensor.contiguous(memory_format=memory_format)
40 | _constant_cache[key] = tensor
41 | return tensor
42 |
43 | #----------------------------------------------------------------------------
44 | # Replace NaN/Inf with specified numerical values.
45 |
46 | try:
47 | nan_to_num = torch.nan_to_num # 1.8.0a0
48 | except AttributeError:
49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50 | assert isinstance(input, torch.Tensor)
51 | if posinf is None:
52 | posinf = torch.finfo(input.dtype).max
53 | if neginf is None:
54 | neginf = torch.finfo(input.dtype).min
55 | assert nan == 0
56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57 |
58 | #----------------------------------------------------------------------------
59 | # Symbolic assert.
60 |
61 | try:
62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63 | except AttributeError:
64 | symbolic_assert = torch.Assert # 1.7.0
65 |
66 | #----------------------------------------------------------------------------
67 | # Context manager to temporarily suppress known warnings in torch.jit.trace().
68 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
69 |
70 | @contextlib.contextmanager
71 | def suppress_tracer_warnings():
72 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
73 | warnings.filters.insert(0, flt)
74 | yield
75 | warnings.filters.remove(flt)
76 |
77 | #----------------------------------------------------------------------------
78 | # Assert that the shape of a tensor matches the given list of integers.
79 | # None indicates that the size of a dimension is allowed to vary.
80 | # Performs symbolic assertion when used in torch.jit.trace().
81 |
82 | def assert_shape(tensor, ref_shape):
83 | if tensor.ndim != len(ref_shape):
84 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
85 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
86 | if ref_size is None:
87 | pass
88 | elif isinstance(ref_size, torch.Tensor):
89 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
90 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
91 | elif isinstance(size, torch.Tensor):
92 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
93 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
94 | elif size != ref_size:
95 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
96 |
97 | #----------------------------------------------------------------------------
98 | # Function decorator that calls torch.autograd.profiler.record_function().
99 |
100 | def profiled_function(fn):
101 | def decorator(*args, **kwargs):
102 | with torch.autograd.profiler.record_function(fn.__name__):
103 | return fn(*args, **kwargs)
104 | decorator.__name__ = fn.__name__
105 | return decorator
106 |
107 | #----------------------------------------------------------------------------
108 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
109 | # indefinitely, shuffling items as it goes.
110 |
111 | class InfiniteSampler(torch.utils.data.Sampler):
112 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
113 | assert len(dataset) > 0
114 | assert num_replicas > 0
115 | assert 0 <= rank < num_replicas
116 | assert 0 <= window_size <= 1
117 | super().__init__(dataset)
118 | self.dataset = dataset
119 | self.rank = rank
120 | self.num_replicas = num_replicas
121 | self.shuffle = shuffle
122 | self.seed = seed
123 | self.window_size = window_size
124 |
125 | def __iter__(self):
126 | order = np.arange(len(self.dataset))
127 | rnd = None
128 | window = 0
129 | if self.shuffle:
130 | rnd = np.random.RandomState(self.seed)
131 | rnd.shuffle(order)
132 | window = int(np.rint(order.size * self.window_size))
133 |
134 | idx = 0
135 | while True:
136 | i = idx % order.size
137 | if idx % self.num_replicas == self.rank:
138 | yield order[i]
139 | if window >= 2:
140 | j = (i - rnd.randint(window)) % order.size
141 | order[i], order[j] = order[j], order[i]
142 | idx += 1
143 |
144 | #----------------------------------------------------------------------------
145 | # Utilities for operating with torch.nn.Module parameters and buffers.
146 | def spectral_to_cpu(model: torch.nn.Module):
147 | def wrapped_in_spectral(m): return hasattr(m, 'weight_v')
148 | children = get_children(model)
149 | for child in children:
150 | if wrapped_in_spectral(child):
151 | child.weight = child.weight.cpu()
152 | return model
153 |
154 | def get_children(model: torch.nn.Module):
155 | children = list(model.children())
156 | flatt_children = []
157 | if children == []:
158 | return model
159 | else:
160 | for child in children:
161 | try:
162 | flatt_children.extend(get_children(child))
163 | except TypeError:
164 | flatt_children.append(get_children(child))
165 | return flatt_children
166 |
167 | def params_and_buffers(module):
168 | assert isinstance(module, torch.nn.Module)
169 | return list(module.parameters()) + list(module.buffers())
170 |
171 | def named_params_and_buffers(module):
172 | assert isinstance(module, torch.nn.Module)
173 | return list(module.named_parameters()) + list(module.named_buffers())
174 |
175 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
176 | assert isinstance(src_module, torch.nn.Module)
177 | assert isinstance(dst_module, torch.nn.Module)
178 | src_tensors = dict(named_params_and_buffers(src_module))
179 | for name, tensor in named_params_and_buffers(dst_module):
180 | assert (name in src_tensors) or (not require_all)
181 | if name in src_tensors:
182 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
183 |
184 | #----------------------------------------------------------------------------
185 | # Context manager for easily enabling/disabling DistributedDataParallel
186 | # synchronization.
187 |
188 | @contextlib.contextmanager
189 | def ddp_sync(module, sync):
190 | assert isinstance(module, torch.nn.Module)
191 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
192 | yield
193 | else:
194 | with module.no_sync():
195 | yield
196 |
197 | #----------------------------------------------------------------------------
198 | # Check DistributedDataParallel consistency across processes.
199 |
200 | def check_ddp_consistency(module, ignore_regex=None):
201 | assert isinstance(module, torch.nn.Module)
202 | for name, tensor in named_params_and_buffers(module):
203 | fullname = type(module).__name__ + '.' + name
204 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
205 | continue
206 | tensor = tensor.detach()
207 | if tensor.is_floating_point():
208 | tensor = nan_to_num(tensor)
209 | other = tensor.clone()
210 | torch.distributed.broadcast(tensor=other, src=0)
211 | assert (tensor == other).all(), fullname
212 |
213 | #----------------------------------------------------------------------------
214 | # Print summary table of module hierarchy.
215 |
216 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
217 | assert isinstance(module, torch.nn.Module)
218 | assert not isinstance(module, torch.jit.ScriptModule)
219 | assert isinstance(inputs, (tuple, list))
220 |
221 | # Register hooks.
222 | entries = []
223 | nesting = [0]
224 | def pre_hook(_mod, _inputs):
225 | nesting[0] += 1
226 | def post_hook(mod, _inputs, outputs):
227 | nesting[0] -= 1
228 | if nesting[0] <= max_nesting:
229 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
230 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
231 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
232 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
233 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
234 |
235 | # Run module.
236 | outputs = module(*inputs)
237 | for hook in hooks:
238 | hook.remove()
239 |
240 | # Identify unique outputs, parameters, and buffers.
241 | tensors_seen = set()
242 | for e in entries:
243 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
244 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
245 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
246 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
247 |
248 | # Filter out redundant entries.
249 | if skip_redundant:
250 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
251 |
252 | # Construct table.
253 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
254 | rows += [['---'] * len(rows[0])]
255 | param_total = 0
256 | buffer_total = 0
257 | submodule_names = {mod: name for name, mod in module.named_modules()}
258 | for e in entries:
259 | name = '' if e.mod is module else submodule_names[e.mod]
260 | param_size = sum(t.numel() for t in e.unique_params)
261 | buffer_size = sum(t.numel() for t in e.unique_buffers)
262 | output_shapes = [str(list(t.shape)) for t in e.outputs]
263 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
264 | rows += [[
265 | name + (':0' if len(e.outputs) >= 2 else ''),
266 | str(param_size) if param_size else '-',
267 | str(buffer_size) if buffer_size else '-',
268 | (output_shapes + ['-'])[0],
269 | (output_dtypes + ['-'])[0],
270 | ]]
271 | for idx in range(1, len(e.outputs)):
272 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
273 | param_total += param_size
274 | buffer_total += buffer_size
275 | rows += [['---'] * len(rows[0])]
276 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
277 |
278 | # Print table.
279 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
280 | print()
281 | for row in rows:
282 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
283 | print()
284 | return outputs
285 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | # empty
10 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/th_utils/ops/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/th_utils/ops/__pycache__/bias_act.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/th_utils/ops/__pycache__/bias_act.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/th_utils/ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "bias_act.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17 | {
18 | if (x.dim() != y.dim())
19 | return false;
20 | for (int64_t i = 0; i < x.dim(); i++)
21 | {
22 | if (x.size(i) != y.size(i))
23 | return false;
24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25 | return false;
26 | }
27 | return true;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 |
32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33 | {
34 | // Validate arguments.
35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
45 |
46 | // Validate layout.
47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52 |
53 | // Create output tensor.
54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55 | torch::Tensor y = torch::empty_like(x);
56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57 |
58 | // Initialize CUDA kernel parameters.
59 | bias_act_kernel_params p;
60 | p.x = x.data_ptr();
61 | p.b = (b.numel()) ? b.data_ptr() : NULL;
62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65 | p.y = y.data_ptr();
66 | p.grad = grad;
67 | p.act = act;
68 | p.alpha = alpha;
69 | p.gain = gain;
70 | p.clamp = clamp;
71 | p.sizeX = (int)x.numel();
72 | p.sizeB = (int)b.numel();
73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74 |
75 | // Choose CUDA kernel.
76 | void* kernel;
77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78 | {
79 | kernel = choose_bias_act_kernel(p);
80 | });
81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82 |
83 | // Launch CUDA kernel.
84 | p.loopX = 4;
85 | int blockSize = 4 * 32;
86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87 | void* args[] = {&p};
88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89 | return y;
90 | }
91 |
92 | //------------------------------------------------------------------------
93 |
94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95 | {
96 | m.def("bias_act", &bias_act);
97 | }
98 |
99 | //------------------------------------------------------------------------
100 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include "bias_act.h"
11 |
12 | //------------------------------------------------------------------------
13 | // Helpers.
14 |
15 | template struct InternalType;
16 | template <> struct InternalType { typedef double scalar_t; };
17 | template <> struct InternalType { typedef float scalar_t; };
18 | template <> struct InternalType { typedef float scalar_t; };
19 |
20 | //------------------------------------------------------------------------
21 | // CUDA kernel.
22 |
23 | template
24 | __global__ void bias_act_kernel(bias_act_kernel_params p)
25 | {
26 | typedef typename InternalType::scalar_t scalar_t;
27 | int G = p.grad;
28 | scalar_t alpha = (scalar_t)p.alpha;
29 | scalar_t gain = (scalar_t)p.gain;
30 | scalar_t clamp = (scalar_t)p.clamp;
31 | scalar_t one = (scalar_t)1;
32 | scalar_t two = (scalar_t)2;
33 | scalar_t expRange = (scalar_t)80;
34 | scalar_t halfExpRange = (scalar_t)40;
35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37 |
38 | // Loop over elements.
39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41 | {
42 | // Load.
43 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48 | scalar_t yy = (gain != 0) ? yref / gain : 0;
49 | scalar_t y = 0;
50 |
51 | // Apply bias.
52 | ((G == 0) ? x : xref) += b;
53 |
54 | // linear
55 | if (A == 1)
56 | {
57 | if (G == 0) y = x;
58 | if (G == 1) y = x;
59 | }
60 |
61 | // relu
62 | if (A == 2)
63 | {
64 | if (G == 0) y = (x > 0) ? x : 0;
65 | if (G == 1) y = (yy > 0) ? x : 0;
66 | }
67 |
68 | // lrelu
69 | if (A == 3)
70 | {
71 | if (G == 0) y = (x > 0) ? x : x * alpha;
72 | if (G == 1) y = (yy > 0) ? x : x * alpha;
73 | }
74 |
75 | // tanh
76 | if (A == 4)
77 | {
78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79 | if (G == 1) y = x * (one - yy * yy);
80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81 | }
82 |
83 | // sigmoid
84 | if (A == 5)
85 | {
86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87 | if (G == 1) y = x * yy * (one - yy);
88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89 | }
90 |
91 | // elu
92 | if (A == 6)
93 | {
94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97 | }
98 |
99 | // selu
100 | if (A == 7)
101 | {
102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105 | }
106 |
107 | // softplus
108 | if (A == 8)
109 | {
110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111 | if (G == 1) y = x * (one - exp(-yy));
112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113 | }
114 |
115 | // swish
116 | if (A == 9)
117 | {
118 | if (G == 0)
119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120 | else
121 | {
122 | scalar_t c = exp(xref);
123 | scalar_t d = c + one;
124 | if (G == 1)
125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126 | else
127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129 | }
130 | }
131 |
132 | // Apply gain.
133 | y *= gain * dy;
134 |
135 | // Clamp.
136 | if (clamp >= 0)
137 | {
138 | if (G == 0)
139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140 | else
141 | y = (yref > -clamp & yref < clamp) ? y : 0;
142 | }
143 |
144 | // Store.
145 | ((T*)p.y)[xi] = (T)y;
146 | }
147 | }
148 |
149 | //------------------------------------------------------------------------
150 | // CUDA kernel selection.
151 |
152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153 | {
154 | if (p.act == 1) return (void*)bias_act_kernel;
155 | if (p.act == 2) return (void*)bias_act_kernel;
156 | if (p.act == 3) return (void*)bias_act_kernel;
157 | if (p.act == 4) return (void*)bias_act_kernel;
158 | if (p.act == 5) return (void*)bias_act_kernel;
159 | if (p.act == 6) return (void*)bias_act_kernel;
160 | if (p.act == 7) return (void*)bias_act_kernel;
161 | if (p.act == 8) return (void*)bias_act_kernel;
162 | if (p.act == 9) return (void*)bias_act_kernel;
163 | return NULL;
164 | }
165 |
166 | //------------------------------------------------------------------------
167 | // Template specializations.
168 |
169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
172 |
173 | //------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/bias_act.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom PyTorch ops for efficient bias and activation."""
10 |
11 | import os
12 | import numpy as np
13 | import torch
14 | import ADD.dnnlib as dnnlib
15 |
16 | from .. import custom_ops
17 | from .. import misc
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | activation_funcs = {
22 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
23 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
24 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
25 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
26 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
27 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
28 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
29 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
30 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
31 | }
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | _plugin = None
36 | _null_tensor = torch.empty([0])
37 |
38 | def _init():
39 | global _plugin
40 | if _plugin is None:
41 | _plugin = custom_ops.get_plugin(
42 | module_name='bias_act_plugin',
43 | sources=['bias_act.cpp', 'bias_act.cu'],
44 | headers=['bias_act.h'],
45 | source_dir=os.path.dirname(__file__),
46 | extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
47 | )
48 | return True
49 |
50 | #----------------------------------------------------------------------------
51 |
52 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
53 | r"""Fused bias and activation function.
54 |
55 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
56 | and scales the result by `gain`. Each of the steps is optional. In most cases,
57 | the fused op is considerably more efficient than performing the same calculation
58 | using standard PyTorch ops. It supports first and second order gradients,
59 | but not third order gradients.
60 |
61 | Args:
62 | x: Input activation tensor. Can be of any shape.
63 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
64 | as `x`. The shape must be known, and it must match the dimension of `x`
65 | corresponding to `dim`.
66 | dim: The dimension in `x` corresponding to the elements of `b`.
67 | The value of `dim` is ignored if `b` is not specified.
68 | act: Name of the activation function to evaluate, or `"linear"` to disable.
69 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
70 | See `activation_funcs` for a full list. `None` is not allowed.
71 | alpha: Shape parameter for the activation function, or `None` to use the default.
72 | gain: Scaling factor for the output tensor, or `None` to use default.
73 | See `activation_funcs` for the default scaling of each activation function.
74 | If unsure, consider specifying 1.
75 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
76 | the clamping (default).
77 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
78 |
79 | Returns:
80 | Tensor of the same shape and datatype as `x`.
81 | """
82 | assert isinstance(x, torch.Tensor)
83 | assert impl in ['ref', 'cuda']
84 | if impl == 'cuda' and x.device.type == 'cuda' and _init():
85 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
86 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
87 |
88 | #----------------------------------------------------------------------------
89 |
90 | @misc.profiled_function
91 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
92 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
93 | """
94 | assert isinstance(x, torch.Tensor)
95 | assert clamp is None or clamp >= 0
96 | spec = activation_funcs[act]
97 | alpha = float(alpha if alpha is not None else spec.def_alpha)
98 | gain = float(gain if gain is not None else spec.def_gain)
99 | clamp = float(clamp if clamp is not None else -1)
100 |
101 | # Add bias.
102 | if b is not None:
103 | assert isinstance(b, torch.Tensor) and b.ndim == 1
104 | assert 0 <= dim < x.ndim
105 | assert b.shape[0] == x.shape[dim]
106 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
107 |
108 | # Evaluate activation function.
109 | alpha = float(alpha)
110 | x = spec.func(x, alpha=alpha)
111 |
112 | # Scale by gain.
113 | gain = float(gain)
114 | if gain != 1:
115 | x = x * gain
116 |
117 | # Clamp.
118 | if clamp >= 0:
119 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
120 | return x
121 |
122 | #----------------------------------------------------------------------------
123 |
124 | _bias_act_cuda_cache = dict()
125 |
126 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
127 | """Fast CUDA implementation of `bias_act()` using custom ops.
128 | """
129 | # Parse arguments.
130 | assert clamp is None or clamp >= 0
131 | spec = activation_funcs[act]
132 | alpha = float(alpha if alpha is not None else spec.def_alpha)
133 | gain = float(gain if gain is not None else spec.def_gain)
134 | clamp = float(clamp if clamp is not None else -1)
135 |
136 | # Lookup from cache.
137 | key = (dim, act, alpha, gain, clamp)
138 | if key in _bias_act_cuda_cache:
139 | return _bias_act_cuda_cache[key]
140 |
141 | # Forward op.
142 | class BiasActCuda(torch.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, x, b): # pylint: disable=arguments-differ
145 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
146 | x = x.contiguous(memory_format=ctx.memory_format)
147 | b = b.contiguous() if b is not None else _null_tensor
148 | y = x
149 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
150 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
151 | ctx.save_for_backward(
152 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
153 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
154 | y if 'y' in spec.ref else _null_tensor)
155 | return y
156 |
157 | @staticmethod
158 | def backward(ctx, dy): # pylint: disable=arguments-differ
159 | dy = dy.contiguous(memory_format=ctx.memory_format)
160 | x, b, y = ctx.saved_tensors
161 | dx = None
162 | db = None
163 |
164 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
165 | dx = dy
166 | if act != 'linear' or gain != 1 or clamp >= 0:
167 | dx = BiasActCudaGrad.apply(dy, x, b, y)
168 |
169 | if ctx.needs_input_grad[1]:
170 | db = dx.sum([i for i in range(dx.ndim) if i != dim])
171 |
172 | return dx, db
173 |
174 | # Backward op.
175 | class BiasActCudaGrad(torch.autograd.Function):
176 | @staticmethod
177 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
178 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
179 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
180 | ctx.save_for_backward(
181 | dy if spec.has_2nd_grad else _null_tensor,
182 | x, b, y)
183 | return dx
184 |
185 | @staticmethod
186 | def backward(ctx, d_dx): # pylint: disable=arguments-differ
187 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
188 | dy, x, b, y = ctx.saved_tensors
189 | d_dy = None
190 | d_x = None
191 | d_b = None
192 | d_y = None
193 |
194 | if ctx.needs_input_grad[0]:
195 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
196 |
197 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
198 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
199 |
200 | if spec.has_2nd_grad and ctx.needs_input_grad[2]:
201 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
202 |
203 | return d_dy, d_x, d_b, d_y
204 |
205 | # Add to cache.
206 | _bias_act_cuda_cache[key] = BiasActCuda
207 | return BiasActCuda
208 |
209 | #----------------------------------------------------------------------------
210 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.conv2d` that supports
10 | arbitrarily high order gradients with zero performance penalty."""
11 |
12 | import contextlib
13 | import torch
14 | from pkg_resources import parse_version
15 |
16 | # pylint: disable=redefined-builtin
17 | # pylint: disable=arguments-differ
18 | # pylint: disable=protected-access
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | enabled = False # Enable the custom op by setting this to true.
23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
25 |
26 | @contextlib.contextmanager
27 | def no_weight_gradients(disable=True):
28 | global weight_gradients_disabled
29 | old = weight_gradients_disabled
30 | if disable:
31 | weight_gradients_disabled = True
32 | yield
33 | weight_gradients_disabled = old
34 |
35 | #----------------------------------------------------------------------------
36 |
37 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
38 | if _should_use_custom_op(input):
39 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
40 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
41 |
42 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
43 | if _should_use_custom_op(input):
44 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
45 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _should_use_custom_op(input):
50 | assert isinstance(input, torch.Tensor)
51 | if (not enabled) or (not torch.backends.cudnn.enabled):
52 | return False
53 | if _use_pytorch_1_11_api:
54 | # The work-around code doesn't work on PyTorch 1.11.0 onwards
55 | return False
56 | if input.device.type != 'cuda':
57 | return False
58 | return True
59 |
60 | def _tuple_of_ints(xs, ndim):
61 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
62 | assert len(xs) == ndim
63 | assert all(isinstance(x, int) for x in xs)
64 | return xs
65 |
66 | #----------------------------------------------------------------------------
67 |
68 | _conv2d_gradfix_cache = dict()
69 | _null_tensor = torch.empty([0])
70 |
71 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
72 | # Parse arguments.
73 | ndim = 2
74 | weight_shape = tuple(weight_shape)
75 | stride = _tuple_of_ints(stride, ndim)
76 | padding = _tuple_of_ints(padding, ndim)
77 | output_padding = _tuple_of_ints(output_padding, ndim)
78 | dilation = _tuple_of_ints(dilation, ndim)
79 |
80 | # Lookup from cache.
81 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
82 | if key in _conv2d_gradfix_cache:
83 | return _conv2d_gradfix_cache[key]
84 |
85 | # Validate arguments.
86 | assert groups >= 1
87 | assert len(weight_shape) == ndim + 2
88 | assert all(stride[i] >= 1 for i in range(ndim))
89 | assert all(padding[i] >= 0 for i in range(ndim))
90 | assert all(dilation[i] >= 0 for i in range(ndim))
91 | if not transpose:
92 | assert all(output_padding[i] == 0 for i in range(ndim))
93 | else: # transpose
94 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
95 |
96 | # Helpers.
97 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
98 | def calc_output_padding(input_shape, output_shape):
99 | if transpose:
100 | return [0, 0]
101 | return [
102 | input_shape[i + 2]
103 | - (output_shape[i + 2] - 1) * stride[i]
104 | - (1 - 2 * padding[i])
105 | - dilation[i] * (weight_shape[i + 2] - 1)
106 | for i in range(ndim)
107 | ]
108 |
109 | # Forward & backward.
110 | class Conv2d(torch.autograd.Function):
111 | @staticmethod
112 | def forward(ctx, input, weight, bias):
113 | assert weight.shape == weight_shape
114 | ctx.save_for_backward(
115 | input if weight.requires_grad else _null_tensor,
116 | weight if input.requires_grad else _null_tensor,
117 | )
118 | ctx.input_shape = input.shape
119 |
120 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
121 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
122 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
123 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
124 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
125 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
126 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
127 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
128 |
129 | # General case => cuDNN.
130 | if transpose:
131 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
132 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
133 |
134 | @staticmethod
135 | def backward(ctx, grad_output):
136 | input, weight = ctx.saved_tensors
137 | input_shape = ctx.input_shape
138 | grad_input = None
139 | grad_weight = None
140 | grad_bias = None
141 |
142 | if ctx.needs_input_grad[0]:
143 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
144 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
145 | grad_input = op.apply(grad_output, weight, None)
146 | assert grad_input.shape == input_shape
147 |
148 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
149 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
150 | assert grad_weight.shape == weight_shape
151 |
152 | if ctx.needs_input_grad[2]:
153 | grad_bias = grad_output.sum([0, 2, 3])
154 |
155 | return grad_input, grad_weight, grad_bias
156 |
157 | # Gradient with respect to the weights.
158 | class Conv2dGradWeight(torch.autograd.Function):
159 | @staticmethod
160 | def forward(ctx, grad_output, input):
161 | ctx.save_for_backward(
162 | grad_output if input.requires_grad else _null_tensor,
163 | input if grad_output.requires_grad else _null_tensor,
164 | )
165 | ctx.grad_output_shape = grad_output.shape
166 | ctx.input_shape = input.shape
167 |
168 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
169 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
170 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
171 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
172 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
173 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
174 |
175 | # General case => cuDNN.
176 | name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
177 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
178 | return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
179 |
180 | @staticmethod
181 | def backward(ctx, grad2_grad_weight):
182 | grad_output, input = ctx.saved_tensors
183 | grad_output_shape = ctx.grad_output_shape
184 | input_shape = ctx.input_shape
185 | grad2_grad_output = None
186 | grad2_input = None
187 |
188 | if ctx.needs_input_grad[0]:
189 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
190 | assert grad2_grad_output.shape == grad_output_shape
191 |
192 | if ctx.needs_input_grad[1]:
193 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
194 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
195 | grad2_input = op.apply(grad_output, grad2_grad_weight, None)
196 | assert grad2_input.shape == input_shape
197 |
198 | return grad2_grad_output, grad2_input
199 |
200 | _conv2d_gradfix_cache[key] = Conv2d
201 | return Conv2d
202 |
203 | #----------------------------------------------------------------------------
204 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """2D convolution with optional up/downsampling."""
10 |
11 | import torch
12 |
13 | from .. import misc
14 | from . import conv2d_gradfix
15 | from . import upfirdn2d
16 | from .upfirdn2d import _parse_padding
17 | from .upfirdn2d import _get_filter_size
18 |
19 | #----------------------------------------------------------------------------
20 |
21 | def _get_weight_shape(w):
22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23 | shape = [int(sz) for sz in w.shape]
24 | misc.assert_shape(w, shape)
25 | return shape
26 |
27 | #----------------------------------------------------------------------------
28 |
29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31 | """
32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
33 |
34 | # Flip weight if requested.
35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36 | if not flip_weight and (kw > 1 or kh > 1):
37 | w = w.flip([2, 3])
38 |
39 | # Execute using conv2d_gradfix.
40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
41 | return op(x, w, stride=stride, padding=padding, groups=groups)
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | @misc.profiled_function
46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
47 | r"""2D convolution with optional up/downsampling.
48 |
49 | Padding is performed only once at the beginning, not between the operations.
50 |
51 | Args:
52 | x: Input tensor of shape
53 | `[batch_size, in_channels, in_height, in_width]`.
54 | w: Weight tensor of shape
55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
57 | calling upfirdn2d.setup_filter(). None = identity (default).
58 | up: Integer upsampling factor (default: 1).
59 | down: Integer downsampling factor (default: 1).
60 | padding: Padding with respect to the upsampled image. Can be a single number
61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
62 | (default: 0).
63 | groups: Split input channels into N groups (default: 1).
64 | flip_weight: False = convolution, True = correlation (default: True).
65 | flip_filter: False = convolution, True = correlation (default: False).
66 |
67 | Returns:
68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
69 | """
70 | # Validate arguments.
71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
74 | assert isinstance(up, int) and (up >= 1)
75 | assert isinstance(down, int) and (down >= 1)
76 | assert isinstance(groups, int) and (groups >= 1)
77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
78 | fw, fh = _get_filter_size(f)
79 | px0, px1, py0, py1 = _parse_padding(padding)
80 |
81 | # Adjust padding to account for up/downsampling.
82 | if up > 1:
83 | px0 += (fw + up - 1) // 2
84 | px1 += (fw - up) // 2
85 | py0 += (fh + up - 1) // 2
86 | py1 += (fh - up) // 2
87 | if down > 1:
88 | px0 += (fw - down + 1) // 2
89 | px1 += (fw - down) // 2
90 | py0 += (fh - down + 1) // 2
91 | py1 += (fh - down) // 2
92 |
93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
94 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
97 | return x
98 |
99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
100 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
103 | return x
104 |
105 | # Fast path: downsampling only => use strided convolution.
106 | if down > 1 and up == 1:
107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
109 | return x
110 |
111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
112 | if up > 1:
113 | if groups == 1:
114 | w = w.transpose(0, 1)
115 | else:
116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
117 | w = w.transpose(1, 2)
118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
119 | px0 -= kw - 1
120 | px1 -= kw - up
121 | py0 -= kh - 1
122 | py1 -= kh - up
123 | pxt = max(min(-px0, -px1), 0)
124 | pyt = max(min(-py0, -py1), 0)
125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
127 | if down > 1:
128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
129 | return x
130 |
131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
132 | if up == 1 and down == 1:
133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
135 |
136 | # Fallback: Generic reference implementation.
137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
139 | if down > 1:
140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
141 | return x
142 |
143 | #----------------------------------------------------------------------------
144 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/filtered_lrelu.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct filtered_lrelu_kernel_params
15 | {
16 | // These parameters decide which kernel to use.
17 | int up; // upsampling ratio (1, 2, 4)
18 | int down; // downsampling ratio (1, 2, 4)
19 | int2 fuShape; // [size, 1] | [size, size]
20 | int2 fdShape; // [size, 1] | [size, size]
21 |
22 | int _dummy; // Alignment.
23 |
24 | // Rest of the parameters.
25 | const void* x; // Input tensor.
26 | void* y; // Output tensor.
27 | const void* b; // Bias tensor.
28 | unsigned char* s; // Sign tensor in/out. NULL if unused.
29 | const float* fu; // Upsampling filter.
30 | const float* fd; // Downsampling filter.
31 |
32 | int2 pad0; // Left/top padding.
33 | float gain; // Additional gain factor.
34 | float slope; // Leaky ReLU slope on negative side.
35 | float clamp; // Clamp after nonlinearity.
36 | int flip; // Filter kernel flip for gradient computation.
37 |
38 | int tilesXdim; // Original number of horizontal output tiles.
39 | int tilesXrep; // Number of horizontal tiles per CTA.
40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41 |
42 | int4 xShape; // [width, height, channel, batch]
43 | int4 yShape; // [width, height, channel, batch]
44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46 | int swLimit; // Active width of sign tensor in bytes.
47 |
48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49 | longlong4 yStride; //
50 | int64_t bStride; //
51 | longlong3 fuStride; //
52 | longlong3 fdStride; //
53 | };
54 |
55 | struct filtered_lrelu_act_kernel_params
56 | {
57 | void* x; // Input/output, modified in-place.
58 | unsigned char* s; // Sign tensor in/out. NULL if unused.
59 |
60 | float gain; // Additional gain factor.
61 | float slope; // Leaky ReLU slope on negative side.
62 | float clamp; // Clamp after nonlinearity.
63 |
64 | int4 xShape; // [width, height, channel, batch]
65 | longlong4 xStride; // Input/output tensor strides, same order as in shape.
66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68 | };
69 |
70 | //------------------------------------------------------------------------
71 | // CUDA kernel specialization.
72 |
73 | struct filtered_lrelu_kernel_spec
74 | {
75 | void* setup; // Function for filter kernel setup.
76 | void* exec; // Function for main operation.
77 | int2 tileOut; // Width/height of launch tile.
78 | int numWarps; // Number of warps per thread block, determines launch block size.
79 | int xrep; // For processing multiple horizontal tiles per thread block.
80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81 | };
82 |
83 | //------------------------------------------------------------------------
84 | // CUDA kernel selection.
85 |
86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87 | template void* choose_filtered_lrelu_act_kernel(void);
88 | template cudaError_t copy_filters(cudaStream_t stream);
89 |
90 | //------------------------------------------------------------------------
91 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/filtered_lrelu_ns.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for no signs mode (no gradients required).
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/filtered_lrelu_rd.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign read mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/filtered_lrelu_wr.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign write mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/fma.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10 |
11 | import torch
12 |
13 | #----------------------------------------------------------------------------
14 |
15 | def fma(a, b, c): # => a * b + c
16 | return _FusedMultiplyAdd.apply(a, b, c)
17 |
18 | #----------------------------------------------------------------------------
19 |
20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21 | @staticmethod
22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23 | out = torch.addcmul(c, a, b)
24 | ctx.save_for_backward(a, b)
25 | ctx.c_shape = c.shape
26 | return out
27 |
28 | @staticmethod
29 | def backward(ctx, dout): # pylint: disable=arguments-differ
30 | a, b = ctx.saved_tensors
31 | c_shape = ctx.c_shape
32 | da = None
33 | db = None
34 | dc = None
35 |
36 | if ctx.needs_input_grad[0]:
37 | da = _unbroadcast(dout * b, a.shape)
38 |
39 | if ctx.needs_input_grad[1]:
40 | db = _unbroadcast(dout * a, b.shape)
41 |
42 | if ctx.needs_input_grad[2]:
43 | dc = _unbroadcast(dout, c_shape)
44 |
45 | return da, db, dc
46 |
47 | #----------------------------------------------------------------------------
48 |
49 | def _unbroadcast(x, shape):
50 | extra_dims = x.ndim - len(shape)
51 | assert extra_dims >= 0
52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53 | if len(dim):
54 | x = x.sum(dim=dim, keepdim=True)
55 | if extra_dims:
56 | x = x.reshape(-1, *x.shape[extra_dims+1:])
57 | assert x.shape == shape
58 | return x
59 |
60 | #----------------------------------------------------------------------------
61 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Custom replacement for `torch.nn.functional.grid_sample` that
10 | supports arbitrarily high order gradients between the input and output.
11 | Only works on 2D images and assumes
12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13 |
14 | import torch
15 | from pkg_resources import parse_version
16 |
17 | # pylint: disable=redefined-builtin
18 | # pylint: disable=arguments-differ
19 | # pylint: disable=protected-access
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | enabled = False # Enable the custom op by setting this to true.
24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
25 |
26 | #----------------------------------------------------------------------------
27 |
28 | def grid_sample(input, grid):
29 | if _should_use_custom_op():
30 | return _GridSample2dForward.apply(input, grid)
31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def _should_use_custom_op():
36 | return enabled
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | class _GridSample2dForward(torch.autograd.Function):
41 | @staticmethod
42 | def forward(ctx, input, grid):
43 | assert input.ndim == 4
44 | assert grid.ndim == 4
45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
46 | ctx.save_for_backward(input, grid)
47 | return output
48 |
49 | @staticmethod
50 | def backward(ctx, grad_output):
51 | input, grid = ctx.saved_tensors
52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
53 | return grad_input, grad_grid
54 |
55 | #----------------------------------------------------------------------------
56 |
57 | class _GridSample2dBackward(torch.autograd.Function):
58 | @staticmethod
59 | def forward(ctx, grad_output, input, grid):
60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
61 | if _use_pytorch_1_11_api:
62 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
63 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
64 | else:
65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66 | ctx.save_for_backward(grid)
67 | return grad_input, grad_grid
68 |
69 | @staticmethod
70 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
71 | _ = grad2_grad_grid # unused
72 | grid, = ctx.saved_tensors
73 | grad2_grad_output = None
74 | grad2_input = None
75 | grad2_grid = None
76 |
77 | if ctx.needs_input_grad[0]:
78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79 |
80 | assert not ctx.needs_input_grad[2]
81 | return grad2_grad_output, grad2_input, grad2_grid
82 |
83 | #----------------------------------------------------------------------------
84 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "upfirdn2d.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17 | {
18 | // Validate arguments.
19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24 | TORCH_CHECK(x.numel() > 0, "x has zero size");
25 | TORCH_CHECK(f.numel() > 0, "f has zero size");
26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
32 |
33 | // Create output tensor.
34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
41 |
42 | // Initialize CUDA kernel parameters.
43 | upfirdn2d_kernel_params p;
44 | p.x = x.data_ptr();
45 | p.f = f.data_ptr();
46 | p.y = y.data_ptr();
47 | p.up = make_int2(upx, upy);
48 | p.down = make_int2(downx, downy);
49 | p.pad0 = make_int2(padx0, pady0);
50 | p.flip = (flip) ? 1 : 0;
51 | p.gain = gain;
52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
60 |
61 | // Choose CUDA kernel.
62 | upfirdn2d_kernel_spec spec;
63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
64 | {
65 | spec = choose_upfirdn2d_kernel(p);
66 | });
67 |
68 | // Set looping options.
69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
70 | p.loopMinor = spec.loopMinor;
71 | p.loopX = spec.loopX;
72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
74 |
75 | // Compute grid size.
76 | dim3 blockSize, gridSize;
77 | if (spec.tileOutW < 0) // large
78 | {
79 | blockSize = dim3(4, 32, 1);
80 | gridSize = dim3(
81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
83 | p.launchMajor);
84 | }
85 | else // small
86 | {
87 | blockSize = dim3(256, 1, 1);
88 | gridSize = dim3(
89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
91 | p.launchMajor);
92 | }
93 |
94 | // Launch CUDA kernel.
95 | void* args[] = {&p};
96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
97 | return y;
98 | }
99 |
100 | //------------------------------------------------------------------------
101 |
102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
103 | {
104 | m.def("upfirdn2d", &upfirdn2d);
105 | }
106 |
107 | //------------------------------------------------------------------------
108 |
--------------------------------------------------------------------------------
/ADD/th_utils/ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/ADD/utils/__pycache__/util_net.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/ADD/utils/__pycache__/util_net.cpython-310.pyc
--------------------------------------------------------------------------------
/ADD/utils/util_net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # Power by Zongsheng Yue 2021-11-24 20:29:36
4 |
5 | import math
6 | import torch
7 | from pathlib import Path
8 | from collections import OrderedDict
9 | import torch.nn.functional as F
10 | from copy import deepcopy
11 |
12 | def calculate_parameters(net):
13 | out = 0
14 | for param in net.parameters():
15 | out += param.numel()
16 | return out
17 |
18 | def pad_input(x, mod):
19 | h, w = x.shape[-2:]
20 | bottom = int(math.ceil(h/mod)*mod -h)
21 | right = int(math.ceil(w/mod)*mod - w)
22 | x_pad = F.pad(x, pad=(0, right, 0, bottom), mode='reflect')
23 | return x_pad
24 |
25 | def forward_chop(net, x, net_kwargs=None, scale=1, shave=10, min_size=160000):
26 | n_GPUs = 1
27 | b, c, h, w = x.size()
28 | h_half, w_half = h // 2, w // 2
29 | h_size, w_size = h_half + shave, w_half + shave
30 | lr_list = [
31 | x[:, :, 0:h_size, 0:w_size],
32 | x[:, :, 0:h_size, (w - w_size):w],
33 | x[:, :, (h - h_size):h, 0:w_size],
34 | x[:, :, (h - h_size):h, (w - w_size):w]]
35 |
36 | if w_size * h_size < min_size:
37 | sr_list = []
38 | for i in range(0, 4, n_GPUs):
39 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
40 | if net_kwargs is None:
41 | sr_batch = net(lr_batch)
42 | else:
43 | sr_batch = net(lr_batch, **net_kwargs)
44 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
45 | else:
46 | sr_list = [
47 | forward_chop(patch, shave=shave, min_size=min_size) \
48 | for patch in lr_list
49 | ]
50 |
51 | h, w = scale * h, scale * w
52 | h_half, w_half = scale * h_half, scale * w_half
53 | h_size, w_size = scale * h_size, scale * w_size
54 | shave *= scale
55 |
56 | output = x.new(b, c, h, w)
57 | output[:, :, 0:h_half, 0:w_half] \
58 | = sr_list[0][:, :, 0:h_half, 0:w_half]
59 | output[:, :, 0:h_half, w_half:w] \
60 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
61 | output[:, :, h_half:h, 0:w_half] \
62 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
63 | output[:, :, h_half:h, w_half:w] \
64 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
65 |
66 | return output
67 |
68 | def measure_time(net, inputs, num_forward=100):
69 | '''
70 | Measuring the average runing time (seconds) for pytorch.
71 | out = net(*inputs)
72 | '''
73 | start = torch.cuda.Event(enable_timing=True)
74 | end = torch.cuda.Event(enable_timing=True)
75 |
76 | start.record()
77 | with torch.set_grad_enabled(False):
78 | for _ in range(num_forward):
79 | out = net(*inputs)
80 | end.record()
81 |
82 | torch.cuda.synchronize()
83 |
84 | return start.elapsed_time(end) / 1000
85 |
86 | def reload_model(model, ckpt):
87 | if list(model.state_dict().keys())[0].startswith('module.'):
88 | if list(ckpt.keys())[0].startswith('module.'):
89 | ckpt = ckpt
90 | else:
91 | ckpt = OrderedDict({f'module.{key}':value for key, value in ckpt.items()})
92 | else:
93 | if list(ckpt.keys())[0].startswith('module.'):
94 | ckpt = OrderedDict({key[7:]:value for key, value in ckpt.items()})
95 | else:
96 | ckpt = ckpt
97 | model.load_state_dict(ckpt, True)
98 |
99 | def compute_hinge_loss(real_output, fake_output, x_start_, r1_lambda):
100 | if r1_lambda == 0:
101 | real_loss_total = torch.relu(torch.ones_like(real_output) - real_output).mean()
102 | fake_loss_total = torch.relu(torch.ones_like(fake_output) + fake_output).mean()
103 |
104 | else:
105 | real_loss_ = torch.relu(torch.ones_like(real_output) - real_output).mean()
106 |
107 | # 计算真实样本的梯度
108 | grad_real = torch.autograd.grad(outputs=real_output.sum(), inputs=x_start_, create_graph=True)[0]
109 |
110 | # 计算梯度惩罚
111 | grad_penalty = (grad_real.contiguous().view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean() * r1_lambda
112 |
113 | real_loss_total = real_loss_ + grad_penalty
114 | fake_loss_total = torch.relu(torch.ones_like(fake_output) + fake_output).mean()
115 |
116 | real_loss = real_loss_total
117 | fake_loss = fake_loss_total
118 |
119 | loss_d = real_loss + fake_loss
120 |
121 | return loss_d
122 |
123 |
124 |
125 | def reload_model_(model, ckpt):
126 | if list(model.state_dict().keys())[0].startswith('model.'):
127 | if list(ckpt.keys())[0].startswith('model.'):
128 | ckpt = ckpt
129 | else:
130 | ckpt = OrderedDict({f'model.{key}':value for key, value in ckpt.items()})
131 | else:
132 | if list(ckpt.keys())[0].startswith('model.'):
133 | ckpt = OrderedDict({key[7:]:value for key, value in ckpt.items()})
134 | else:
135 | ckpt = ckpt
136 | model.load_state_dict(ckpt, True)
137 |
138 |
139 |
140 | def reload_model_IDE(model, ckpt):
141 | extracted_dict = OrderedDict()
142 | for key, value in ckpt.items():
143 | if key.startswith('E_st'):
144 | new_key = key.replace('E_st.', '')
145 | extracted_dict[new_key] = value
146 |
147 | model.load_state_dict(extracted_dict, True)
148 |
149 |
150 |
151 | class EMA():
152 | def __init__(self, model, decay):
153 | self.model = model
154 | self.decay = decay
155 | self.shadow = {}
156 | self.backup = {}
157 |
158 | def register(self):
159 | for name, param in self.model.named_parameters():
160 | if param.requires_grad:
161 | self.shadow[name] = param.data.clone()
162 |
163 | def update(self):
164 | for name, param in self.model.named_parameters():
165 | if param.requires_grad:
166 | assert name in self.shadow
167 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
168 | self.shadow[name] = new_average.clone()
169 |
170 | def apply_shadow(self):
171 | for name, param in self.model.named_parameters():
172 | if param.requires_grad:
173 | assert name in self.shadow
174 | self.backup[name] = param.data
175 | param.data = self.shadow[name]
176 |
177 | def restore(self):
178 | for name, param in self.model.named_parameters():
179 | if param.requires_grad:
180 | assert name in self.backup
181 | param.data = self.backup[name]
182 | self.backup = {}
183 |
--------------------------------------------------------------------------------
/dataloaders/__pycache__/paired_dataset_txt.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/dataloaders/__pycache__/paired_dataset_txt.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/realesrgan.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/dataloaders/__pycache__/realesrgan.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/paired_dataset_txt.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | from PIL import Image
4 | import random
5 | import numpy as np
6 |
7 | from torch import nn
8 | from torchvision import transforms
9 | from torch.utils import data as data
10 | import torch.nn.functional as F
11 |
12 | from .realesrgan import RealESRGAN_degradation
13 |
14 | class PairedCaptionDataset(data.Dataset):
15 | def __init__(
16 | self,
17 | root_folders=None,
18 | tokenizer=None,
19 | gt_ratio=0, # let lr is gt
20 | ):
21 | super(PairedCaptionDataset, self).__init__()
22 |
23 | self.gt_ratio = gt_ratio
24 | with open(root_folders, 'r') as f:
25 | self.gt_list = [line.strip() for line in f.readlines()]
26 |
27 | self.img_preproc = transforms.Compose([
28 | transforms.RandomCrop((512, 512)),
29 | transforms.Resize((512, 512)),
30 | transforms.RandomHorizontalFlip(),
31 | ])
32 |
33 | self.degradation = RealESRGAN_degradation('dataloaders/params_ccsr.yml', device='cuda')
34 | self.tokenizer = tokenizer
35 |
36 |
37 | def tokenize_caption(self, caption=""):
38 | inputs = self.tokenizer(
39 | caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
40 | )
41 |
42 | return inputs.input_ids
43 |
44 | def __getitem__(self, index):
45 |
46 | gt_path = self.gt_list[index]
47 | gt_img = Image.open(gt_path).convert('RGB')
48 | gt_img = self.img_preproc(gt_img)
49 |
50 | gt_img, img_t = self.degradation.degrade_process(np.asarray(gt_img)/255., resize_bak=True)
51 |
52 | if random.random() < self.gt_ratio:
53 | lq_img = gt_img
54 | else:
55 | lq_img = img_t
56 |
57 | # no caption used
58 | lq_caption = ''
59 |
60 | example = dict()
61 | example["conditioning_pixel_values"] = lq_img.squeeze(0) # [0, 1]
62 | example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0 # [-1, 1]
63 | example["input_caption"] = self.tokenize_caption(caption=lq_caption).squeeze(0)
64 |
65 | lq_img = lq_img.squeeze()
66 |
67 | return example
68 |
69 | def __len__(self):
70 | return len(self.gt_list)
--------------------------------------------------------------------------------
/dataloaders/params_ccsr.yml:
--------------------------------------------------------------------------------
1 | scale: 4
2 | color_jitter_prob: 0.0
3 | gray_prob: 0.0
4 |
5 | # the first degradation process
6 | resize_prob: [0.2, 0.7, 0.1] # up, down, keep
7 | resize_range: [0.3, 1.5]
8 | gaussian_noise_prob: 0.5
9 | noise_range: [1, 15]
10 | poisson_scale_range: [0.05, 2.0]
11 | gray_noise_prob: 0.4
12 | jpeg_range: [60, 95]
13 |
14 |
15 | # the second degradation process
16 | second_blur_prob: 0.5
17 | resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
18 | resize_range2: [0.6, 1.2]
19 | gaussian_noise_prob2: 0.5
20 | noise_range2: [1, 12]
21 | poisson_scale_range2: [0.05, 1.0]
22 | gray_noise_prob2: 0.4
23 | jpeg_range2: [60, 100]
24 |
25 | kernel_info:
26 | blur_kernel_size: 21
27 | kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
28 | kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
29 | sinc_prob: 0.1
30 | blur_sigma: [0.2, 1.5]
31 | betag_range: [0.5, 2.0]
32 | betap_range: [1, 1.5]
33 |
34 | blur_kernel_size2: 11
35 | kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
36 | kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
37 | sinc_prob2: 0.1
38 | blur_sigma2: [0.2, 1.0]
39 | betag_range2: [0.5, 2.0]
40 | betap_range2: [1, 1.5]
41 |
42 | final_sinc_prob: 0.8
43 |
--------------------------------------------------------------------------------
/figs/compare_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/figs/compare_1.png
--------------------------------------------------------------------------------
/figs/compare_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/figs/compare_2.png
--------------------------------------------------------------------------------
/figs/compare_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/figs/compare_3.png
--------------------------------------------------------------------------------
/figs/compare_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/figs/compare_4.png
--------------------------------------------------------------------------------
/figs/compare_efficient.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/figs/compare_efficient.png
--------------------------------------------------------------------------------
/figs/compare_standard.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/figs/compare_standard.png
--------------------------------------------------------------------------------
/figs/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/figs/fig.png
--------------------------------------------------------------------------------
/figs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/figs/framework.png
--------------------------------------------------------------------------------
/figs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/figs/logo.png
--------------------------------------------------------------------------------
/figs/table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/figs/table.png
--------------------------------------------------------------------------------
/models/DiffAugment.py:
--------------------------------------------------------------------------------
1 | # BSD 2-Clause "Simplified" License
2 | # Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3 | # All rights reserved.
4 | #
5 | # Redistribution and use in source and binary forms, with or without
6 | # modification, are permitted provided that the following conditions are met:
7 | #
8 | # * Redistributions of source code must retain the above copyright notice, this
9 | # list of conditions and the following disclaimer.
10 | #
11 | # * Redistributions in binary form must reproduce the above copyright notice,
12 | # this list of conditions and the following disclaimer in the documentation
13 | # and/or other materials provided with the distribution.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | #
26 | # Code from https://github.com/mit-han-lab/data-efficient-gans
27 |
28 | """Training GANs with DiffAugment."""
29 |
30 | import numpy as np
31 | import torch
32 | import torch.nn.functional as F
33 |
34 |
35 | def DiffAugment(x: torch.Tensor, policy: str = '', channels_first: bool = True) -> torch.Tensor:
36 | if policy:
37 | if not channels_first:
38 | x = x.permute(0, 3, 1, 2)
39 | for p in policy.split(','):
40 | for f in AUGMENT_FNS[p]:
41 | x = f(x)
42 | if not channels_first:
43 | x = x.permute(0, 2, 3, 1)
44 | x = x.contiguous()
45 | return x
46 |
47 |
48 | def rand_brightness(x: torch.Tensor) -> torch.Tensor:
49 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
50 | return x
51 |
52 |
53 | def rand_saturation(x: torch.Tensor) -> torch.Tensor:
54 | x_mean = x.mean(dim=1, keepdim=True)
55 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
56 | return x
57 |
58 |
59 | def rand_contrast(x: torch.Tensor) -> torch.Tensor:
60 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
61 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
62 | return x
63 |
64 |
65 | def rand_translation(x: torch.Tensor, ratio: float = 0.125) -> torch.Tensor:
66 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
67 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
68 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
69 | grid_batch, grid_x, grid_y = torch.meshgrid(
70 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
71 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
72 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
73 | )
74 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
75 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
76 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
77 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
78 | return x
79 |
80 |
81 | def rand_cutout(x: torch.Tensor, ratio: float = 0.2) -> torch.Tensor:
82 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
83 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
84 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
85 | grid_batch, grid_x, grid_y = torch.meshgrid(
86 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
87 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
88 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
89 | )
90 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
91 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
92 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
93 | mask[grid_batch, grid_x, grid_y] = 0
94 | x = x * mask.unsqueeze(1)
95 | return x
96 |
97 |
98 | def rand_resize(x: torch.Tensor, min_ratio: float = 0.8, max_ratio: float = 1.2) -> torch.Tensor:
99 | resize_ratio = np.random.rand()*(max_ratio-min_ratio) + min_ratio
100 | resized_img = F.interpolate(x, size=int(resize_ratio*x.shape[3]), mode='bilinear')
101 | org_size = x.shape[3]
102 | if int(resize_ratio*x.shape[3]) < x.shape[3]:
103 | left_pad = (x.shape[3]-int(resize_ratio*x.shape[3]))/2.
104 | left_pad = int(left_pad)
105 | right_pad = x.shape[3] - left_pad - resized_img.shape[3]
106 | x = F.pad(resized_img, (left_pad, right_pad, left_pad, right_pad), "constant", 0.)
107 | else:
108 | left = (int(resize_ratio*x.shape[3])-x.shape[3])/2.
109 | left = int(left)
110 | x = resized_img[:, :, left:(left+x.shape[3]), left:(left+x.shape[3])]
111 | assert x.shape[2] == org_size
112 | assert x.shape[3] == org_size
113 | return x
114 |
115 |
116 | AUGMENT_FNS = {
117 | 'color': [rand_brightness, rand_saturation, rand_contrast],
118 | 'translation': [rand_translation],
119 | 'resize': [rand_resize],
120 | 'cutout': [rand_cutout],
121 | }
--------------------------------------------------------------------------------
/models/__pycache__/DiffAugment.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/__pycache__/DiffAugment.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/controlnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/__pycache__/controlnet.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/shared.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/__pycache__/shared.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet_2d_blocks.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/__pycache__/unet_2d_blocks.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/unet_2d_condition.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/__pycache__/unet_2d_condition.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/vit_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/__pycache__/vit_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from models.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/models/losses/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/losses/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/models/losses/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/losses/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/losses/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/losses/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/losses/__pycache__/contperceptual.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/losses/__pycache__/contperceptual.cpython-310.pyc
--------------------------------------------------------------------------------
/models/losses/__pycache__/contperceptual.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/losses/__pycache__/contperceptual.cpython-37.pyc
--------------------------------------------------------------------------------
/models/losses/__pycache__/contperceptual.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/models/losses/__pycache__/contperceptual.cpython-38.pyc
--------------------------------------------------------------------------------
/models/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 | from diffusers.models.modeling_utils import ModelMixin
6 | from diffusers.configuration_utils import ConfigMixin, register_to_config
7 | from diffusers.loaders import FromOriginalControlnetMixin
8 |
9 | class LPIPSWithDiscriminator(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
10 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
11 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
12 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
13 | disc_loss="hinge"):
14 |
15 | super().__init__()
16 | assert disc_loss in ["hinge", "vanilla"]
17 | self.kl_weight = kl_weight
18 | self.pixel_weight = pixelloss_weight
19 | self.perceptual_loss = LPIPS().eval()
20 | self.perceptual_weight = perceptual_weight
21 | # output log variance
22 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
23 |
24 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
25 | n_layers=disc_num_layers,
26 | use_actnorm=use_actnorm
27 | ).apply(weights_init)
28 | self.discriminator_iter_start = disc_start
29 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
30 | self.disc_factor = disc_factor
31 | self.discriminator_weight = disc_weight
32 | self.disc_conditional = disc_conditional
33 |
34 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
35 | if last_layer is not None:
36 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
37 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
38 | else:
39 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
40 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
41 |
42 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
43 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
44 | d_weight = d_weight * self.discriminator_weight
45 | return d_weight
46 |
47 | def forward(self, inputs, reconstructions, optimizer_idx,
48 | global_step, posteriors=None, last_layer=None, cond=None, split="train",
49 | weights=None, return_dic=False):
50 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
51 | if self.perceptual_weight > 0:
52 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
53 | rec_loss = rec_loss + self.perceptual_weight * p_loss
54 |
55 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
56 | weighted_nll_loss = nll_loss
57 | if weights is not None:
58 | weighted_nll_loss = weights*nll_loss
59 | weighted_nll_loss = torch.mean(weighted_nll_loss) / weighted_nll_loss.shape[0]
60 | nll_loss = torch.mean(nll_loss) / nll_loss.shape[0]
61 | if self.kl_weight>0:
62 | kl_loss = posteriors.kl()
63 | kl_loss = torch.mean(kl_loss) / kl_loss.shape[0]
64 |
65 | # now the GAN part
66 | if optimizer_idx == 0:
67 | # generator update
68 | if cond is None:
69 | assert not self.disc_conditional
70 | logits_fake = self.discriminator(reconstructions.contiguous())
71 | else:
72 | assert self.disc_conditional
73 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
74 | g_loss = -torch.mean(logits_fake)
75 |
76 | if self.disc_factor > 0.0:
77 | try:
78 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
79 | except RuntimeError:
80 | # assert not self.training
81 | d_weight = torch.tensor(1.0) * self.discriminator_weight
82 | else:
83 | # d_weight = torch.tensor(0.0)
84 | d_weight = torch.tensor(0.0)
85 |
86 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
87 | if self.kl_weight>0:
88 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
89 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
90 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
91 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
92 | "{}/d_weight".format(split): d_weight.detach(),
93 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
94 | "{}/g_loss".format(split): g_loss.detach().mean(),
95 | }
96 | if return_dic:
97 | loss_dic = {}
98 | loss_dic['total_loss'] = loss.clone().detach().mean()
99 | loss_dic['logvar'] = self.logvar.detach()
100 | loss_dic['kl_loss'] = kl_loss.detach().mean()
101 | loss_dic['nll_loss'] = nll_loss.detach().mean()
102 | loss_dic['rec_loss'] = rec_loss.detach().mean()
103 | loss_dic['d_weight'] = d_weight.detach()
104 | loss_dic['disc_factor'] = torch.tensor(disc_factor)
105 | loss_dic['g_loss'] = g_loss.detach().mean()
106 | else:
107 | loss = weighted_nll_loss + d_weight * disc_factor * g_loss
108 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
109 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
110 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
111 | "{}/d_weight".format(split): d_weight.detach(),
112 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
113 | "{}/g_loss".format(split): g_loss.detach().mean(),
114 | }
115 | if return_dic:
116 | loss_dic = {}
117 | loss_dic["{}/total_loss".format(split)] = loss.clone().detach().mean()
118 | loss_dic["{}/logvar".format(split)] = self.logvar.detach()
119 | loss_dic['nll_loss'.format(split)] = nll_loss.detach().mean()
120 | loss_dic['rec_loss'.format(split)] = rec_loss.detach().mean()
121 | loss_dic['d_weight'.format(split)] = d_weight.detach()
122 | loss_dic['disc_factor'.format(split)] = torch.tensor(disc_factor)
123 | loss_dic['g_loss'.format(split)] = g_loss.detach().mean()
124 |
125 | if return_dic:
126 | return loss, log, loss_dic
127 | return loss, log
128 |
129 | if optimizer_idx == 1:
130 | # second pass for discriminator update
131 | if cond is None:
132 | logits_real = self.discriminator(inputs.contiguous().detach())
133 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
134 | else:
135 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
136 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
137 |
138 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
139 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
140 |
141 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
142 | "{}/logits_real".format(split): logits_real.detach().mean(),
143 | "{}/logits_fake".format(split): logits_fake.detach().mean()
144 | }
145 |
146 | if return_dic:
147 | loss_dic = {}
148 | loss_dic["{}/disc_loss".format(split)] = d_loss.clone().detach().mean()
149 | loss_dic["{}/logits_real".format(split)] = logits_real.detach().mean()
150 | loss_dic["{}/logits_fake".format(split)] = logits_fake.detach().mean()
151 | return d_loss, log, loss_dic
152 |
153 | return d_loss, log
154 |
155 |
--------------------------------------------------------------------------------
/models/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | # def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | # if last_layer is not None:
87 | # nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | # g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | # else:
90 | # nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | # g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | # d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | # d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | # d_weight = d_weight * self.discriminator_weight
96 | # return d_weight
97 |
98 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
99 | # if last_layer is not None:
100 | # nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
101 | # g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
102 | # else:
103 | # nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
104 | # g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
105 |
106 | # d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
107 | # d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
108 | d_weight = 1.0 * self.discriminator_weight
109 | return d_weight
110 |
111 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
112 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
113 | if not exists(codebook_loss):
114 | codebook_loss = torch.tensor([0.]).to(inputs.device)
115 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
116 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
117 | if self.perceptual_weight > 0:
118 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
119 | rec_loss = rec_loss + self.perceptual_weight * p_loss
120 | else:
121 | p_loss = torch.tensor([0.0])
122 |
123 | nll_loss = rec_loss
124 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
125 | nll_loss = torch.mean(nll_loss)
126 |
127 | # now the GAN part
128 | if optimizer_idx == 0:
129 | # generator update
130 | if cond is None:
131 | assert not self.disc_conditional
132 | logits_fake = self.discriminator(reconstructions.contiguous())
133 | else:
134 | assert self.disc_conditional
135 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
136 | g_loss = -torch.mean(logits_fake)
137 |
138 | try:
139 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
140 | except RuntimeError:
141 | assert not self.training
142 | d_weight = torch.tensor(0.0)
143 |
144 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
145 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
146 |
147 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
148 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
149 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
150 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
151 | "{}/p_loss".format(split): p_loss.detach().mean(),
152 | "{}/d_weight".format(split): d_weight.detach(),
153 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
154 | "{}/g_loss".format(split): g_loss.detach().mean(),
155 | }
156 | if predicted_indices is not None:
157 | assert self.n_classes is not None
158 | with torch.no_grad():
159 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
160 | log[f"{split}/perplexity"] = perplexity
161 | log[f"{split}/cluster_usage"] = cluster_usage
162 | return loss, log
163 |
164 | if optimizer_idx == 1:
165 | # second pass for discriminator update
166 | if cond is None:
167 | logits_real = self.discriminator(inputs.contiguous().detach())
168 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
169 | else:
170 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
171 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
172 |
173 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
174 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
175 |
176 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
177 | "{}/logits_real".format(split): logits_real.detach().mean(),
178 | "{}/logits_fake".format(split): logits_fake.detach().mean()
179 | }
180 | return d_loss, log
181 |
--------------------------------------------------------------------------------
/models/shared.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | """Shared architecture blocks."""
10 |
11 | from typing import Callable
12 |
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 |
17 | from ADD.th_utils.ops import bias_act
18 |
19 |
20 | class ResidualBlock(nn.Module):
21 | def __init__(self, fn: Callable):
22 | super().__init__()
23 | self.fn = fn
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return (self.fn(x) + x) / np.sqrt(2)
27 |
28 |
29 | class FullyConnectedLayer(nn.Module):
30 | def __init__(
31 | self,
32 | in_features: int, # Number of input features.
33 | out_features: int, # Number of output features.
34 | bias: bool = True, # Apply additive bias before the activation function?
35 | activation: str = 'linear', # Activation function: 'relu', 'lrelu', etc.
36 | lr_multiplier: float = 1.0, # Learning rate multiplier.
37 | weight_init: float = 1.0, # Initial standard deviation of the weight tensor.
38 | bias_init: float = 0.0, # Initial value for the additive bias.
39 | ):
40 |
41 | super().__init__()
42 | self.in_features = in_features
43 | self.out_features = out_features
44 | self.activation = activation
45 | self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))
46 | bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
47 | self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
48 | self.weight_gain = lr_multiplier / np.sqrt(in_features)
49 | self.bias_gain = lr_multiplier
50 |
51 | def forward(self, x: torch.Tensor) -> torch.Tensor:
52 | w = self.weight.to(x.dtype) * self.weight_gain
53 | b = self.bias
54 | if b is not None:
55 | b = b.to(x.dtype)
56 | if self.bias_gain != 1:
57 | b = b * self.bias_gain
58 |
59 | if self.activation == 'linear' and b is not None:
60 | x = torch.addmm(b.unsqueeze(0), x, w.t())
61 | else:
62 | x = x.matmul(w.t())
63 | x = bias_act.bias_act(x, b, act=self.activation)
64 | return x
65 |
66 | def extra_repr(self) -> str:
67 | return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
68 |
69 |
70 | class MLP(nn.Module):
71 | def __init__(
72 | self,
73 | features_list: list[int], # Number of features in each layer of the MLP.
74 | activation: str = 'linear', # Activation function: 'relu', 'lrelu', etc.
75 | lr_multiplier: float = 1.0, # Learning rate multiplier.
76 | linear_out: bool = False # Use the 'linear' activation function for the output layer?
77 | ):
78 | super().__init__()
79 | num_layers = len(features_list) - 1
80 | self.num_layers = num_layers
81 | self.out_dim = features_list[-1]
82 |
83 | for idx in range(num_layers):
84 | in_features = features_list[idx]
85 | out_features = features_list[idx + 1]
86 | if linear_out and idx == num_layers-1:
87 | activation = 'linear'
88 | layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
89 | setattr(self, f'fc{idx}', layer)
90 |
91 | def forward(self, x: torch.Tensor) -> torch.Tensor:
92 | ''' if x is sequence of tokens, shift tokens to batch and apply MLP to all'''
93 | shift2batch = (x.ndim == 3)
94 |
95 | if shift2batch:
96 | B, K, C = x.shape
97 | x = x.flatten(0,1)
98 |
99 | for idx in range(self.num_layers):
100 | layer = getattr(self, f'fc{idx}')
101 | x = layer(x)
102 |
103 | if shift2batch:
104 | x = x.reshape(B, K, -1)
105 |
106 | return x
107 |
--------------------------------------------------------------------------------
/models/vit_utils.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2021 Intel ISL (Intel Intelligent Systems Lab)
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | # Based on code from https://github.com/isl-org/DPT
24 |
25 | """Flexible configuration and feature extraction of timm VisionTransformers."""
26 |
27 | import types
28 | import math
29 | from typing import Callable
30 |
31 | import torch
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 |
35 |
36 | class AddReadout(nn.Module):
37 | def __init__(self, start_index: bool = 1):
38 | super(AddReadout, self).__init__()
39 | self.start_index = start_index
40 |
41 | def forward(self, x: torch.Tensor) -> torch.Tensor:
42 | if self.start_index == 2:
43 | readout = (x[:, 0] + x[:, 1]) / 2
44 | else:
45 | readout = x[:, 0]
46 | return x[:, self.start_index:] + readout.unsqueeze(1)
47 |
48 |
49 | class Transpose(nn.Module):
50 | def __init__(self, dim0: int, dim1: int):
51 | super(Transpose, self).__init__()
52 | self.dim0 = dim0
53 | self.dim1 = dim1
54 |
55 | def forward(self, x: torch.Tensor) -> torch.Tensor:
56 | x = x.transpose(self.dim0, self.dim1)
57 | return x.contiguous()
58 |
59 |
60 | def forward_vit(pretrained: nn.Module, x: torch.Tensor) -> dict:
61 | _, _, H, W = x.size()
62 | _ = pretrained.model.forward_flex(x)
63 | return {k: pretrained.rearrange(v) for k, v in activations.items()}
64 |
65 |
66 | def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor:
67 | posemb_tok, posemb_grid = (
68 | posemb[:, : self.start_index],
69 | posemb[0, self.start_index :],
70 | )
71 |
72 | gs_old = int(math.sqrt(len(posemb_grid)))
73 |
74 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
75 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False)
76 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
77 |
78 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
79 |
80 | return posemb
81 |
82 |
83 | def forward_flex(self, x: torch.Tensor) -> torch.Tensor:
84 | # patch proj and dynamically resize
85 | B, C, H, W = x.size()
86 | x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
87 | pos_embed = self._resize_pos_embed(
88 | self.pos_embed, H // self.patch_size[1], W // self.patch_size[0]
89 | )
90 |
91 | # add cls token
92 | cls_tokens = self.cls_token.expand(
93 | x.size(0), -1, -1
94 | )
95 | x = torch.cat((cls_tokens, x), dim=1)
96 |
97 | # forward pass
98 | x = x + pos_embed
99 | x = self.pos_drop(x)
100 |
101 | for blk in self.blocks:
102 | x = blk(x)
103 |
104 | x = self.norm(x)
105 | return x
106 |
107 |
108 | activations = {}
109 |
110 |
111 | def get_activation(name: str) -> Callable:
112 | def hook(model, input, output):
113 | activations[name] = output
114 | return hook
115 |
116 |
117 | def make_sd_backbone(
118 | model: nn.Module,
119 | hooks: list[int] = [2, 5, 8, 11],
120 | hook_patch: bool = True,
121 | start_index: list[int] = 1,
122 | ):
123 | assert len(hooks) == 4
124 |
125 | pretrained = nn.Module()
126 | pretrained.model = model
127 |
128 | # add hooks
129 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0'))
130 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1'))
131 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2'))
132 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3'))
133 | if hook_patch:
134 | pretrained.model.pos_drop.register_forward_hook(get_activation('4'))
135 |
136 | # configure readout
137 | pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2))
138 | pretrained.model.start_index = start_index
139 | pretrained.model.patch_size = patch_size
140 |
141 | # We inject this function into the VisionTransformer instances so that
142 | # we can use it with interpolated position embeddings without modifying the library source.
143 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
144 | pretrained.model._resize_pos_embed = types.MethodType(
145 | _resize_pos_embed, pretrained.model
146 | )
147 |
148 | return pretrained
149 |
150 | def make_vit_backbone(
151 | model: nn.Module,
152 | patch_size: list[int] = [16, 16],
153 | hooks: list[int] = [2, 5, 8, 11],
154 | hook_patch: bool = True,
155 | start_index: list[int] = 1,
156 | ):
157 | assert len(hooks) == 4
158 |
159 | pretrained = nn.Module()
160 | pretrained.model = model
161 |
162 | # add hooks
163 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0'))
164 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1'))
165 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2'))
166 | pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3'))
167 | if hook_patch:
168 | pretrained.model.pos_drop.register_forward_hook(get_activation('4'))
169 |
170 | # configure readout
171 | pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2))
172 | pretrained.model.start_index = start_index
173 | pretrained.model.patch_size = patch_size
174 |
175 | # We inject this function into the VisionTransformer instances so that
176 | # we can use it with interpolated position embeddings without modifying the library source.
177 | pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
178 | pretrained.model._resize_pos_embed = types.MethodType(
179 | _resize_pos_embed, pretrained.model
180 | )
181 |
182 | return pretrained
183 |
--------------------------------------------------------------------------------
/myutils/__pycache__/devices.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/myutils/__pycache__/devices.cpython-310.pyc
--------------------------------------------------------------------------------
/myutils/__pycache__/wavelet_color_fix.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/myutils/__pycache__/wavelet_color_fix.cpython-310.pyc
--------------------------------------------------------------------------------
/myutils/devices.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import contextlib
3 | from functools import lru_cache
4 |
5 | import torch
6 | #from modules import errors
7 |
8 | if sys.platform == "darwin":
9 | from modules import mac_specific
10 |
11 |
12 | def has_mps() -> bool:
13 | if sys.platform != "darwin":
14 | return False
15 | else:
16 | return mac_specific.has_mps
17 |
18 |
19 | def get_cuda_device_string():
20 | return "cuda"
21 |
22 |
23 | def get_optimal_device_name():
24 | if torch.cuda.is_available():
25 | return get_cuda_device_string()
26 |
27 | if has_mps():
28 | return "mps"
29 |
30 | return "cpu"
31 |
32 |
33 | def get_optimal_device():
34 | return torch.device(get_optimal_device_name())
35 |
36 |
37 | def get_device_for(task):
38 | return get_optimal_device()
39 |
40 |
41 | def torch_gc():
42 |
43 | if torch.cuda.is_available():
44 | with torch.cuda.device(get_cuda_device_string()):
45 | torch.cuda.empty_cache()
46 | torch.cuda.ipc_collect()
47 |
48 | if has_mps():
49 | mac_specific.torch_mps_gc()
50 |
51 |
52 | def enable_tf32():
53 | if torch.cuda.is_available():
54 |
55 | # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
56 | # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
57 | if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
58 | torch.backends.cudnn.benchmark = True
59 |
60 | torch.backends.cuda.matmul.allow_tf32 = True
61 | torch.backends.cudnn.allow_tf32 = True
62 |
63 |
64 | enable_tf32()
65 | #errors.run(enable_tf32, "Enabling TF32")
66 |
67 | cpu = torch.device("cpu")
68 | device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
69 | dtype = torch.float16
70 | dtype_vae = torch.float16
71 | dtype_unet = torch.float16
72 | unet_needs_upcast = False
73 |
74 |
75 | def cond_cast_unet(input):
76 | return input.to(dtype_unet) if unet_needs_upcast else input
77 |
78 |
79 | def cond_cast_float(input):
80 | return input.float() if unet_needs_upcast else input
81 |
82 |
83 | def randn(seed, shape):
84 | torch.manual_seed(seed)
85 | return torch.randn(shape, device=device)
86 |
87 |
88 | def randn_without_seed(shape):
89 | return torch.randn(shape, device=device)
90 |
91 |
92 | def autocast(disable=False):
93 | if disable:
94 | return contextlib.nullcontext()
95 |
96 | return torch.autocast("cuda")
97 |
98 |
99 | def without_autocast(disable=False):
100 | return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
101 |
102 |
103 | class NansException(Exception):
104 | pass
105 |
106 |
107 | def test_for_nans(x, where):
108 | if not torch.all(torch.isnan(x)).item():
109 | return
110 |
111 | if where == "unet":
112 | message = "A tensor with all NaNs was produced in Unet."
113 |
114 | elif where == "vae":
115 | message = "A tensor with all NaNs was produced in VAE."
116 |
117 | else:
118 | message = "A tensor with all NaNs was produced."
119 |
120 | message += " Use --disable-nan-check commandline argument to disable this check."
121 |
122 | raise NansException(message)
123 |
124 |
125 | @lru_cache
126 | def first_time_calculation():
127 | """
128 | just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
129 | spends about 2.7 seconds doing that, at least wih NVidia.
130 | """
131 |
132 | x = torch.zeros((1, 1)).to(device, dtype)
133 | linear = torch.nn.Linear(1, 1).to(device, dtype)
134 | linear(x)
135 |
136 | x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
137 | conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
138 | conv2d(x)
139 |
--------------------------------------------------------------------------------
/myutils/img_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | import cv2
4 | import math
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | import imageio
9 |
10 | from einops import rearrange
11 |
12 | def save_videos_grid(videos, path=None, rescale=True, n_rows=4, fps=8, discardN=0):
13 | videos = rearrange(videos, "b c t h w -> t b c h w").cpu()
14 | outputs = []
15 | for x in videos:
16 | x = torchvision.utils.make_grid(x, nrow=n_rows)
17 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
18 | if rescale:
19 | x = (x / 2.0 + 0.5).clamp(0, 1) # -1,1 -> 0,1
20 | x = (x * 255).numpy().astype(np.uint8)
21 | #x = adjust_gamma(x, 0.5)
22 | outputs.append(x)
23 |
24 | outputs = outputs[discardN:]
25 |
26 | if path is not None:
27 | #os.makedirs(os.path.dirname(path), exist_ok=True)
28 | imageio.mimsave(path, outputs, duration=1000/fps, loop=0)
29 |
30 | return outputs
31 |
32 | def convert_image_to_fn(img_type, minsize, image, eps=0.02):
33 | width, height = image.size
34 | if min(width, height) < minsize:
35 | scale = minsize/min(width, height) + eps
36 | image = image.resize((math.ceil(width*scale), math.ceil(height*scale)))
37 |
38 | if image.mode != img_type:
39 | return image.convert(img_type)
40 | return image
--------------------------------------------------------------------------------
/myutils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import binascii
3 | from safetensors import safe_open
4 |
5 | import torch
6 |
7 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint
8 |
9 | def rand_name(length=8, suffix=''):
10 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
11 | if suffix:
12 | if not suffix.startswith('.'):
13 | suffix = '.' + suffix
14 | name += suffix
15 | return name
16 |
17 | def cycle(dl):
18 | while True:
19 | for data in dl:
20 | yield data
21 |
22 | def exists(x):
23 | return x is not None
24 |
25 | def identity(x):
26 | return x
27 |
28 | def load_dreambooth_lora(unet, vae=None, model_path=None, alpha=1.0, model_base=""):
29 | if model_path is None: return unet
30 |
31 | if model_path.endswith(".ckpt"):
32 | base_state_dict = torch.load(model_path)['state_dict']
33 | elif model_path.endswith(".safetensors"):
34 | state_dict = {}
35 | with safe_open(model_path, framework="pt", device="cpu") as f:
36 | for key in f.keys():
37 | state_dict[key] = f.get_tensor(key)
38 |
39 | is_lora = all("lora" in k for k in state_dict.keys())
40 | if not is_lora:
41 | base_state_dict = state_dict
42 | else:
43 | base_state_dict = {}
44 | with safe_open(model_base, framework="pt", device="cpu") as f:
45 | for key in f.keys():
46 | base_state_dict[key] = f.get_tensor(key)
47 |
48 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, unet.config)
49 | unet_state_dict = unet.state_dict()
50 | for key in converted_unet_checkpoint:
51 | converted_unet_checkpoint[key] = alpha * converted_unet_checkpoint[key] + (1.0-alpha) * unet_state_dict[key]
52 | unet.load_state_dict(converted_unet_checkpoint, strict=False)
53 |
54 | if vae is not None:
55 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, vae.config)
56 | vae.load_state_dict(converted_vae_checkpoint)
57 |
58 | return unet, vae
--------------------------------------------------------------------------------
/myutils/wavelet_color_fix.py:
--------------------------------------------------------------------------------
1 | '''
2 | # --------------------------------------------------------------------------------
3 | # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
4 | # --------------------------------------------------------------------------------
5 | '''
6 |
7 | import torch
8 | from PIL import Image
9 | from torch import Tensor
10 | from torch.nn import functional as F
11 |
12 | from torchvision.transforms import ToTensor, ToPILImage
13 |
14 | def adain_color_fix(target: Image, source: Image):
15 | # Convert images to tensors
16 | to_tensor = ToTensor()
17 | target_tensor = to_tensor(target).unsqueeze(0)
18 | source_tensor = to_tensor(source).unsqueeze(0)
19 |
20 | # Apply adaptive instance normalization
21 | result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
22 |
23 | # Convert tensor back to image
24 | to_image = ToPILImage()
25 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
26 |
27 | return result_image
28 |
29 | def wavelet_color_fix(target: Image, source: Image):
30 | # Convert images to tensors
31 | to_tensor = ToTensor()
32 | target_tensor = to_tensor(target).unsqueeze(0)
33 | source_tensor = to_tensor(source).unsqueeze(0)
34 |
35 | # Apply wavelet reconstruction
36 | result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
37 |
38 | # Convert tensor back to image
39 | to_image = ToPILImage()
40 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
41 |
42 | return result_image
43 |
44 | def calc_mean_std(feat: Tensor, eps=1e-5):
45 | """Calculate mean and std for adaptive_instance_normalization.
46 | Args:
47 | feat (Tensor): 4D tensor.
48 | eps (float): A small value added to the variance to avoid
49 | divide-by-zero. Default: 1e-5.
50 | """
51 | size = feat.size()
52 | assert len(size) == 4, 'The input feature should be 4D tensor.'
53 | b, c = size[:2]
54 | feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
55 | feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
56 | feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
57 | return feat_mean, feat_std
58 |
59 | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
60 | """Adaptive instance normalization.
61 | Adjust the reference features to have the similar color and illuminations
62 | as those in the degradate features.
63 | Args:
64 | content_feat (Tensor): The reference feature.
65 | style_feat (Tensor): The degradate features.
66 | """
67 | size = content_feat.size()
68 | style_mean, style_std = calc_mean_std(style_feat)
69 | content_mean, content_std = calc_mean_std(content_feat)
70 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
71 | return normalized_feat * style_std.expand(size) + style_mean.expand(size)
72 |
73 | def wavelet_blur(image: Tensor, radius: int):
74 | """
75 | Apply wavelet blur to the input tensor.
76 | """
77 | # input shape: (1, 3, H, W)
78 | # convolution kernel
79 | kernel_vals = [
80 | [0.0625, 0.125, 0.0625],
81 | [0.125, 0.25, 0.125],
82 | [0.0625, 0.125, 0.0625],
83 | ]
84 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
85 | # add channel dimensions to the kernel to make it a 4D tensor
86 | kernel = kernel[None, None]
87 | # repeat the kernel across all input channels
88 | kernel = kernel.repeat(3, 1, 1, 1)
89 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
90 | # apply convolution
91 | output = F.conv2d(image, kernel, groups=3, dilation=radius)
92 | return output
93 |
94 | def wavelet_decomposition(image: Tensor, levels=5):
95 | """
96 | Apply wavelet decomposition to the input tensor.
97 | This function only returns the low frequency & the high frequency.
98 | """
99 | high_freq = torch.zeros_like(image)
100 | for i in range(levels):
101 | radius = 2 ** i
102 | low_freq = wavelet_blur(image, radius)
103 | high_freq += (image - low_freq)
104 | image = low_freq
105 |
106 | return high_freq, low_freq
107 |
108 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
109 | """
110 | Apply wavelet decomposition, so that the content will have the same color as the style.
111 | """
112 | # calculate the wavelet decomposition of the content feature
113 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
114 | del content_low_freq
115 | # calculate the wavelet decomposition of the style feature
116 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
117 | del style_high_freq
118 | # reconstruct the content feature with the style's high frequency
119 | return content_high_freq + style_low_freq
120 |
--------------------------------------------------------------------------------
/pipelines/__pycache__/pipeline_ccsr.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/pipelines/__pycache__/pipeline_ccsr.cpython-310.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.21.0
2 | torch==2.0.1
3 | pytorch_lightning
4 | accelerate==1.2.0
5 | transformers==4.25.0
6 | xformers==0.0.22
7 | loralib
8 | fairscale==0.4.13
9 | basicsr==1.4.2
10 | timm==0.9.5
11 | pydantic==1.10.11
12 | huggingface_hub==0.25.2
13 | opencv-python-headless
14 | lpips
15 |
--------------------------------------------------------------------------------
/scripts/get_path.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def write_png_paths(folder_path, txt_path):
4 | with open(txt_path, 'w') as f:
5 | for root, dirs, files in os.walk(folder_path):
6 | for file in files:
7 | if file.endswith('.png'):
8 | f.write(os.path.join(root, file) + '\n')
9 |
10 | # Example usage:
11 | folder_path = ''
12 | txt_path = '/gt_path.txt'
13 | write_png_paths(folder_path, txt_path)
--------------------------------------------------------------------------------
/scripts/test/test_ccsr_multistep.sh:
--------------------------------------------------------------------------------
1 | python test_ccsr_tile.py \
2 | --pretrained_model_path preset/models/stable-diffusion-2-1-base \
3 | --controlnet_model_path preset/models \
4 | --vae_model_path preset/models \
5 | --baseline_name ccsr-v2 \
6 | --image_path preset/test_datasets \
7 | --output_dir experiments/test \
8 | --sample_method ddpm \
9 | --num_inference_steps 6 \
10 | --t_max 0.6667 \
11 | --t_min 0.5 \
12 | --start_point lr \
13 | --start_steps 999 \
14 | --process_size 512 \
15 | --guidance_scale 4.5 \
16 | --sample_times 1 \
17 | --use_vae_encode_condition \
18 | --upscale 4
--------------------------------------------------------------------------------
/scripts/test/test_ccsr_onestep.sh:
--------------------------------------------------------------------------------
1 |
2 | python test_ccsr_tile.py \
3 | --pretrained_model_path preset/models/stable-diffusion-2-1-base \
4 | --controlnet_model_path preset/models \
5 | --vae_model_path preset/models \
6 | --baseline_name ccsr-v2 \
7 | --image_path preset/test_datasets \
8 | --output_dir experiments/test \
9 | --sample_method ddpm \
10 | --num_inference_steps 1 \
11 | --t_min 0.0 \
12 | --start_point lr \
13 | --start_steps 999 \
14 | --process_size 512 \
15 | --guidance_scale 1.0 \
16 | --sample_times 1 \
17 | --use_vae_encode_condition \
18 | --upscale 4
--------------------------------------------------------------------------------
/scripts/test/test_ccsr_tile.sh:
--------------------------------------------------------------------------------
1 | python test_ccsr_tile.py \
2 | --pretrained_model_path preset/models/stable-diffusion-2-1-base \
3 | --controlnet_model_path preset/models \
4 | --vae_model_path preset/models \
5 | --baseline_name ccsr-v2 \
6 | --image_path preset/test_datasets \
7 | --output_dir experiments/test \
8 | --sample_method ddpm \
9 | --num_inference_steps 6 \
10 | --t_max 0.6667 \
11 | --t_min 0.5 \
12 | --start_point lr \
13 | --start_steps 999 \
14 | --process_size 512 \
15 | --guidance_scale 4.5 \
16 | --sample_times 1 \
17 | --use_vae_encode_condition \
18 | --upscale 4 \
19 | --tile_diffusion \
20 | --tile_diffusion_size 512 \
21 | --tile_diffusion_stride 256 \
22 | --tile_vae \
23 | --vae_decoder_tile_size 224 \
24 | --vae_encoder_tile_size 1024 \
--------------------------------------------------------------------------------
/scripts/train/train_ccsr_stage1.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_ccsr_stage1.py \
2 | --pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \
3 | --controlnet_model_name_or_path='preset/models/pretrained_controlnet' \
4 | --enable_xformers_memory_efficient_attention \
5 | --output_dir="./experiments/ccsrv2_stage1" \
6 | --mixed_precision="fp16" \
7 | --resolution=512 \
8 | --learning_rate=5e-5 \
9 | --train_batch_size=4 \
10 | --gradient_accumulation_steps=6 \
11 | --dataloader_num_workers=0 \
12 | --checkpointing_steps=500 \
13 | --t_max=0.6667 \
14 | --max_train_steps=20000 \
15 | --dataset_root_folders 'preset/gt_path.txt'
--------------------------------------------------------------------------------
/scripts/train/train_ccsr_stage2.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_ccsr_stage2.py \
2 | --pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \
3 | --controlnet_model_name_or_path='preset/models/model_stage1' \
4 | --enable_xformers_memory_efficient_attention \
5 | --output_dir="./experiments/ccsrv2_stage2" \
6 | --mixed_precision="fp16" \
7 | --resolution=512 \
8 | --learning_rate=5e-6 \
9 | --train_batch_size=2 \
10 | --gradient_accumulation_steps=8 \
11 | --checkpointing_steps=500 \
12 | --is_start_lr=True \
13 | --t_max=0.6667 \
14 | --num_inference_steps=1 \
15 | --is_module \
16 | --lambda_l2=1.0 \
17 | --lambda_lpips=1.0 \
18 | --lambda_disc=0.05 \
19 | --lambda_disc_train=0.5 \
20 | --begin_disc=100 \
21 | --max_train_steps=2000 \
22 | --dataset_root_folders 'preset/gt_path.txt'
--------------------------------------------------------------------------------
/scripts/train/train_controlnet.sh:
--------------------------------------------------------------------------------
1 |
2 | CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_controlnet.py \
3 | --pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \
4 | --controlnet_model_name_or_path='' \
5 | --enable_xformers_memory_efficient_attention \
6 | --output_dir="./experiments/pretrained_controlnet" \
7 | --mixed_precision="fp16" \
8 | --resolution=512 \
9 | --learning_rate=5e-5 \
10 | --train_batch_size=4 \
11 | --gradient_accumulation_steps=6 \
12 | --dataloader_num_workers=0 \
13 | --checkpointing_steps=5000 \
14 | --max_train_steps=40000 \
15 | --dataset_root_folders 'preset/gt_path.txt'
--------------------------------------------------------------------------------
/test_ccsr_tile.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import math
4 | import time
5 | import argparse
6 |
7 | import numpy as np
8 | from PIL import Image
9 | import safetensors.torch
10 |
11 | import torch
12 | from torchvision import transforms
13 | import torchvision.transforms.functional as F
14 |
15 | from accelerate import Accelerator
16 | from accelerate.utils import set_seed
17 | from diffusers import (
18 | AutoencoderKL,
19 | UniPCMultistepScheduler,
20 | DPMSolverMultistepScheduler,
21 | DDPMScheduler,
22 | UNet2DConditionModel,
23 | )
24 |
25 | from diffusers.utils import check_min_version
26 | from diffusers.utils.import_utils import is_xformers_available
27 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
28 |
29 | from pipelines.pipeline_ccsr import StableDiffusionControlNetPipeline
30 | from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
31 | from models.controlnet import ControlNetModel
32 |
33 |
34 |
35 | def load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention):
36 |
37 | scheduler_mapping = {
38 | 'unipcmultistep': UniPCMultistepScheduler,
39 | 'ddpm': DDPMScheduler,
40 | 'dpmmultistep': DPMSolverMultistepScheduler,
41 | }
42 |
43 | try:
44 | scheduler_cls = scheduler_mapping[args.sample_method]
45 | except KeyError:
46 | raise ValueError(f"Invalid sample_method: {args.sample_method}")
47 |
48 | scheduler = scheduler_cls.from_pretrained(args.pretrained_model_path, subfolder="scheduler")
49 |
50 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
51 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
52 | feature_extractor = CLIPImageProcessor.from_pretrained(os.path.join(args.pretrained_model_path, "feature_extractor"))
53 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_path, subfolder="unet")
54 | controlnet = ControlNetModel.from_pretrained(args.controlnet_model_path, subfolder="controlnet")
55 |
56 | vae_path = args.vae_model_path if args.vae_model_path else args.pretrained_model_path
57 | vae = AutoencoderKL.from_pretrained(vae_path, subfolder="vae")
58 |
59 | # Freeze models
60 | for model in [vae, text_encoder, unet, controlnet]:
61 | model.requires_grad_(False)
62 |
63 | # Enable xformers if available
64 | if enable_xformers_memory_efficient_attention:
65 | if is_xformers_available():
66 | unet.enable_xformers_memory_efficient_attention()
67 | controlnet.enable_xformers_memory_efficient_attention()
68 | else:
69 | raise ValueError("xformers is not available. Ensure it is installed correctly.")
70 |
71 | # Initialize pipeline
72 | validation_pipeline = StableDiffusionControlNetPipeline(
73 | vae=vae,
74 | text_encoder=text_encoder,
75 | tokenizer=tokenizer,
76 | feature_extractor=feature_extractor,
77 | unet=unet,
78 | controlnet=controlnet,
79 | scheduler=scheduler,
80 | safety_checker=None,
81 | requires_safety_checker=False,
82 | )
83 |
84 | if args.tile_vae:
85 | validation_pipeline._init_tiled_vae(
86 | encoder_tile_size=args.vae_encoder_tile_size,
87 | decoder_tile_size=args.vae_decoder_tile_size
88 | )
89 |
90 | # Set weight dtype based on mixed precision
91 | dtype_mapping = {
92 | "fp16": torch.float16,
93 | "bf16": torch.bfloat16,
94 | }
95 | weight_dtype = dtype_mapping.get(accelerator.mixed_precision, torch.float32)
96 |
97 | # Move models to accelerator device with appropriate dtype
98 | for model in [text_encoder, vae, unet, controlnet]:
99 | model.to(accelerator.device, dtype=weight_dtype)
100 |
101 | return validation_pipeline
102 |
103 | def main(args, enable_xformers_memory_efficient_attention=True,):
104 |
105 | detailed_output_dir = os.path.join(
106 | args.output_dir,
107 | f"sr_{args.baseline_name}_{args.sample_method}_{str(args.num_inference_steps).zfill(3)}steps_{args.start_point}{args.start_steps}_size{args.process_size}_cfg{args.guidance_scale}"
108 | )
109 |
110 | accelerator = Accelerator(
111 | mixed_precision=args.mixed_precision,
112 | )
113 |
114 | # If passed along, set the training seed now.
115 | if args.seed is not None:
116 | set_seed(args.seed)
117 |
118 | # Handle the output folder creation
119 | # We need to initialize the trackers we use, and also store our configuration.
120 | # The trackers initializes automatically on the main process.
121 | if accelerator.is_main_process:
122 | os.makedirs(detailed_output_dir, exist_ok=True)
123 | accelerator.init_trackers("Controlnet")
124 |
125 | pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention)
126 |
127 | if accelerator.is_main_process:
128 | generator = torch.Generator(device=accelerator.device)
129 | if args.seed is not None:
130 | generator.manual_seed(args.seed)
131 |
132 | image_paths = sorted(glob.glob(os.path.join(args.image_path, "*.*"))) if os.path.isdir(args.image_path) else [args.image_path]
133 |
134 | time_records = []
135 | for image_path in image_paths:
136 | validation_image = Image.open(image_path).convert("RGB")
137 | negative_prompt = args.negative_prompt
138 | validation_prompt = args.added_prompt
139 |
140 | ori_width, ori_height = validation_image.size
141 | resize_flag = False
142 | rscale = args.upscale
143 | if ori_width < args.process_size//rscale or ori_height < args.process_size//rscale:
144 | scale = (args.process_size//rscale)/min(ori_width, ori_height)
145 | tmp_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height)))
146 | validation_image = tmp_image
147 | resize_flag = True
148 |
149 |
150 | validation_image = validation_image.resize((validation_image.size[0]*rscale, validation_image.size[1]*rscale))
151 | validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
152 | width, height = validation_image.size
153 | resize_flag = True #
154 |
155 | for sample_idx in range(args.sample_times):
156 | os.makedirs(f'{detailed_output_dir}/sample{str(sample_idx).zfill(2)}/', exist_ok=True)
157 |
158 | for sample_idx in range(args.sample_times):
159 |
160 | inference_time, image = pipeline(
161 | args.t_max,
162 | args.t_min,
163 | args.tile_diffusion,
164 | args.tile_diffusion_size,
165 | args.tile_diffusion_stride,
166 | args.added_prompt,
167 | validation_image,
168 | num_inference_steps=args.num_inference_steps,
169 | generator=generator,
170 | height=height,
171 | width=width,
172 | guidance_scale=args.guidance_scale,
173 | negative_prompt=args.negative_prompt,
174 | conditioning_scale=args.conditioning_scale,
175 | start_steps=args.start_steps,
176 | start_point=args.start_point,
177 | use_vae_encode_condition=args.use_vae_encode_condition,
178 | )
179 | image = image.images[0]
180 |
181 | print(f"Inference time: {inference_time:.4f} seconds")
182 | time_records.append(inference_time)
183 |
184 | # Apply color fixing if specified
185 | if args.align_method != 'nofix':
186 | fix_func = wavelet_color_fix if args.align_method == 'wavelet' else adain_color_fix
187 | image = fix_func(image, validation_image)
188 |
189 | if resize_flag:
190 | image = image.resize((ori_width*rscale, ori_height*rscale))
191 |
192 | image_tensor = torch.clamp(F.to_tensor(image), 0, 1)
193 | final_image = transforms.ToPILImage()(image_tensor)
194 | base_name = os.path.splitext(os.path.basename(image_path))[0]
195 | save_path = os.path.join(detailed_output_dir, f"sample{str(sample_idx).zfill(2)}", f"{base_name}.png")
196 | image.save(save_path)
197 |
198 | # Calculate the average inference time, excluding the first few for stabilization
199 | if len(time_records) > 3:
200 | average_time = np.mean(time_records[3:])
201 | else:
202 | average_time = np.mean(time_records)
203 | if accelerator.is_main_process:
204 | print(f"Average inference time: {average_time:.4f} seconds")
205 |
206 |
207 | # Save the run settings to a file
208 | settings_path = os.path.join(detailed_output_dir, "settings.txt")
209 | with open(settings_path, 'w') as f:
210 | f.write("------------------ start ------------------\n")
211 | for key, value in vars(args).items():
212 | f.write(f"{key} : {value}\n")
213 | f.write("------------------- end -------------------\n")
214 |
215 |
216 | if __name__ == "__main__":
217 | parser = argparse.ArgumentParser(description="Stable Diffusion ControlNet Pipeline for Super-Resolution")
218 | parser.add_argument("--controlnet_model_path", type=str, default="", help="Path to ControlNet model")
219 | parser.add_argument("--pretrained_model_path", type=str, default="", help="Path to pretrained model")
220 | parser.add_argument("--vae_model_path", type=str, default="", help="Path to VAE model")
221 | parser.add_argument("--added_prompt", type=str, default="clean, high-resolution, 8k", help="Additional prompt for generation")
222 | parser.add_argument("--negative_prompt", type=str, default="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", help="Negative prompt to avoid certain features")
223 | parser.add_argument("--image_path", type=str, default="", help="Path to input image or directory")
224 | parser.add_argument("--output_dir", type=str, default="", help="Directory to save outputs")
225 | parser.add_argument("--mixed_precision", type=str, choices=["no", "fp16", "bf16"], default="fp16", help="Mixed precision mode")
226 | parser.add_argument("--guidance_scale", type=float, default=1.0, help="Guidance scale for generation")
227 | parser.add_argument("--conditioning_scale", type=float, default=1.0, help="Conditioning scale")
228 | parser.add_argument("--num_inference_steps", type=int, default=1, help="Number of inference steps(not the final inference time)")
229 | # final_inference_time = num_inference_steps * (t_max - t_min) + 1
230 | parser.add_argument("--t_max", type=float, default=0.6666, help="Maximum timestep")
231 | parser.add_argument("--t_min", type=float, default=0.0, help="Minimum timestep")
232 | parser.add_argument("--process_size", type=int, default=512, help="Processing size of the image")
233 | parser.add_argument("--upscale", type=int, default=1, help="Upscaling factor")
234 | parser.add_argument("--seed", type=int, default=None, help="Random seed")
235 | parser.add_argument("--sample_times", type=int, default=5, help="Number of samples to generate per image")
236 | parser.add_argument("--sample_method", type=str, choices=['unipcmultistep', 'ddpm', 'dpmmultistep'], default='ddpm', help="Sampling method")
237 | parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain', help="Alignment method for color fixing")
238 | parser.add_argument("--start_steps", type=int, default=999, help="Starting steps")
239 | parser.add_argument("--start_point", type=str, choices=['lr', 'noise'], default='lr', help="Starting point for generation")
240 | parser.add_argument("--baseline_name", type=str, default='ccsr-v2', help="Baseline name for output naming")
241 | parser.add_argument("--use_vae_encode_condition", action='store_true', help="Use VAE encoding LQ condition")
242 |
243 | # Tiling settings for high-resolution SR
244 | parser.add_argument("--tile_diffusion", action="store_true", help="Optionally! Enable tile-based diffusion")
245 | parser.add_argument("--tile_diffusion_size", type=int, default=512, help="Tile size for diffusion")
246 | parser.add_argument("--tile_diffusion_stride", type=int, default=256, help="Stride size for diffusion tiles")
247 | parser.add_argument("--tile_vae", action="store_true", help="Optionally! Enable tiling for VAE")
248 | parser.add_argument("--vae_decoder_tile_size", type=int, default=224, help="Tile size for VAE decoder")
249 | parser.add_argument("--vae_encoder_tile_size", type=int, default=1024, help="Tile size for VAE encoder")
250 |
251 | args = parser.parse_args()
252 | main(args)
--------------------------------------------------------------------------------
/utils/__pycache__/vaehook.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csslc/CCSR/9ccb264e81b717770f4c3e5b48a1155a6a5dbf3c/utils/__pycache__/vaehook.cpython-310.pyc
--------------------------------------------------------------------------------
/utils/devices.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import contextlib
3 | from functools import lru_cache
4 |
5 | import torch
6 | #from modules import errors
7 |
8 | if sys.platform == "darwin":
9 | from modules import mac_specific
10 |
11 |
12 | def has_mps() -> bool:
13 | if sys.platform != "darwin":
14 | return False
15 | else:
16 | return mac_specific.has_mps
17 |
18 |
19 | def get_cuda_device_string():
20 | return "cuda"
21 |
22 |
23 | def get_optimal_device_name():
24 | if torch.cuda.is_available():
25 | return get_cuda_device_string()
26 |
27 | if has_mps():
28 | return "mps"
29 |
30 | return "cpu"
31 |
32 |
33 | def get_optimal_device():
34 | return torch.device(get_optimal_device_name())
35 |
36 |
37 | def get_device_for(task):
38 | return get_optimal_device()
39 |
40 |
41 | def torch_gc():
42 |
43 | if torch.cuda.is_available():
44 | with torch.cuda.device(get_cuda_device_string()):
45 | torch.cuda.empty_cache()
46 | torch.cuda.ipc_collect()
47 |
48 | if has_mps():
49 | mac_specific.torch_mps_gc()
50 |
51 |
52 | def enable_tf32():
53 | if torch.cuda.is_available():
54 |
55 | # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
56 | # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
57 | if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
58 | torch.backends.cudnn.benchmark = True
59 |
60 | torch.backends.cuda.matmul.allow_tf32 = True
61 | torch.backends.cudnn.allow_tf32 = True
62 |
63 |
64 | enable_tf32()
65 | #errors.run(enable_tf32, "Enabling TF32")
66 |
67 | cpu = torch.device("cpu")
68 | device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda")
69 | dtype = torch.float16
70 | dtype_vae = torch.float16
71 | dtype_unet = torch.float16
72 | unet_needs_upcast = False
73 |
74 |
75 | def cond_cast_unet(input):
76 | return input.to(dtype_unet) if unet_needs_upcast else input
77 |
78 |
79 | def cond_cast_float(input):
80 | return input.float() if unet_needs_upcast else input
81 |
82 |
83 | def randn(seed, shape):
84 | torch.manual_seed(seed)
85 | return torch.randn(shape, device=device)
86 |
87 |
88 | def randn_without_seed(shape):
89 | return torch.randn(shape, device=device)
90 |
91 |
92 | def autocast(disable=False):
93 | if disable:
94 | return contextlib.nullcontext()
95 |
96 | return torch.autocast("cuda")
97 |
98 |
99 | def without_autocast(disable=False):
100 | return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
101 |
102 |
103 | class NansException(Exception):
104 | pass
105 |
106 |
107 | def test_for_nans(x, where):
108 | if not torch.all(torch.isnan(x)).item():
109 | return
110 |
111 | if where == "unet":
112 | message = "A tensor with all NaNs was produced in Unet."
113 |
114 | elif where == "vae":
115 | message = "A tensor with all NaNs was produced in VAE."
116 |
117 | else:
118 | message = "A tensor with all NaNs was produced."
119 |
120 | message += " Use --disable-nan-check commandline argument to disable this check."
121 |
122 | raise NansException(message)
123 |
124 |
125 | @lru_cache
126 | def first_time_calculation():
127 | """
128 | just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
129 | spends about 2.7 seconds doing that, at least wih NVidia.
130 | """
131 |
132 | x = torch.zeros((1, 1)).to(device, dtype)
133 | linear = torch.nn.Linear(1, 1).to(device, dtype)
134 | linear(x)
135 |
136 | x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
137 | conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
138 | conv2d(x)
139 |
--------------------------------------------------------------------------------
/utils/img_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import PIL
3 | import cv2
4 | import math
5 | import numpy as np
6 | import torch
7 | import torchvision
8 | import imageio
9 |
10 | from einops import rearrange
11 |
12 | def save_videos_grid(videos, path=None, rescale=True, n_rows=4, fps=8, discardN=0):
13 | videos = rearrange(videos, "b c t h w -> t b c h w").cpu()
14 | outputs = []
15 | for x in videos:
16 | x = torchvision.utils.make_grid(x, nrow=n_rows)
17 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
18 | if rescale:
19 | x = (x / 2.0 + 0.5).clamp(0, 1) # -1,1 -> 0,1
20 | x = (x * 255).numpy().astype(np.uint8)
21 | #x = adjust_gamma(x, 0.5)
22 | outputs.append(x)
23 |
24 | outputs = outputs[discardN:]
25 |
26 | if path is not None:
27 | #os.makedirs(os.path.dirname(path), exist_ok=True)
28 | imageio.mimsave(path, outputs, duration=1000/fps, loop=0)
29 |
30 | return outputs
31 |
32 | def convert_image_to_fn(img_type, minsize, image, eps=0.02):
33 | width, height = image.size
34 | if min(width, height) < minsize:
35 | scale = minsize/min(width, height) + eps
36 | image = image.resize((math.ceil(width*scale), math.ceil(height*scale)))
37 |
38 | if image.mode != img_type:
39 | return image.convert(img_type)
40 | return image
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import binascii
3 | from safetensors import safe_open
4 |
5 | import torch
6 |
7 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint
8 |
9 | def rand_name(length=8, suffix=''):
10 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
11 | if suffix:
12 | if not suffix.startswith('.'):
13 | suffix = '.' + suffix
14 | name += suffix
15 | return name
16 |
17 | def cycle(dl):
18 | while True:
19 | for data in dl:
20 | yield data
21 |
22 | def exists(x):
23 | return x is not None
24 |
25 | def identity(x):
26 | return x
27 |
28 | def load_dreambooth_lora(unet, vae=None, model_path=None, alpha=1.0, model_base=""):
29 | if model_path is None: return unet
30 |
31 | if model_path.endswith(".ckpt"):
32 | base_state_dict = torch.load(model_path)['state_dict']
33 | elif model_path.endswith(".safetensors"):
34 | state_dict = {}
35 | with safe_open(model_path, framework="pt", device="cpu") as f:
36 | for key in f.keys():
37 | state_dict[key] = f.get_tensor(key)
38 |
39 | is_lora = all("lora" in k for k in state_dict.keys())
40 | if not is_lora:
41 | base_state_dict = state_dict
42 | else:
43 | base_state_dict = {}
44 | with safe_open(model_base, framework="pt", device="cpu") as f:
45 | for key in f.keys():
46 | base_state_dict[key] = f.get_tensor(key)
47 |
48 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, unet.config)
49 | unet_state_dict = unet.state_dict()
50 | for key in converted_unet_checkpoint:
51 | converted_unet_checkpoint[key] = alpha * converted_unet_checkpoint[key] + (1.0-alpha) * unet_state_dict[key]
52 | unet.load_state_dict(converted_unet_checkpoint, strict=False)
53 |
54 | if vae is not None:
55 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, vae.config)
56 | vae.load_state_dict(converted_vae_checkpoint)
57 |
58 | return unet, vae
--------------------------------------------------------------------------------
/utils/wavelet_color_fix.py:
--------------------------------------------------------------------------------
1 | '''
2 | # --------------------------------------------------------------------------------
3 | # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
4 | # --------------------------------------------------------------------------------
5 | '''
6 |
7 | import torch
8 | from PIL import Image
9 | from torch import Tensor
10 | from torch.nn import functional as F
11 |
12 | from torchvision.transforms import ToTensor, ToPILImage
13 |
14 | def adain_color_fix(target: Image, source: Image):
15 | # Convert images to tensors
16 | to_tensor = ToTensor()
17 | target_tensor = to_tensor(target).unsqueeze(0)
18 | source_tensor = to_tensor(source).unsqueeze(0)
19 |
20 | # Apply adaptive instance normalization
21 | result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
22 |
23 | # Convert tensor back to image
24 | to_image = ToPILImage()
25 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
26 |
27 | return result_image
28 |
29 | def wavelet_color_fix(target: Image, source: Image):
30 | # Convert images to tensors
31 | to_tensor = ToTensor()
32 | target_tensor = to_tensor(target).unsqueeze(0)
33 | source_tensor = to_tensor(source).unsqueeze(0)
34 |
35 | # Apply wavelet reconstruction
36 | result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
37 |
38 | # Convert tensor back to image
39 | to_image = ToPILImage()
40 | result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
41 |
42 | return result_image
43 |
44 | def calc_mean_std(feat: Tensor, eps=1e-5):
45 | """Calculate mean and std for adaptive_instance_normalization.
46 | Args:
47 | feat (Tensor): 4D tensor.
48 | eps (float): A small value added to the variance to avoid
49 | divide-by-zero. Default: 1e-5.
50 | """
51 | size = feat.size()
52 | assert len(size) == 4, 'The input feature should be 4D tensor.'
53 | b, c = size[:2]
54 | feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
55 | feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
56 | feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
57 | return feat_mean, feat_std
58 |
59 | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
60 | """Adaptive instance normalization.
61 | Adjust the reference features to have the similar color and illuminations
62 | as those in the degradate features.
63 | Args:
64 | content_feat (Tensor): The reference feature.
65 | style_feat (Tensor): The degradate features.
66 | """
67 | size = content_feat.size()
68 | style_mean, style_std = calc_mean_std(style_feat)
69 | content_mean, content_std = calc_mean_std(content_feat)
70 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
71 | return normalized_feat * style_std.expand(size) + style_mean.expand(size)
72 |
73 | def wavelet_blur(image: Tensor, radius: int):
74 | """
75 | Apply wavelet blur to the input tensor.
76 | """
77 | # input shape: (1, 3, H, W)
78 | # convolution kernel
79 | kernel_vals = [
80 | [0.0625, 0.125, 0.0625],
81 | [0.125, 0.25, 0.125],
82 | [0.0625, 0.125, 0.0625],
83 | ]
84 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
85 | # add channel dimensions to the kernel to make it a 4D tensor
86 | kernel = kernel[None, None]
87 | # repeat the kernel across all input channels
88 | kernel = kernel.repeat(3, 1, 1, 1)
89 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
90 | # apply convolution
91 | output = F.conv2d(image, kernel, groups=3, dilation=radius)
92 | return output
93 |
94 | def wavelet_decomposition(image: Tensor, levels=5):
95 | """
96 | Apply wavelet decomposition to the input tensor.
97 | This function only returns the low frequency & the high frequency.
98 | """
99 | high_freq = torch.zeros_like(image)
100 | for i in range(levels):
101 | radius = 2 ** i
102 | low_freq = wavelet_blur(image, radius)
103 | high_freq += (image - low_freq)
104 | image = low_freq
105 |
106 | return high_freq, low_freq
107 |
108 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
109 | """
110 | Apply wavelet decomposition, so that the content will have the same color as the style.
111 | """
112 | # calculate the wavelet decomposition of the content feature
113 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
114 | del content_low_freq
115 | # calculate the wavelet decomposition of the style feature
116 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
117 | del style_high_freq
118 | # reconstruct the content feature with the style's high frequency
119 | return content_high_freq + style_low_freq
120 |
--------------------------------------------------------------------------------