DiffCom is Robust to Unexpected Transmission Degradations
21 |
25 |
A visual comparison illustrating the impact of several unexpected
26 | transmission degradations: ① unseen channel fading, ② PAPR reduction, ③
27 | with ISI (removed CP symbols), and ④ very low CSNR (0dB).
27 | End-to-end visual communication systems typically optimize a trade-off between channel
28 | bandwidth
29 | costs and signal-level distortion metrics.
30 | However, under challenging physical conditions, such a discriminative
31 | communication paradigm often results in unrealistic reconstructions with perceptible
32 | blurring and aliasing artifacts, despite the inclusion of perceptual or adversarial losses
33 | during training. This issue primarily stems from the receiver's limited knowledge about the
34 | underlying data
35 | manifold and the use of deterministic decoding mechanisms.
36 |
37 |
38 | We propose DiffCom, a novel end-to-end generative communication paradigm that utilizes off-the-shelf
40 | generative priors from diffusion models for decoding , thereby improving perceptual
41 | quality
42 | without heavily relying on bandwidth costs and received signal quality.
43 | Unlike traditional systems that rely on deterministic decoders optimized solely for
44 | distortion
45 | metrics, our DiffCom leverages raw channel-received signal as a fine-grained condition to guide stochastic posterior
46 | sampling. Our approach ensures that reconstructions remain on the manifold of real data
47 | with a
48 | novel confirming constraint, enhancing the robustness and reliability of the generated
49 | outcomes.
50 | Furthermore, DiffCom incorporates a blind posterior
51 | sampling
52 | technique to address
53 | scenarios with unknown forward transmission characteristics.
54 |
55 |
56 | Experimental results demonstrate that:
57 |
58 |
DiffCom achieves SOTA transmission performance in
59 | terms of
60 | multiple perceptual quality metrics, such as LPIPS, DISTS, FID, and so on.
61 |
62 |
DiffCom significantly enhances the robustness of
63 | current
64 | methods against various transmission-related degradations, including mismatched SNR,
65 | unseen
66 | fading, blind channel estimation, PAPR reduction, and inter-symbol interference.
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 | )
76 |
77 | const Section1 = () => {
78 | return (
79 |
80 |
81 |
82 |
83 |
84 | );
85 | }
86 |
87 | export default Section1;
88 |
--------------------------------------------------------------------------------
/website/src/App.css:
--------------------------------------------------------------------------------
1 | .App {
2 | text-align: center;
3 | }
4 |
5 | .App-logo {
6 | height: 40vmin;
7 | pointer-events: none;
8 | }
9 |
10 | @media (prefers-reduced-motion: no-preference) {
11 | .App-logo {
12 | animation: App-logo-spin infinite 20s linear;
13 | }
14 | }
15 |
16 | .App-header {
17 | background-color: #282c34;
18 | min-height: 100vh;
19 | display: flex;
20 | flex-direction: column;
21 | align-items: center;
22 | justify-content: center;
23 | font-size: calc(10px + 2vmin);
24 | color: white;
25 | }
26 |
27 | .App-link {
28 | color: #61dafb;
29 | }
30 |
31 | @keyframes App-logo-spin {
32 | from {
33 | transform: rotate(0deg);
34 | }
35 | to {
36 | transform: rotate(360deg);
37 | }
38 | }
39 |
40 | html {
41 | background-color: #fff;
42 | font-size: 16px;
43 | -moz-osx-font-smoothing: grayscale;
44 | -webkit-font-smoothing: antialiased;
45 | min-width: 300px;
46 | overflow-x: hidden;
47 | overflow-y: scroll;
48 | text-rendering: optimizeLegibility;
49 | -webkit-text-size-adjust: 100%;
50 | -moz-text-size-adjust: 100%;
51 | -ms-text-size-adjust: 100%;
52 | text-size-adjust: 100%;
53 | }
54 |
55 | a {
56 | color: #3273dc;
57 | cursor: pointer;
58 | text-decoration: none;
59 | }
60 |
61 | p {
62 | display: block;
63 | margin-block-start: 1em;
64 | margin-block-end: 1em;
65 | margin-inline-start: 0px;
66 | margin-inline-end: 0px;
67 | }
68 |
69 | img {
70 | height: auto;
71 | max-width: 100%;
72 | }
73 |
74 | body {
75 | color: #4a4a4a;
76 | font-size: 1em;
77 | font-weight: 400;
78 | line-height: 1.5;
79 | }
80 |
81 | pre {
82 | -webkit-overflow-scrolling: touch;
83 | background-color: #f5f5f5;
84 | color: #4a4a4a;
85 | font-size: .875em;
86 | overflow-x: auto;
87 | padding: 1.25rem 1.5rem;
88 | white-space: pre;
89 | word-wrap: normal;
90 | }
91 |
92 | code, pre {
93 | -moz-osx-font-smoothing: auto;
94 | -webkit-font-smoothing: auto;
95 | font-family: monospace;
96 | }
97 |
98 | .footer {
99 | background-color: #fafafa;
100 | padding: 3rem 1.5rem 6rem;
101 | }
102 |
103 | .title {
104 | color: #363636;
105 | font-size: 2rem;
106 | font-weight: 600;
107 | line-height: 1.125;
108 | }
109 |
110 | .subtitle {
111 | color: #4a4a4a;
112 | font-size: 1.25rem;
113 | font-weight: 400;
114 | line-height: 1.25;
115 | }
116 |
117 | .subtitle, .title {
118 | word-break: break-word;
119 | }
120 |
121 | .title.is-1{
122 | font-size: 3rem;
123 | }
124 |
125 | .title.is-2 {
126 | margin-top: -0.5rem;
127 | font-size: 2.0rem;
128 | }
129 |
130 | .title.is-3 {
131 | margin-top: -0.5rem;
132 | font-size: 1.7rem;
133 | }
134 |
135 | .is-size-5 {
136 | font-size: 1.25rem!important;
137 | }
138 |
139 | .publication-title {
140 | font-family: 'Google Sans', sans-serif;
141 | }
142 |
143 | .publication-authors {
144 | font-family: 'Google Sans', sans-serif;
145 | }
146 |
147 | .publication-authors a {
148 | color: hsl(204, 86%, 53%) !important;
149 | }
150 |
151 | .author-block {
152 | display: inline-block;
153 | }
154 |
155 | .author-block-small {
156 | display: inline-block;
157 | font-size: 1rem;
158 | }
159 |
160 | @media screen and (min-width: 1024px){
161 | .container {
162 | max-width: 960px;
163 | }}
164 |
165 | .hero {
166 | align-items: stretch;
167 | display: flex;
168 | flex-direction: column;
169 | justify-content: space-between;
170 | }
171 |
172 | .hero-body {
173 | flex-grow: 1;
174 | flex-shrink: 0;
175 | padding: 1.5rem 1.5rem;
176 | }
177 |
178 | .hero.is-light {
179 | background-color: #f5f5f5;
180 | color: rgba(0,0,0,.7);
181 | }
182 |
183 | .container {
184 | flex-grow: 1;
185 | margin: 0 auto;
186 | position: relative;
187 | width: auto;
188 | }
189 | .section {
190 | padding: 1.5rem 1.5rem;
191 | }
192 |
193 | .column{
194 | display: block;
195 | flex-basis: 0;
196 | flex-grow: 1;
197 | flex-shrink: 1;
198 | padding: 0.75rem;
199 | }
200 |
201 | @media screen and (min-width: 769px), print{
202 | .column.is-four-fifths, .column.is-four-fifths-tablet {
203 | flex: none;
204 | width: 80%;
205 | }}
206 |
207 | .columns {
208 | margin-left: -.75rem;
209 | margin-right: -.75rem;
210 | margin-top: -.75rem;
211 | }
212 |
213 | @media screen and (min-width: 769px), print{
214 | .columns:not(.is-desktop) {
215 | display: flex;
216 | }}
217 |
218 | .columns.is-centered {
219 | justify-content: center;
220 | }
221 | .columns:last-child {
222 | margin-bottom: -.75rem;
223 | }
224 |
225 | .has-text-centered {
226 | text-align: center!important;
227 | }
228 |
229 | .has-text-justified {
230 | text-align: justify!important;
231 | }
232 |
233 |
234 | .task-description {
235 | text-align: center;
236 | padding: 2rem 0;
237 | }
238 |
239 | .result-display {
240 | padding: 0.3rem 0;
241 | }
242 |
243 | /* button style */
244 | .button {
245 | background-color: #fff;
246 | border-color: #dbdbdb;
247 | border-width: 1px;
248 | color: #363636;
249 | cursor: pointer;
250 | justify-content: center;
251 | padding-bottom: calc(.5em - 1px);
252 | padding-left: 1em;
253 | padding-right: 1em;
254 | padding-top: calc(.5em - 1px);
255 | text-align: center;
256 | white-space: nowrap;
257 | }
258 |
259 |
260 | .task-btns {
261 | display: flex;
262 | justify-content: center;
263 | align-items: center;
264 | padding: 1.0rem 0;
265 | }
266 |
267 | .generate-progress {
268 | justify-content: center;
269 | }
--------------------------------------------------------------------------------
/guided_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch as th
3 |
4 | from .gaussian_diffusion import GaussianDiffusion
5 |
6 |
7 | def space_timesteps(num_timesteps, section_counts):
8 | """
9 | Create a list of timesteps to use from an original diffusion process,
10 | given the number of timesteps we want to take from equally-sized portions
11 | of the original process.
12 |
13 | For example, if there's 300 timesteps and the section counts are [10,15,20]
14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
15 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
16 |
17 | If the stride is a string starting with "ddim", then the fixed striding
18 | from the DDIM paper is used, and only one section is allowed.
19 |
20 | :param num_timesteps: the number of diffusion steps in the original
21 | process to divide up.
22 | :param section_counts: either a list of numbers, or a string containing
23 | comma-separated numbers, indicating the step count
24 | per section. As a special case, use "ddimN" where N
25 | is a number of steps to use the striding from the
26 | DDIM paper.
27 | :return: a set of diffusion steps from the original process to use.
28 | """
29 | if isinstance(section_counts, str):
30 | if section_counts.startswith("ddim"):
31 | desired_count = int(section_counts[len("ddim") :])
32 | for i in range(1, num_timesteps):
33 | if len(range(0, num_timesteps, i)) == desired_count:
34 | return set(range(0, num_timesteps, i))
35 | raise ValueError(
36 | f"cannot create exactly {num_timesteps} steps with an integer stride"
37 | )
38 | section_counts = [int(x) for x in section_counts.split(",")]
39 | size_per = num_timesteps // len(section_counts)
40 | extra = num_timesteps % len(section_counts)
41 | start_idx = 0
42 | all_steps = []
43 | for i, section_count in enumerate(section_counts):
44 | size = size_per + (1 if i < extra else 0)
45 | if size < section_count:
46 | raise ValueError(
47 | f"cannot divide section of {size} steps into {section_count}"
48 | )
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 |
67 | :param use_timesteps: a collection (sequence or set) of timesteps from the
68 | original diffusion process to retain.
69 | :param kwargs: the kwargs to create the base diffusion process.
70 | """
71 |
72 | def __init__(self, use_timesteps, **kwargs):
73 | self.use_timesteps = set(use_timesteps)
74 | self.timestep_map = []
75 | self.original_num_steps = len(kwargs["betas"])
76 |
77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
78 | last_alpha_cumprod = 1.0
79 | new_betas = []
80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
81 | if i in self.use_timesteps:
82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
83 | last_alpha_cumprod = alpha_cumprod
84 | self.timestep_map.append(i)
85 | kwargs["betas"] = np.array(new_betas)
86 | super().__init__(**kwargs)
87 |
88 | def p_mean_variance(
89 | self, model, *args, **kwargs
90 | ): # pylint: disable=signature-differs
91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
92 |
93 | def training_losses(
94 | self, model, *args, **kwargs
95 | ): # pylint: disable=signature-differs
96 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
97 |
98 | def condition_mean(self, cond_fn, *args, **kwargs):
99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
100 |
101 | def condition_score(self, cond_fn, *args, **kwargs):
102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
103 |
104 | def _wrap_model(self, model):
105 | if isinstance(model, _WrappedModel):
106 | return model
107 | return _WrappedModel(
108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
109 | )
110 |
111 | def _scale_timesteps(self, t):
112 | # Scaling is done by the wrapped model.
113 | return t
114 |
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | if self.rescale_timesteps:
127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/guided_diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/_pdjscc/loss_utils/perceptual_similarity/dists_loss/DISTS_pt.py:
--------------------------------------------------------------------------------
1 | # This is a pytoch implementation of DISTS metric.
2 | # Requirements: python >= 3.6, pytorch >= 1.0
3 |
4 | import numpy as np
5 | import os, sys
6 | import torch
7 | from torchvision import models, transforms
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class L2pooling(nn.Module):
13 | def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
14 | super(L2pooling, self).__init__()
15 | self.padding = (filter_size - 2) // 2
16 | self.stride = stride
17 | self.channels = channels
18 | a = np.hanning(filter_size)[1:-1]
19 | g = torch.Tensor(a[:, None] * a[None, :])
20 | g = g / torch.sum(g)
21 | self.register_buffer('filter', g[None, None, :, :].repeat((self.channels, 1, 1, 1)))
22 |
23 | def forward(self, input):
24 | input = input ** 2
25 | out = F.conv2d(input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1])
26 | return (out + 1e-12).sqrt()
27 |
28 |
29 | class DISTS(torch.nn.Module):
30 | def __init__(self, load_weights=True):
31 | super(DISTS, self).__init__()
32 | vgg_pretrained_features = models.vgg16(pretrained=True).features
33 | self.stage1 = torch.nn.Sequential()
34 | self.stage2 = torch.nn.Sequential()
35 | self.stage3 = torch.nn.Sequential()
36 | self.stage4 = torch.nn.Sequential()
37 | self.stage5 = torch.nn.Sequential()
38 | for x in range(0, 4):
39 | self.stage1.add_module(str(x), vgg_pretrained_features[x])
40 | self.stage2.add_module(str(4), L2pooling(channels=64))
41 | for x in range(5, 9):
42 | self.stage2.add_module(str(x), vgg_pretrained_features[x])
43 | self.stage3.add_module(str(9), L2pooling(channels=128))
44 | for x in range(10, 16):
45 | self.stage3.add_module(str(x), vgg_pretrained_features[x])
46 | self.stage4.add_module(str(16), L2pooling(channels=256))
47 | for x in range(17, 23):
48 | self.stage4.add_module(str(x), vgg_pretrained_features[x])
49 | self.stage5.add_module(str(23), L2pooling(channels=512))
50 | for x in range(24, 30):
51 | self.stage5.add_module(str(x), vgg_pretrained_features[x])
52 |
53 | for param in self.parameters():
54 | param.requires_grad = False
55 |
56 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1))
57 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1))
58 |
59 | self.chns = [3, 64, 128, 256, 512, 512]
60 | self.register_parameter("alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
61 | self.register_parameter("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
62 | self.alpha.data.normal_(0.1, 0.01)
63 | self.beta.data.normal_(0.1, 0.01)
64 | if load_weights:
65 | weights = torch.load(os.path.join(sys.prefix, '/media/D/wangjun/DEEPJSCC/DISTS/weights.pt'))
66 | self.alpha.data = weights['alpha']
67 | self.beta.data = weights['beta']
68 |
69 | def forward_once(self, x):
70 | h = (x - self.mean) / self.std
71 | h = self.stage1(h)
72 | h_relu1_2 = h
73 | h = self.stage2(h)
74 | h_relu2_2 = h
75 | h = self.stage3(h)
76 | h_relu3_3 = h
77 | h = self.stage4(h)
78 | h_relu4_3 = h
79 | h = self.stage5(h)
80 | h_relu5_3 = h
81 | return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
82 |
83 | def forward(self, x, y, require_grad=False, batch_average=False):
84 | if require_grad:
85 | feats0 = self.forward_once(x)
86 | feats1 = self.forward_once(y)
87 | else:
88 | with torch.no_grad():
89 | feats0 = self.forward_once(x)
90 | feats1 = self.forward_once(y)
91 | dist1 = 0
92 | dist2 = 0
93 | c1 = 1e-6
94 | c2 = 1e-6
95 | w_sum = self.alpha.sum() + self.beta.sum()
96 | alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)
97 | beta = torch.split(self.beta / w_sum, self.chns, dim=1)
98 | for k in range(len(self.chns)):
99 | x_mean = feats0[k].mean([2, 3], keepdim=True)
100 | y_mean = feats1[k].mean([2, 3], keepdim=True)
101 | S1 = (2 * x_mean * y_mean + c1) / (x_mean ** 2 + y_mean ** 2 + c1)
102 | dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True)
103 |
104 | x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)
105 | y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)
106 | xy_cov = (feats0[k] * feats1[k]).mean([2, 3], keepdim=True) - x_mean * y_mean
107 | S2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
108 | dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)
109 |
110 | score = 1 - (dist1 + dist2).squeeze()
111 | if batch_average:
112 | return score.mean()
113 | else:
114 | return score
115 |
116 |
117 | def prepare_image(image, resize=True):
118 | if resize and min(image.size) > 256:
119 | image = transforms.functional.resize(image, 256)
120 | image = transforms.ToTensor()(image)
121 | return image.unsqueeze(0)
122 |
123 |
124 | if __name__ == '__main__':
125 | from PIL import Image
126 | import argparse
127 |
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--ref', type=str, default='/media/D/wangjun/kodak1/kodim01.png')
130 | parser.add_argument('--dist', type=str, default='/media/D/wangjun/kodak1/test.png')
131 | args = parser.parse_args()
132 |
133 | ref = prepare_image(Image.open(args.ref).convert("RGB"))
134 | dist = prepare_image(Image.open(args.dist).convert("RGB"))
135 | assert ref.shape == dist.shape
136 |
137 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
138 | model = DISTS().to(device)
139 | ref = ref.to(device)
140 | dist = dist.to(device)
141 | score = model(ref, dist)
142 | print(score.item())
--------------------------------------------------------------------------------
/_pdjscc/net/network.py:
--------------------------------------------------------------------------------
1 | from .encoder import *
2 | from .decoder import *
3 | from .discriminator import *
4 | from CommonModules.loss.distortion import Distortion
5 | from .channel import Channel
6 | from random import choice
7 | from CommonModules.loss.distortion import MS_SSIM
8 | from CommonModules.perceptual_similarity.perceptual_loss import PerceptualLoss
9 | from CommonModules.loss import gan_loss
10 | from collections import namedtuple
11 | from functools import partial
12 |
13 |
14 | def pad_factor(input_image, spatial_dims, factor):
15 | """Pad `input_image` (N,C,H,W) such that H and W are divisible by `factor`."""
16 |
17 | if isinstance(factor, int) is True:
18 | factor_H = factor
19 | factor_W = factor_H
20 | else:
21 | factor_H, factor_W = factor
22 |
23 | H, W = spatial_dims[0], spatial_dims[1]
24 | pad_H = (factor_H - (H % factor_H)) % factor_H
25 | pad_W = (factor_W - (W % factor_W)) % factor_W
26 | return F.pad(input_image, pad=(0, pad_W, 0, pad_H), mode='reflect')
27 |
28 |
29 | class ADJSCC(nn.Module):
30 | def __init__(self, config):
31 | super(ADJSCC, self).__init__()
32 | if config.logger:
33 | config.logger.info("【Network】: Built Distributed JSCC model, C={}, k/n={}".format(config.C, config.kdivn))
34 |
35 | self.config = config
36 | self.Encoder = Encoder(config)
37 | self.Decoder = Decoder(config)
38 | if config.use_discriminator:
39 | self.Discriminator = Discriminator(image_dims = config.image_dims, C=config.C)
40 | self.use_discriminator = config.use_discriminator
41 | self.channel = Channel(config)
42 | self.pass_channel = config.pass_channel
43 | self.MS_SSIM = MS_SSIM(data_range=1., levels=4, channel=3).cuda()
44 | self._lpips = PerceptualLoss(model='net-lin', net='alex', use_gpu=torch.cuda.is_available(), gpu_ids=[torch.device("cuda:0")])
45 | self.gan_loss = partial(gan_loss.gan_loss, config.gan_loss_type)
46 | self.distortion_loss = Distortion(config)
47 | def feature_pass_channel(self, feature):
48 | noisy_feature = self.channel(feature)
49 | return noisy_feature
50 | def discriminator_forward(self, reconstruction, input_image, latents_quantized, train_generator):
51 | """ Train on gen/real batches simultaneously. """
52 | x_gen = reconstruction
53 | x_real = input_image
54 | Disc_out = namedtuple("disc_out",
55 | ["D_real", "D_gen", "D_real_logits", "D_gen_logits"])
56 |
57 | # Alternate between training discriminator and compression models
58 | if train_generator is False:
59 | x_gen = x_gen.detach()
60 |
61 | D_in = torch.cat([x_real, x_gen], dim=0)
62 |
63 | latents = latents_quantized.detach()
64 | latents = torch.repeat_interleave(latents, 2, dim=0)
65 |
66 | D_out, D_out_logits = self.Discriminator(D_in, latents)
67 | D_out = torch.squeeze(D_out)
68 | D_out_logits = torch.squeeze(D_out_logits)
69 |
70 | D_real, D_gen = torch.chunk(D_out, 2, dim=0)
71 | D_real_logits, D_gen_logits = torch.chunk(D_out_logits, 2, dim=0)
72 |
73 | return Disc_out(D_real, D_gen, D_real_logits, D_gen_logits)
74 |
75 | def GAN_loss(self,reconstruction, input_image,latents_quantized,train_generator=False):
76 | """
77 | train_generator: Flag to send gradients to generator
78 | """
79 | disc_out = self.discriminator_forward(reconstruction, input_image,latents_quantized,train_generator)
80 | D_loss = self.gan_loss(disc_out, mode='discriminator_loss')
81 | G_loss = self.gan_loss(disc_out, mode='generator_loss')
82 | D_gen = torch.mean(disc_out.D_gen).item()
83 | D_real = torch.mean(disc_out.D_real).item()
84 |
85 |
86 | return D_loss, G_loss,D_gen,D_real
87 |
88 |
89 |
90 | def forward(self, input_sequence,train_generator = True,given_SNR=None):
91 | B, C, H, W = input_sequence.shape
92 | if self.training == False:
93 | n_encoder_downsamples = self.Encoder.n_downsampling_layers
94 | factor = 2 ** n_encoder_downsamples
95 | x = pad_factor(input_sequence, input_sequence.size()[2:], factor)
96 | else:
97 | x = input_sequence
98 |
99 | if given_SNR is not None:
100 | self.channel.chan_param = given_SNR
101 | else:
102 | random_SNR = choice(self.config.multiple_snr)
103 | self.channel.chan_param = random_SNR
104 |
105 | SNR = torch.ones([B, 1]).to(x.device) * self.channel.chan_param
106 | feature = self.Encoder(x, SNR)
107 | if self.pass_channel:
108 | noisy_feature = self.feature_pass_channel(feature)
109 | else:
110 | noisy_feature = feature
111 | x_hat = self.Decoder(noisy_feature, SNR)
112 | if self.training == False:
113 | x_hat = x_hat[:, :, :H, :W]
114 | mse_loss = self.distortion_loss(input_sequence, x_hat)
115 | lpips_loss = self._lpips(input_sequence, x_hat, normalize=True).mean()
116 | ms_ssim_loss = self.MS_SSIM(input_sequence, x_hat).mean()
117 | return ms_ssim_loss,mse_loss, lpips_loss, x_hat
118 | else:
119 | mse_loss = self.distortion_loss(input_sequence, x_hat)
120 | lpips_loss = self._lpips(input_sequence,x_hat,normalize=True).mean()
121 | ms_ssim_loss = self.MS_SSIM(input_sequence, x_hat).mean()
122 | if self.use_discriminator:
123 | D_loss, G_loss,D_gen,D_real = self.GAN_loss(x_hat,x,feature,train_generator)
124 | return ms_ssim_loss,mse_loss,lpips_loss,x_hat,D_loss, G_loss,D_gen,D_real
125 | else:
126 | return ms_ssim_loss, mse_loss, lpips_loss, x_hat
127 |
128 |
129 | if __name__ == '__main__':
130 | import torch
131 | import torch.nn.functional as F
132 | from ADJSCC.config import config
133 |
134 | input_Tensor = torch.ones([2, 3, 256, 256]).cuda()
135 | model = ADJSCC(config).cuda()
136 | recon_image, distortion_loss = model(input_Tensor)
137 | print(recon_image.shape)
138 |
--------------------------------------------------------------------------------
/_pdjscc/loss_utils/perceptual_similarity/dists_loss/DISTS/DISTS_pt.py:
--------------------------------------------------------------------------------
1 | # This is a pytoch implementation of DISTS metric.
2 | # Requirements: python >= 3.6, pytorch >= 1.0
3 |
4 | import numpy as np
5 | import os, sys
6 | import torch
7 | from torchvision import models, transforms
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import os
11 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
12 |
13 | class L2pooling(nn.Module):
14 | def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0):
15 | super(L2pooling, self).__init__()
16 | self.padding = (filter_size - 2) // 2
17 | self.stride = stride
18 | self.channels = channels
19 | a = np.hanning(filter_size)[1:-1]
20 | g = torch.Tensor(a[:, None] * a[None, :])
21 | g = g / torch.sum(g)
22 | self.register_buffer('filter', g[None, None, :, :].repeat((self.channels, 1, 1, 1)))
23 |
24 | def forward(self, input):
25 | input = input ** 2
26 | out = F.conv2d(input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1])
27 | return (out + 1e-12).sqrt()
28 |
29 |
30 | class DISTS(torch.nn.Module):
31 | def __init__(self, load_weights=True):
32 | super(DISTS, self).__init__()
33 | vgg_pretrained_features = models.vgg16(pretrained=True).features
34 | self.stage1 = torch.nn.Sequential()
35 | self.stage2 = torch.nn.Sequential()
36 | self.stage3 = torch.nn.Sequential()
37 | self.stage4 = torch.nn.Sequential()
38 | self.stage5 = torch.nn.Sequential()
39 | for x in range(0, 4):
40 | self.stage1.add_module(str(x), vgg_pretrained_features[x])
41 | self.stage2.add_module(str(4), L2pooling(channels=64))
42 | for x in range(5, 9):
43 | self.stage2.add_module(str(x), vgg_pretrained_features[x])
44 | self.stage3.add_module(str(9), L2pooling(channels=128))
45 | for x in range(10, 16):
46 | self.stage3.add_module(str(x), vgg_pretrained_features[x])
47 | self.stage4.add_module(str(16), L2pooling(channels=256))
48 | for x in range(17, 23):
49 | self.stage4.add_module(str(x), vgg_pretrained_features[x])
50 | self.stage5.add_module(str(23), L2pooling(channels=512))
51 | for x in range(24, 30):
52 | self.stage5.add_module(str(x), vgg_pretrained_features[x])
53 |
54 | for param in self.parameters():
55 | param.requires_grad = False
56 |
57 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1))
58 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1))
59 |
60 | self.chns = [3, 64, 128, 256, 512, 512]
61 | self.register_parameter("alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
62 | self.register_parameter("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
63 | self.alpha.data.normal_(0.1, 0.01)
64 | self.beta.data.normal_(0.1, 0.01)
65 | if load_weights:
66 | weights = torch.load(os.path.join(sys.prefix,'/media/D/wangjun/DEEPJSCC/DISTS/weights.pt'))
67 | self.alpha.data = weights['alpha']
68 | self.beta.data = weights['beta']
69 |
70 | def forward_once(self, x):
71 | h = (x - self.mean) / self.std
72 | h = self.stage1(h)
73 | h_relu1_2 = h
74 | h = self.stage2(h)
75 | h_relu2_2 = h
76 | h = self.stage3(h)
77 | h_relu3_3 = h
78 | h = self.stage4(h)
79 | h_relu4_3 = h
80 | h = self.stage5(h)
81 | h_relu5_3 = h
82 | return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]
83 |
84 | def forward(self, x, y, require_grad=False, batch_average=False):
85 | if require_grad:
86 | feats0 = self.forward_once(x)
87 | feats1 = self.forward_once(y)
88 | else:
89 | with torch.no_grad():
90 | feats0 = self.forward_once(x)
91 | feats1 = self.forward_once(y)
92 | dist1 = 0
93 | dist2 = 0
94 | c1 = 1e-6
95 | c2 = 1e-6
96 | w_sum = self.alpha.sum() + self.beta.sum()
97 | alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)
98 | beta = torch.split(self.beta / w_sum, self.chns, dim=1)
99 | for k in range(len(self.chns)):
100 | x_mean = feats0[k].mean([2, 3], keepdim=True)
101 | y_mean = feats1[k].mean([2, 3], keepdim=True)
102 | S1 = (2 * x_mean * y_mean + c1) / (x_mean ** 2 + y_mean ** 2 + c1)
103 | dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True)
104 |
105 | x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)
106 | y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)
107 | xy_cov = (feats0[k] * feats1[k]).mean([2, 3], keepdim=True) - x_mean * y_mean
108 | S2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
109 | dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)
110 |
111 | score = 1 - (dist1 + dist2).squeeze()
112 | if batch_average:
113 | return score.mean()
114 | else:
115 | return score
116 |
117 |
118 | def prepare_image(image, resize=True):
119 | if resize and min(image.size) > 256:
120 | image = transforms.functional.resize(image, 256)
121 | image = transforms.ToTensor()(image)
122 | return image.unsqueeze(0)
123 |
124 |
125 | if __name__ == '__main__':
126 | from PIL import Image
127 | import argparse
128 |
129 | parser = argparse.ArgumentParser()
130 | parser.add_argument('--ref', type=str, default='/media/D/wangjun/kodak1/kodim01.png')
131 | parser.add_argument('--dist', type=str, default='/media/D/wangjun/kodak1/test.png')
132 | args = parser.parse_args()
133 |
134 | ref = prepare_image(Image.open(args.ref).convert("RGB"))
135 | dist = prepare_image(Image.open(args.dist).convert("RGB"))
136 | assert ref.shape == dist.shape
137 |
138 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
139 | model = DISTS().to(device)
140 | ref = ref.to(device)
141 | dist = dist.to(device)
142 | score = model(ref, dist)
143 | print(score.item())
144 | # score: 0.3347
--------------------------------------------------------------------------------
/_pdjscc/loss_utils/perceptual_similarity/perceptual_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | # from skimage.measure import compare_ssim
7 | import torch
8 |
9 | from . import dist_model
10 |
11 |
12 | class PerceptualLoss(torch.nn.Module):
13 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0],
14 | version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric)
15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16 | super(PerceptualLoss, self).__init__()
17 | print('Setting up Perceptual loss...')
18 | self.use_gpu = use_gpu
19 | self.spatial = spatial
20 | self.gpu_ids = gpu_ids
21 | self.model = dist_model.DistModel()
22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial,
23 | gpu_ids=gpu_ids, version=version)
24 | print('...[%s] initialized' % self.model.name())
25 | print('...Done')
26 |
27 | def forward(self, pred, target, normalize=False):
28 | """
29 | Pred and target are Variables.
30 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
31 | If normalize is False, assumes the images are already between [-1,+1]
32 |
33 | Inputs pred and target are Nx3xHxW
34 | Output pytorch Variable N long
35 | """
36 |
37 | if normalize:
38 | target = 2 * target - 1
39 | pred = 2 * pred - 1
40 |
41 | return self.model.forward(target, pred)
42 |
43 |
44 | def normalize_tensor(in_feat, eps=1e-10):
45 | l2_norm = torch.sum(in_feat ** 2, dim=1, keepdim=True)
46 | norm_factor = torch.sqrt(l2_norm + eps)
47 | # return in_feat/(norm_factor+eps)
48 | return in_feat / (norm_factor)
49 |
50 |
51 | def l2(p0, p1, range=255.):
52 | return .5 * np.mean((p0 / range - p1 / range) ** 2)
53 |
54 |
55 | def psnr(p0, p1, peak=255.):
56 | return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2))
57 |
58 |
59 | # def dssim(p0, p1, range=255.):
60 | # return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
61 |
62 | def rgb2lab(in_img, mean_cent=False):
63 | from skimage import color
64 | img_lab = color.rgb2lab(in_img)
65 | if (mean_cent):
66 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50
67 | return img_lab
68 |
69 |
70 | def tensor2np(tensor_obj):
71 | # change dimension of a tensor object into a numpy array
72 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
73 |
74 |
75 | def np2tensor(np_obj):
76 | # change dimenion of np array into tensor array
77 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
78 |
79 |
80 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
81 | # image tensor to lab tensor
82 | from skimage import color
83 |
84 | img = tensor2im(image_tensor)
85 | img_lab = color.rgb2lab(img)
86 | if (mc_only):
87 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50
88 | if (to_norm and not mc_only):
89 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50
90 | img_lab = img_lab / 100.
91 |
92 | return np2tensor(img_lab)
93 |
94 |
95 | def tensorlab2tensor(lab_tensor, return_inbnd=False):
96 | from skimage import color
97 | import warnings
98 | warnings.filterwarnings("ignore")
99 |
100 | lab = tensor2np(lab_tensor) * 100.
101 | lab[:, :, 0] = lab[:, :, 0] + 50
102 |
103 | rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1)
104 | if (return_inbnd):
105 | # convert back to lab, see if we match
106 | lab_back = color.rgb2lab(rgb_back.astype('uint8'))
107 | mask = 1. * np.isclose(lab_back, lab, atol=2.)
108 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
109 | return (im2tensor(rgb_back), mask)
110 | else:
111 | return im2tensor(rgb_back)
112 |
113 |
114 | def rgb2lab(input):
115 | from skimage import color
116 | return color.rgb2lab(input / 255.)
117 |
118 |
119 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
120 | image_numpy = image_tensor[0].cpu().float().numpy()
121 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
122 | return image_numpy.astype(imtype)
123 |
124 |
125 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
126 | return torch.Tensor((image / factor - cent)
127 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
128 |
129 |
130 | def tensor2vec(vector_tensor):
131 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
132 |
133 |
134 | def voc_ap(rec, prec, use_07_metric=False):
135 | """ ap = voc_ap(rec, prec, [use_07_metric])
136 | Compute VOC AP given precision and recall.
137 | If use_07_metric is true, uses the
138 | VOC 07 11 point method (default:False).
139 | """
140 | if use_07_metric:
141 | # 11 point metric
142 | ap = 0.
143 | for t in np.arange(0., 1.1, 0.1):
144 | if np.sum(rec >= t) == 0:
145 | p = 0
146 | else:
147 | p = np.max(prec[rec >= t])
148 | ap = ap + p / 11.
149 | else:
150 | # correct AP calculation
151 | # first append sentinel values at the end
152 | mrec = np.concatenate(([0.], rec, [1.]))
153 | mpre = np.concatenate(([0.], prec, [0.]))
154 |
155 | # compute the precision envelope
156 | for i in range(mpre.size - 1, 0, -1):
157 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
158 |
159 | # to calculate area under PR curve, look for points
160 | # where X axis (recall) changes value
161 | i = np.where(mrec[1:] != mrec[:-1])[0]
162 |
163 | # and sum (\Delta recall) * prec
164 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
165 | return ap
166 |
167 |
168 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
169 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
170 | image_numpy = image_tensor[0].cpu().float().numpy()
171 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
172 | return image_numpy.astype(imtype)
173 |
174 |
175 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
176 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
177 | return torch.Tensor((image / factor - cent)
178 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
179 |
--------------------------------------------------------------------------------
/_pdjscc/loss_utils/perceptual_similarity/pretrained_networks.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torchvision import models as tv
4 |
5 | class squeezenet(torch.nn.Module):
6 | def __init__(self, requires_grad=False, pretrained=True):
7 | super(squeezenet, self).__init__()
8 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
9 | self.slice1 = torch.nn.Sequential()
10 | self.slice2 = torch.nn.Sequential()
11 | self.slice3 = torch.nn.Sequential()
12 | self.slice4 = torch.nn.Sequential()
13 | self.slice5 = torch.nn.Sequential()
14 | self.slice6 = torch.nn.Sequential()
15 | self.slice7 = torch.nn.Sequential()
16 | self.N_slices = 7
17 | for x in range(2):
18 | self.slice1.add_module(str(x), pretrained_features[x])
19 | for x in range(2,5):
20 | self.slice2.add_module(str(x), pretrained_features[x])
21 | for x in range(5, 8):
22 | self.slice3.add_module(str(x), pretrained_features[x])
23 | for x in range(8, 10):
24 | self.slice4.add_module(str(x), pretrained_features[x])
25 | for x in range(10, 11):
26 | self.slice5.add_module(str(x), pretrained_features[x])
27 | for x in range(11, 12):
28 | self.slice6.add_module(str(x), pretrained_features[x])
29 | for x in range(12, 13):
30 | self.slice7.add_module(str(x), pretrained_features[x])
31 | if not requires_grad:
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 | def forward(self, X):
36 | h = self.slice1(X)
37 | h_relu1 = h
38 | h = self.slice2(h)
39 | h_relu2 = h
40 | h = self.slice3(h)
41 | h_relu3 = h
42 | h = self.slice4(h)
43 | h_relu4 = h
44 | h = self.slice5(h)
45 | h_relu5 = h
46 | h = self.slice6(h)
47 | h_relu6 = h
48 | h = self.slice7(h)
49 | h_relu7 = h
50 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
51 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
52 |
53 | return out
54 |
55 |
56 | class alexnet(torch.nn.Module):
57 | def __init__(self, requires_grad=False, pretrained=True):
58 | super(alexnet, self).__init__()
59 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
60 | self.slice1 = torch.nn.Sequential()
61 | self.slice2 = torch.nn.Sequential()
62 | self.slice3 = torch.nn.Sequential()
63 | self.slice4 = torch.nn.Sequential()
64 | self.slice5 = torch.nn.Sequential()
65 | self.N_slices = 5
66 | for x in range(2):
67 | self.slice1.add_module(str(x), alexnet_pretrained_features[x])
68 | for x in range(2, 5):
69 | self.slice2.add_module(str(x), alexnet_pretrained_features[x])
70 | for x in range(5, 8):
71 | self.slice3.add_module(str(x), alexnet_pretrained_features[x])
72 | for x in range(8, 10):
73 | self.slice4.add_module(str(x), alexnet_pretrained_features[x])
74 | for x in range(10, 12):
75 | self.slice5.add_module(str(x), alexnet_pretrained_features[x])
76 | if not requires_grad:
77 | for param in self.parameters():
78 | param.requires_grad = False
79 |
80 | def forward(self, X):
81 | h = self.slice1(X)
82 | h_relu1 = h
83 | h = self.slice2(h)
84 | h_relu2 = h
85 | h = self.slice3(h)
86 | h_relu3 = h
87 | h = self.slice4(h)
88 | h_relu4 = h
89 | h = self.slice5(h)
90 | h_relu5 = h
91 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
92 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
93 |
94 | return out
95 |
96 | class vgg16(torch.nn.Module):
97 | def __init__(self, requires_grad=False, pretrained=True):
98 | super(vgg16, self).__init__()
99 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
100 | self.slice1 = torch.nn.Sequential()
101 | self.slice2 = torch.nn.Sequential()
102 | self.slice3 = torch.nn.Sequential()
103 | self.slice4 = torch.nn.Sequential()
104 | self.slice5 = torch.nn.Sequential()
105 | self.N_slices = 5
106 | for x in range(4):
107 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
108 | for x in range(4, 9):
109 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
110 | for x in range(9, 16):
111 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
112 | for x in range(16, 23):
113 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
114 | for x in range(23, 30):
115 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
116 | if not requires_grad:
117 | for param in self.parameters():
118 | param.requires_grad = False
119 |
120 | def forward(self, X):
121 | h = self.slice1(X)
122 | h_relu1_2 = h
123 | h = self.slice2(h)
124 | h_relu2_2 = h
125 | h = self.slice3(h)
126 | h_relu3_3 = h
127 | h = self.slice4(h)
128 | h_relu4_3 = h
129 | h = self.slice5(h)
130 | h_relu5_3 = h
131 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
132 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
133 |
134 | return out
135 |
136 |
137 |
138 | class resnet(torch.nn.Module):
139 | def __init__(self, requires_grad=False, pretrained=True, num=18):
140 | super(resnet, self).__init__()
141 | if(num==18):
142 | self.net = tv.resnet18(pretrained=pretrained)
143 | elif(num==34):
144 | self.net = tv.resnet34(pretrained=pretrained)
145 | elif(num==50):
146 | self.net = tv.resnet50(pretrained=pretrained)
147 | elif(num==101):
148 | self.net = tv.resnet101(pretrained=pretrained)
149 | elif(num==152):
150 | self.net = tv.resnet152(pretrained=pretrained)
151 | self.N_slices = 5
152 |
153 | self.conv1 = self.net.conv1
154 | self.bn1 = self.net.bn1
155 | self.relu = self.net.relu
156 | self.maxpool = self.net.maxpool
157 | self.layer1 = self.net.layer1
158 | self.layer2 = self.net.layer2
159 | self.layer3 = self.net.layer3
160 | self.layer4 = self.net.layer4
161 |
162 | def forward(self, X):
163 | h = self.conv1(X)
164 | h = self.bn1(h)
165 | h = self.relu(h)
166 | h_relu1 = h
167 | h = self.maxpool(h)
168 | h = self.layer1(h)
169 | h_conv2 = h
170 | h = self.layer2(h)
171 | h_conv3 = h
172 | h = self.layer3(h)
173 | h_conv4 = h
174 | h = self.layer4(h)
175 | h_conv5 = h
176 |
177 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
178 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
179 |
180 | return out
181 |
--------------------------------------------------------------------------------
/_ntsccp/net/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc
2 | # All rights reserved.
3 |
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted (subject to the limitations in the disclaimer
6 | # below) provided that the following conditions are met:
7 |
8 | # * Redistributions of source code must retain the above copyright notice,
9 | # this list of conditions and the following disclaimer.
10 | # * Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | # * Neither the name of InterDigital Communications, Inc nor the names of its
14 | # contributors may be used to endorse or promote products derived from this
15 | # software without specific prior written permission.
16 |
17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
30 | import torch
31 | import torch.nn as nn
32 | from torch.autograd import Function
33 |
34 |
35 | def find_named_module(module, query):
36 | """Helper function to find a named module. Returns a `nn.Module` or `None`
37 |
38 | Args:
39 | module (nn.Module): the root module
40 | query (str): the module name to find
41 |
42 | Returns:
43 | nn.Module or None
44 | """
45 |
46 | return next((m for n, m in module.named_modules() if n == query), None)
47 |
48 |
49 | def find_named_buffer(module, query):
50 | """Helper function to find a named buffer. Returns a `torch.Tensor` or `None`
51 |
52 | Args:
53 | module (nn.Module): the root module
54 | query (str): the buffer name to find
55 |
56 | Returns:
57 | torch.Tensor or None
58 | """
59 | return next((b for n, b in module.named_buffers() if n == query), None)
60 |
61 |
62 | def _update_registered_buffer(
63 | module,
64 | buffer_name,
65 | state_dict_key,
66 | state_dict,
67 | policy="resize_if_empty",
68 | dtype=torch.int,
69 | ):
70 | new_size = state_dict[state_dict_key].size()
71 | registered_buf = find_named_buffer(module, buffer_name)
72 |
73 | if policy in ("resize_if_empty", "resize"):
74 | if registered_buf is None:
75 | raise RuntimeError(f'buffer "{buffer_name}" was not registered')
76 |
77 | if policy == "resize" or registered_buf.numel() == 0:
78 | registered_buf.resize_(new_size)
79 |
80 | elif policy == "register":
81 | if registered_buf is not None:
82 | raise RuntimeError(f'buffer "{buffer_name}" was already registered')
83 |
84 | module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0))
85 |
86 | else:
87 | raise ValueError(f'Invalid policy "{policy}"')
88 |
89 |
90 | def update_registered_buffers(
91 | module,
92 | module_name,
93 | buffer_names,
94 | state_dict,
95 | policy="resize_if_empty",
96 | dtype=torch.int,
97 | ):
98 | """Update the registered buffers in a module according to the tensors sized
99 | in a state_dict.
100 |
101 | (There's no way in torch to directly load a buffer with a dynamic size)
102 |
103 | Args:
104 | module (nn.Module): the module
105 | module_name (str): module name in the state dict
106 | buffer_names (list(str)): list of the buffer names to resize in the module
107 | state_dict (dict): the state dict
108 | policy (str): Update policy, choose from
109 | ('resize_if_empty', 'resize', 'register')
110 | dtype (dtype): Type of buffer to be registered (when policy is 'register')
111 | """
112 | valid_buffer_names = [n for n, _ in module.named_buffers()]
113 | for buffer_name in buffer_names:
114 | if buffer_name not in valid_buffer_names:
115 | raise ValueError(f'Invalid buffer name "{buffer_name}"')
116 |
117 | for buffer_name in buffer_names:
118 | _update_registered_buffer(
119 | module,
120 | buffer_name,
121 | f"{module_name}.{buffer_name}",
122 | state_dict,
123 | policy,
124 | dtype,
125 | )
126 |
127 |
128 | def conv(in_channels, out_channels, kernel_size=5, stride=2):
129 | return nn.Conv2d(
130 | in_channels,
131 | out_channels,
132 | kernel_size=kernel_size,
133 | stride=stride,
134 | padding=kernel_size // 2,
135 | )
136 |
137 |
138 | def deconv(in_channels, out_channels, kernel_size=5, stride=2):
139 | return nn.ConvTranspose2d(
140 | in_channels,
141 | out_channels,
142 | kernel_size=kernel_size,
143 | stride=stride,
144 | output_padding=stride - 1,
145 | padding=kernel_size // 2,
146 | )
147 |
148 |
149 | def quantize_ste(x):
150 | """Differentiable quantization via the Straight-Through-Estimator."""
151 | # STE (straight-through estimator) trick: x_hard - x_soft.detach() + x_soft
152 | return (torch.round(x) - x).detach() + x
153 |
154 |
155 | def DEMUX(x):
156 | B, C, H, W = x.shape
157 | x_part1 = torch.ones_like(x)[:, :, :H // 2, :W]
158 | x_part1[:, :, :, 0::2] = x[:, :, 0::2, 0::2]
159 | x_part1[:, :, :, 1::2] = x[:, :, 1::2, 1::2]
160 |
161 | x_part2 = torch.ones_like(x)[:, :, :H // 2, :W]
162 | x_part2[:, :, :, 0::2] = x[:, :, 1::2, 0::2]
163 | x_part2[:, :, :, 1::2] = x[:, :, 0::2, 1::2]
164 | return x_part1, x_part2
165 |
166 |
167 | def MUX(x_part1, x_part2):
168 | # B, C, H_half, W = x_anchor.shape
169 | # H = H_half * 2
170 | x = torch.cat([torch.ones_like(x_part1), torch.ones_like(x_part1)], dim=2)
171 | x[:, :, 0::2, 0::2] = x_part1[:, :, :, 0::2]
172 | x[:, :, 1::2, 1::2] = x_part1[:, :, :, 1::2]
173 | x[:, :, 1::2, 0::2] = x_part2[:, :, :, 0::2]
174 | x[:, :, 0::2, 1::2] = x_part2[:, :, :, 1::2]
175 | return x
176 |
177 |
178 | # pylint: disable=W0221
179 | class LowerBound(Function):
180 | @staticmethod
181 | def forward(ctx, inputs, bound):
182 | b = torch.ones_like(inputs) * bound
183 | ctx.save_for_backward(inputs, b)
184 | return torch.max(inputs, b)
185 |
186 | @staticmethod
187 | def backward(ctx, grad_output):
188 | inputs, b = ctx.saved_tensors
189 | pass_through_1 = inputs >= b
190 | pass_through_2 = grad_output < 0
191 |
192 | pass_through = pass_through_1 | pass_through_2
193 | return pass_through.type(grad_output.dtype) * grad_output, None
194 |
--------------------------------------------------------------------------------
/conditioning_method/diffcom.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | from utils import utils_model
6 |
7 | __CONDITIONING_METHOD__ = {}
8 |
9 |
10 | def register_conditioning_method(name: str):
11 | def wrapper(cls):
12 | if __CONDITIONING_METHOD__.get(name, None):
13 | raise NameError(f"Name {name} is already registered!")
14 | __CONDITIONING_METHOD__[name] = cls
15 | return cls
16 |
17 | return wrapper
18 |
19 |
20 | def get_conditioning_method(name: str, **kwargs):
21 | if __CONDITIONING_METHOD__.get(name, None) is None:
22 | raise NameError(f"Name {name} is not defined!")
23 | return __CONDITIONING_METHOD__[name](**kwargs)
24 |
25 |
26 | class ConsistencyLoss(nn.Module):
27 | def __init__(self, config, device):
28 | super().__init__()
29 | self.config = config
30 | zeta = config.diffcom_series[config.conditioning_method]['zeta']
31 | gamma = config.diffcom_series[config.conditioning_method]['gamma']
32 | self.weight = {
33 | 'x_mse': gamma,
34 | 'ofdm_sig': zeta,
35 | }
36 |
37 | def forward(self, measurement, x_0_hat, cof, operator, operation_mode):
38 | x_0_hat = (x_0_hat / 2 + 0.5) # .clip(0, 1)
39 | s = operator.encode(x_0_hat)
40 | if operation_mode == 'latent':
41 | recon_measurement = {
42 | 'ofdm_sig': operator.forward(s, cof)
43 | }
44 | elif operation_mode == 'pixel':
45 | recon_measurement = {
46 | 'x_mse': x_0_hat
47 | }
48 | elif operation_mode == 'joint':
49 | ofdm_sig = operator.forward(s, cof)
50 | s_hat = operator.transpose(ofdm_sig, cof)
51 | x_confirming = operator.decode(s_hat)
52 | recon_measurement = {
53 | 'ofdm_sig': ofdm_sig,
54 | 'x_mse': x_confirming
55 | }
56 | loss = {}
57 | for key in recon_measurement.keys():
58 | loss[key] = self.weight[key] * torch.linalg.norm(measurement[key] - recon_measurement[key])
59 | return loss
60 |
61 |
62 | def get_lr(config, t, T):
63 | lr_base = config['learning_rate']
64 | # exponential decay to 0
65 | if config['lr_schedule'] == 'exp':
66 | lr_min = config['lr_min']
67 | lr = lr_min + (lr_base - lr_min) * np.exp(-t / T)
68 | # linear decay
69 | elif config['lr_schedule'] == 'linear':
70 | lr_min = config['lr_min']
71 | lr = lr_min + (lr_base - lr_min) * (t / T)
72 | # constant
73 | else:
74 | lr = lr_base
75 | return lr
76 |
77 |
78 | @register_conditioning_method(name='diffcom')
79 | class DiffCom(nn.Module):
80 | def __init__(self):
81 | super().__init__()
82 | self.conditioning_method = 'latent'
83 |
84 | def conditioning(self, config, i, ns, x_t, h_t, power,
85 | measurement, unet, diffusion, operator, loss_wrapper, last_timestep):
86 | h_0_hat = h_t
87 | h_t_minus_1_prime = h_t
88 | h_t_minus_1 = h_t
89 |
90 | t_step = ns.seq[i]
91 | sigma_t = ns.reduced_alpha_cumprod[t_step].cpu().numpy()
92 | x_t = x_t.requires_grad_()
93 | x_t_minus_1_prime, x_0_hat, _ = utils_model.model_fn(x_t,
94 | noise_level=sigma_t * 255,
95 | model_out_type='pred_x_prev_and_start', \
96 | model_diffusion=unet,
97 | diffusion=diffusion,
98 | ddim_sample=config.ddim_sample)
99 | if last_timestep:
100 | loss = loss_wrapper.forward(measurement, x_0_hat, h_0_hat, operator, self.conditioning_method)
101 | return x_0_hat, h_0_hat, x_t_minus_1_prime, h_t_minus_1_prime, loss
102 | else:
103 | loss = loss_wrapper.forward(measurement, x_0_hat, h_t, operator, self.conditioning_method)
104 | total_loss = sum(loss.values())
105 | x_grad = torch.autograd.grad(outputs=total_loss, inputs=x_t)[0]
106 | learning_rate = get_lr(config.diffcom_series[config.conditioning_method], t_step,
107 | ns.t_start - 1)
108 | x_t_minus_1 = x_t_minus_1_prime - x_grad * learning_rate
109 | x_t_minus_1 = x_t_minus_1.detach_()
110 | return x_0_hat, h_0_hat, x_t_minus_1, h_t_minus_1, loss
111 |
112 |
113 | @register_conditioning_method(name='hifi_diffcom')
114 | class HiFiDiffCom(DiffCom):
115 | def __init__(self):
116 | super().__init__()
117 | self.conditioning_method = 'joint'
118 |
119 |
120 | @register_conditioning_method(name='blind_diffcom')
121 | class BlindDiffCom(DiffCom):
122 | def __init__(self):
123 | super().__init__()
124 |
125 | def conditioning(self, config, i, ns, x_t, h_t, power,
126 | measurement, unet, diffusion, operator, loss_wrapper, last_timestep):
127 | t_step = ns.seq[i]
128 | sigma_t = ns.reduced_alpha_cumprod[t_step].cpu().numpy()
129 | x_t = x_t.requires_grad_()
130 | x_t_minus_1_prime, x_0_hat, _ = utils_model.model_fn(x_t,
131 | noise_level=sigma_t * 255,
132 | model_out_type='pred_x_prev_and_start', \
133 | model_diffusion=unet,
134 | diffusion=diffusion,
135 | ddim_sample=config.ddim_sample)
136 |
137 | assert (config.conditioning_method == 'blind_diffcom')
138 |
139 | h_t = h_t.requires_grad_()
140 | h_score = - h_t / (power ** 2)
141 | h_0_hat = (1 / ns.alphas_cumprod[t_step]) * (
142 | h_t + ns.sqrt_1m_alphas_cumprod[t_step] * h_score)
143 | h_t_minus_1_prime = ns.posterior_mean_coef2[t_step] * h_t + ns.posterior_mean_coef1[t_step] * h_0_hat + \
144 | ns.posterior_variance[t_step] * (torch.randn_like(h_t) + 1j * torch.randn_like(h_t))
145 |
146 | if last_timestep:
147 | loss = loss_wrapper.forward(measurement, x_0_hat, h_0_hat, operator, self.conditioning_method)
148 | return x_0_hat, h_0_hat, x_t_minus_1_prime, h_t_minus_1_prime, loss
149 | else:
150 | loss = loss_wrapper.forward(measurement, x_0_hat, h_0_hat, operator, self.conditioning_method)
151 | total_loss = sum(loss.values())
152 | x_grad, h_t_grad = torch.autograd.grad(outputs=total_loss, inputs=[x_t, h_t])
153 | learning_rate = config.diffcom_series['blind_diffcom']['learning_rate']
154 | learning_rate = (learning_rate - 0) * (t_step / (ns.t_start - 1))
155 | x_t_minus_1 = x_t_minus_1_prime - x_grad * learning_rate
156 | x_t_minus_1 = x_t_minus_1.detach_()
157 | lr_h = config.diffcom_series['blind_diffcom']['h_lr']
158 | lr_h = (lr_h - 0) * (t_step / (ns.t_start - 1))
159 | h_t_minus_1 = h_t_minus_1_prime - h_t_grad * lr_h
160 | h_t_minus_1 = h_t_minus_1.detach_()
161 | return x_0_hat, h_0_hat, x_t_minus_1, h_t_minus_1, loss
162 |
--------------------------------------------------------------------------------
/website/src/components/section3/Section3.js:
--------------------------------------------------------------------------------
1 | import React, {useState} from "react";
2 | import {Button, Grid, Stack, ToggleButton, ToggleButtonGroup} from '@mui/material';
3 | import ReactSwipe from 'react-swipe'
4 | import {ReactCompareSlider, ReactCompareSliderImage} from 'react-compare-slider';
5 | import {AiFillLeftCircle, AiFillRightCircle} from 'react-icons/ai'
6 |
7 | const CenterWrapper = (props) => {
8 | return (
9 |
10 |