├── E2E_framework_DDM_AWGN.ipynb ├── E2E_framework_DDM_Rayleigh.ipynb ├── E2E_framework_DDM_SSPA.ipynb ├── README.md ├── figures ├── ECDF_histogram_AWGN.pdf ├── ECDF_histogram_Rayleigh.pdf ├── ECDF_histogram_SSPA.pdf ├── EbN0_vs._testSER_sample_AWGN.pdf ├── EbN0_vs._testSER_sample_Rayleigh.pdf └── EbN0_vs._testSER_sample_SSPA.pdf ├── src ├── .ipynb_checkpoints │ ├── channel_models-checkpoint.py │ ├── models-checkpoint.py │ ├── trainer-checkpoint.py │ ├── utils-checkpoint.py │ └── utils_plot-checkpoint.py ├── __pycache__ │ ├── channel_models.cpython-310.pyc │ ├── channel_models.cpython-39.pyc │ ├── ema.cpython-310.pyc │ ├── ema.cpython-39.pyc │ ├── models.cpython-310.pyc │ ├── models.cpython-39.pyc │ ├── trainer.cpython-310.pyc │ ├── trainer.cpython-39.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-39.pyc │ ├── utils_plot.cpython-310.pyc │ └── utils_plot.cpython-39.pyc ├── channel_models.py ├── ema.py ├── models.py ├── trainer.py ├── utils.py └── utils_plot.py └── state_dict ├── channel_gen_AWGN_M16N7_T100.pt ├── channel_gen_Rayleigh_M16N7_T100.pt ├── channel_gen_SSPA_M64N8_T100.pt ├── decoder_AWGN_M16N7_T100DDIM20.pt ├── decoder_Rayleigh_M16N7_DDPM_T100DDIM20.pt ├── decoder_SSPA_M64N8_T100DDPM1.pt ├── encoder_AWGN_M16N7_T100DDIM20.pt ├── encoder_Rayleigh_M16N7_DDPM_T100DDIM20.pt └── encoder_SSPA_M64N8_T100DDPM1.pt /README.md: -------------------------------------------------------------------------------- 1 | # learning-E2E-channel-coding-with-diffusion-models 2 | This repository contains the simulations used in the paper [1]. It is open to the public, so please feel free to play around or use it for your research, study, education, etc. 3 | In the repository, there are notebook files enabling an end-to-end (E2E) framework for each of three channel models: AWGN channel, real Rayleigh fading channel, and a channel model with solid-state power amplifier (SSPA). 4 | The E2E framework's final goal is to learn a pair of a neural encoder and a neural decoder achieving a small symbol error rate. 5 | In the real world, the channel models are not clearly known nor differentiable, 6 | so we learn the channel distribution by using a diffusion model and replace the channel block with a synthesized channel sampled by the learned diffusion model. 7 | Please refer to our papers for more details. You can find the links of them below. 8 | This repository has the source code that can reproduce the simulations in the paper and some visualizations that can help understanding. 9 | 10 | 11 | ## Implementation Environment 12 | Python >= 3.7 13 | 14 | PyTorch >= 1.6 15 | 16 | CUDA (I don't know exactly, but I used) 11.6 17 | 18 | ## Parameters and Options Available 19 | The source codes are in the .ipynb format and are straightforward. You can easily find the parameters and simulation settings in the top cells and change them as you want. The adjustable parameters and available options are listed below. 20 | 21 | ### Communication Channel 22 | * Channel models: AWGN, Rayleigh, SSPA 23 | * Parameters: cardinality of the message set 'M', block length 'n', Training Eb/N0 'TRAINING_Eb/N0'. 24 | 25 | ### Difusion-Denoising Models 26 | * Prediction variable and loss: 'epsilon', 'v' 27 | * Sampling algorithm: 'DDPM', 'DDIM' 28 | * Diffusion noise (beta) scheduling: sigmoid, cosine, no scheduling (constant). 29 | * Parameters: # of diffusion steps 'num_steps', step size of skipped sampling 'skip'. 30 | 31 | ## Sources 32 | Our papers about this study: [conference version](https://scholar.google.com/citations?view_op=view_citation&hl=ko&user=VOl55dwAAAAJ&citation_for_view=VOl55dwAAAAJ:IjCSPb-OGe4C), [full journal version](https://arxiv.org/submit/5126965/view) 33 | 34 | [hojonathanho/diffusion](https://https://github.com/hojonathanho/diffusion): the original paper of the diffusion denoising probabilistic models. 35 | 36 | [azad-academy/denoising-diffusion-model](https://github.com/azad-academy/denoising-diffusion-model): we referred to the simple implementation of the diffusion models with MLPs. 37 | 38 | [karpathy/minGPT](https://github.com/karpathy/minGPT): we referred to this repository's structure to build up the skeleton code mainly for the autoencoder part. 39 | 40 | ## Acknowledgement 41 | This repository is developed together with [Rick Fritschek](https://github.com/Fritschek/). 42 | -------------------------------------------------------------------------------- /figures/ECDF_histogram_AWGN.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/figures/ECDF_histogram_AWGN.pdf -------------------------------------------------------------------------------- /figures/ECDF_histogram_Rayleigh.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/figures/ECDF_histogram_Rayleigh.pdf -------------------------------------------------------------------------------- /figures/ECDF_histogram_SSPA.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/figures/ECDF_histogram_SSPA.pdf -------------------------------------------------------------------------------- /figures/EbN0_vs._testSER_sample_AWGN.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/figures/EbN0_vs._testSER_sample_AWGN.pdf -------------------------------------------------------------------------------- /figures/EbN0_vs._testSER_sample_Rayleigh.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/figures/EbN0_vs._testSER_sample_Rayleigh.pdf -------------------------------------------------------------------------------- /figures/EbN0_vs._testSER_sample_SSPA.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/figures/EbN0_vs._testSER_sample_SSPA.pdf -------------------------------------------------------------------------------- /src/.ipynb_checkpoints/channel_models-checkpoint.py: -------------------------------------------------------------------------------- 1 | # The channel models are defined here. 2 | 3 | import torch 4 | import numpy as np 5 | 6 | def ch_AWGN(ch_input, AWGN_std, device): 7 | ch_input = ch_input.to(device) 8 | ch_output = ch_input + torch.normal(mean=0, std=AWGN_std, size = ch_input.size()).to(device) 9 | return ch_output 10 | 11 | 12 | def ch_Rayleigh_AWGN(codeword, noise_std, dvc): 13 | codeword = codeword.to(dvc) 14 | i,j = codeword.size() 15 | h_ray = (1/np.sqrt(2))*torch.sqrt( torch.randn((i,j), device = dvc)**2 + torch.randn((i,j), device = dvc)**2) 16 | 17 | channel_noise = noise_std*torch.randn((i,j), device = dvc) 18 | 19 | return h_ray* codeword + channel_noise 20 | 21 | 22 | def ch_SSPA(x, sigma_n, device, p = 3.0, A_0 = 1.5, v = 5.): # A_0 limiting output amplitude, v is small signal gain. 23 | assert x.size(1) % 2 == 0 24 | 25 | # x= ([1 2 3 4] ,[5 6 7 8]) 26 | dim = int(x.size(1) //2) 27 | x_2d = x.reshape(-1,2) #x_2d =([1 2], [3 4] , [5 6], [7 8]) 28 | A = torch.sum(x_2d ** 2, dim=1) ** 0.5 # Amplitude 29 | A_ratio = v / (1+ (v*A/A_0)**(2*p) )**(1/2/p) # g_A / A 30 | x_amp_2d = torch.mul(A_ratio.reshape(-1,1), x_2d) 31 | x_amp = x_amp_2d.reshape(-1, 2*dim) 32 | y = x_amp + (sigma_n/2**0.5) * torch.randn_like(x) 33 | 34 | return y 35 | -------------------------------------------------------------------------------- /src/.ipynb_checkpoints/models-checkpoint.py: -------------------------------------------------------------------------------- 1 | #An Implementation of Diffusion Network Model 2 | #Oringinal source: https://github.com/acids-ircam/diffusion_models 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch 7 | import numpy as np 8 | from channel_models import * 9 | 10 | class ConditionalLinear(nn.Module): 11 | def __init__(self, num_in, num_out, n_steps): 12 | super(ConditionalLinear, self).__init__() 13 | use_cuda = torch.cuda.is_available() 14 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 15 | self.num_out = num_out 16 | self.lin = nn.Linear(num_in, num_out, device =self.this_device ) 17 | self.embed = nn.Embedding(n_steps, num_out, device = self.this_device) # for each step, the distribution is distinct. Given the number of steps, the embedding returns a gamma for it. 18 | self.embed.weight.data.uniform_() 19 | 20 | def forward(self, x, y): 21 | x = x.type(torch.FloatTensor).to(self.this_device) 22 | out = self.lin(x).to(self.this_device) 23 | gamma = self.embed(y).to(self.this_device) 24 | out = gamma.view(-1, self.num_out).to(self.this_device) * out # Why is it multiplicative? not additive? 25 | return out 26 | 27 | class ConditionalModel(nn.Module): 28 | def __init__(self, n_steps): 29 | super(ConditionalModel, self).__init__() 30 | use_cuda = torch.cuda.is_available() 31 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 32 | self.lin1 = ConditionalLinear(2, 128, n_steps) 33 | self.lin2 = ConditionalLinear(128, 128, n_steps) 34 | self.lin3 = ConditionalLinear(128, 128, n_steps) 35 | self.lin4 = nn.Linear(128, 2, device =self.this_device) 36 | 37 | def forward(self, x, y): 38 | x = x.type(torch.FloatTensor).to(self.this_device) 39 | x = F.softplus(self.lin1(x, y)) 40 | x = F.softplus(self.lin2(x, y)) 41 | x = F.softplus(self.lin3(x, y)) 42 | return self.lin4(x) 43 | 44 | 45 | class ConditionalLinear_w_Condition(nn.Module): 46 | def __init__(self, num_in, num_out, n_steps): 47 | super(ConditionalLinear_w_Condition, self).__init__() 48 | use_cuda = torch.cuda.is_available() 49 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 50 | self.num_out = num_out 51 | self.lin = nn.Linear(num_in, num_out, device = self.this_device) 52 | self.embed = nn.Embedding(n_steps, num_out, device = self.this_device) # for each step, the distribution is distinct. Given the number of steps, the embedding returns a gamma for it. 53 | self.embed.weight.data.uniform_() 54 | 55 | def forward(self, x, y, c): 56 | x = x.type(torch.FloatTensor).to(self.this_device) 57 | c = c.type(torch.FloatTensor).to(self.this_device) 58 | out = self.lin(torch.cat((x,c), dim=1)).to(self.this_device) 59 | gamma = self.embed(y).to(self.this_device) 60 | out = gamma.view(-1, self.num_out).to(self.this_device) * out # Why is it multiplicative? not additive? 61 | return out 62 | 63 | 64 | class ConditionalModel_w_Condition(nn.Module): 65 | def __init__(self, n_steps, M, n, N): 66 | super(ConditionalModel_w_Condition, self).__init__() 67 | use_cuda = torch.cuda.is_available() 68 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 69 | self.lin1 = ConditionalLinear_w_Condition(2*n+M, N, n_steps) # hidden layer 128 70 | self.lin2 = ConditionalLinear(N, N, n_steps) 71 | self.lin3 = ConditionalLinear(N, N, n_steps) 72 | self.lin4 = nn.Linear(N, n, device = self.this_device) 73 | 74 | def forward(self, x, y, c): 75 | x = x.type(torch.FloatTensor).to(self.this_device) 76 | c = c.type(torch.FloatTensor).to(self.this_device) 77 | 78 | x = F.softplus(self.lin1(x, y, c)) 79 | x = F.softplus(self.lin2(x, y)) 80 | x = F.softplus(self.lin3(x, y)) 81 | return self.lin4(x) 82 | 83 | 84 | 85 | # Binary Input Encoders/decoders: 86 | class Encoder(nn.Module): 87 | def __init__(self, k, M, n, **extra_kwargs): 88 | super(Encoder, self).__init__() 89 | 90 | use_cuda = torch.cuda.is_available() 91 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 92 | 93 | self._f = nn.Sequential( 94 | nn.Linear(k, M), 95 | nn.ELU(), 96 | nn.Linear(M,M), 97 | nn.ELU(), 98 | nn.Linear(M, n)#,norm_layer 99 | ) 100 | 101 | def power_constraint(self, codes): 102 | codes_mean = torch.mean(codes) 103 | codes_std = torch.std(codes) 104 | codes_norm = (codes-codes_mean)/ codes_std 105 | 106 | return codes_norm 107 | 108 | def forward(self, inputs): 109 | inputs = inputs.type(torch.FloatTensor).to(self.this_device) 110 | _x = self._f(inputs) 111 | x = self.power_constraint(_x) 112 | return x 113 | 114 | 115 | class Decoder(nn.Module): 116 | def __init__(self,k, M ,n, **extra_kwargs): 117 | super(Decoder, self).__init__() 118 | self._f = nn.Sequential( 119 | nn.Linear(n, M), 120 | nn.ELU(), 121 | nn.Linear(M, M), 122 | nn.ELU(), 123 | nn.Linear(M, k), 124 | nn.Sigmoid() 125 | ) 126 | 127 | def forward(self, inputs): 128 | x = self._f(inputs) 129 | return x 130 | 131 | # one-hot encoders/decoders: 132 | 133 | class Encoder_1h(nn.Module): #power constraint 134 | def __init__(self, M, n, P_in, **extra_kwargs): 135 | super(Encoder_1h, self).__init__() 136 | 137 | use_cuda = torch.cuda.is_available() 138 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 139 | self.P_in = P_in 140 | self.M = M 141 | self._f = nn.Sequential( 142 | nn.Linear(M, M), 143 | nn.ELU(), 144 | nn.Linear(M,M), 145 | nn.ELU(), 146 | nn.Linear(M, n)#,norm_layer 147 | ) 148 | def power(self): 149 | inputs = torch.eye(self.M).to(self.this_device) 150 | codes = self._f(inputs) 151 | P = torch.mean(torch.sum(codes**2, dim=1)) 152 | return P 153 | 154 | def power_constraint(self, codes): 155 | P = self.power() #torch.mean(torch.sum(codes**2, dim=1)) 156 | codes_norm = codes * np.sqrt(self.P_in) / torch.sqrt(P) 157 | return codes_norm 158 | 159 | def forward(self, inputs): 160 | inputs = inputs.to(self.this_device) 161 | _x = self._f(inputs) 162 | x = self.power_constraint(_x) 163 | return x 164 | 165 | class Decoder_1h(nn.Module): 166 | def __init__(self, M ,n, **extra_kwargs): 167 | super(Decoder_1h, self).__init__() 168 | self._f = nn.Sequential( 169 | #nn.Linear(n, n), 170 | #nn.ELU(), 171 | nn.Linear(n, M), 172 | nn.ELU(), 173 | nn.Linear(M, M), 174 | nn.ELU(), 175 | nn.Linear(M, M), 176 | ) 177 | 178 | def forward(self, inputs): 179 | x = self._f(inputs) 180 | return x 181 | 182 | 183 | 184 | 185 | class Channel(torch.nn.Module): 186 | def __init__(self, enc, dec): 187 | super(Channel, self).__init__() 188 | use_cuda = torch.cuda.is_available() 189 | self.device = torch.device("cuda" if use_cuda else "cpu") 190 | self.enc = enc 191 | self.dec = dec 192 | 193 | def forward(self, inputs, noise_std): 194 | #End-to-End Framework 195 | codeword = self.enc(inputs) 196 | 197 | i,j = codeword.size() 198 | channel_noise = noise_std*torch.randn((i, j)).to(self.device) 199 | 200 | rec_signal = codeword + channel_noise 201 | dec_signal = self.dec(rec_signal) 202 | 203 | return dec_signal 204 | 205 | 206 | 207 | class Channel_ray(torch.nn.Module): 208 | def __init__(self, enc, dec): 209 | super(Channel_ray, self).__init__() 210 | use_cuda = torch.cuda.is_available() 211 | self.device = torch.device("cuda" if use_cuda else "cpu") 212 | self.enc = enc 213 | self.dec = dec 214 | 215 | def forward(self, inputs, noise_std): 216 | #End-to-End Framework 217 | codeword = self.enc(inputs) 218 | 219 | i,j = codeword.size() 220 | 221 | h_ray = (1/np.sqrt(2))*torch.sqrt( torch.rand(size = codeword.size(), device = self.device)**2 +torch.rand(size = codeword.size(), device = self.device)**2) 222 | channel_noise = noise_std*torch.randn((i, j), device = self.device) 223 | 224 | rec_signal = codeword*h_ray + channel_noise 225 | dec_signal = self.dec(rec_signal) 226 | 227 | return dec_signal 228 | 229 | 230 | class Channel_SSPA(torch.nn.Module): 231 | def __init__(self, enc, dec): 232 | super(Channel_SSPA, self).__init__() 233 | use_cuda = torch.cuda.is_available() 234 | self.device = torch.device("cuda" if use_cuda else "cpu") 235 | self.enc = enc 236 | self.dec = dec 237 | def forward(self, inputs, noise_std): 238 | codeword = self.enc(inputs) 239 | 240 | rec_signal = ch_SSPA(codeword, noise_std, self.device) 241 | dec_signal = self.dec(rec_signal) 242 | 243 | return dec_signal 244 | 245 | -------------------------------------------------------------------------------- /src/.ipynb_checkpoints/trainer-checkpoint.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.optim.lr_scheduler import LambdaLR 11 | from torch.utils.data.dataloader import DataLoader 12 | from torch.nn import functional as F 13 | import utils 14 | 15 | 16 | class TrainerConfig: 17 | # optimization parameters 18 | max_epochs = 10 19 | batch_size = 500 20 | M=16 21 | learning_rate = 1e-3 22 | rate = 4/7 23 | num_workers = 0 # for DataLoader 24 | 25 | 26 | def __init__(self, **kwargs): 27 | for k,v in kwargs.items(): 28 | setattr(self, k, v) 29 | 30 | class Trainer: 31 | def __init__(self, model, train_dataset, config, one_hot): 32 | self.model = model 33 | self.train_dataset = train_dataset 34 | #self.test_dataset = test_dataset 35 | self.config = config 36 | self.one_hot = one_hot 37 | 38 | self.device = 'cpu' 39 | if torch.cuda.is_available(): 40 | self.device = torch.cuda.current_device() 41 | self.model = torch.nn.DataParallel(self.model).to(self.device) 42 | 43 | def train(self): 44 | model, config = self.model, self.config 45 | optimizer = torch.optim.NAdam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate) 46 | 47 | loss_CE = nn.CrossEntropyLoss() 48 | 49 | def run_epoch(): 50 | lr = config.learning_rate 51 | batch_size = config.batch_size 52 | M = config.M 53 | data = self.train_dataset 54 | loader = DataLoader(data, shuffle=True, pin_memory=False, batch_size=batch_size, 55 | num_workers=config.num_workers) 56 | losses, batch_BER, batch_SER = [], [], [] 57 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 58 | 59 | for it, x in pbar: 60 | labels = x.reshape((batch_size,1)).to(self.device) 61 | x = F.one_hot(labels.long(), num_classes=M) 62 | x = x.float().reshape(batch_size,M) 63 | 64 | optimizer.zero_grad() 65 | output = model(x, config.noise_std) 66 | if self.one_hot: 67 | loss = loss_CE(output, labels.long().squeeze(1)) 68 | else: 69 | loss = F.binary_cross_entropy(output, x) 70 | loss = loss.mean() 71 | loss.backward() 72 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.) 73 | losses.append(loss.item()) 74 | optimizer.step() 75 | 76 | if self.one_hot: 77 | SER, _, _ = utils.SER(labels.long(), output, 0.9) 78 | batch_SER.append(SER.item()) 79 | if it%100==0: 80 | pbar.set_description(f"epoch {epoch+1}: loss {np.mean(losses):.2e} SER {np.mean(batch_SER):.2e}") 81 | else: 82 | batch_BER.append(utils.B_Ber_m(x, output).item()) 83 | pbar.set_description(f"epoch {epoch+1}: loss {np.mean(losses):.2e} BER {np.mean(batch_BER):.2e}") 84 | 85 | 86 | for epoch in range(config.max_epochs): 87 | run_epoch() 88 | 89 | def test(self, snr_range, one_hot, erasure_bound): 90 | model, config = self.model, self.config 91 | ser = [] 92 | 93 | def run_epoch(): 94 | data = self.train_dataset 95 | batch_size = config.batch_size 96 | M = config.M 97 | loader = DataLoader(data, shuffle=True, pin_memory=False, batch_size=batch_size, 98 | num_workers=config.num_workers) 99 | batch_BER = [] 100 | block_ER = [] 101 | erasures = [] 102 | batch_SER = [] 103 | 104 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 105 | 106 | for it, x in pbar: 107 | if self.one_hot: 108 | labels = x.reshape((batch_size,1)).to(self.device) 109 | x=F.one_hot(labels.long(), num_classes=M) 110 | x = x.float().reshape(batch_size,M) 111 | else: 112 | x = x.float().to(self.device) 113 | noise_std_ = utils.EbNo_to_noise(snr_range[epoch], config.rate) 114 | #noise_std_ = utils.SNR_to_noise(snr_range[epoch]) 115 | 116 | output = model(x, noise_std_) 117 | 118 | if self.one_hot: 119 | SER, _, _ = utils.SER(labels.long(), output, erasure_bound) 120 | batch_SER.append(SER.item()) 121 | #block_ER.append(bler) 122 | #erasures.append(Erasures) 123 | if it%100==0: 124 | pbar.set_description(f"SNR {snr_range[epoch]} iter {it}: SER {np.mean(batch_SER):.3e}") 125 | else: 126 | batch_BER.append(utils.B_Ber_m(x.float(), output).item()) 127 | block_ER.append(utils.Block_ER(x.float(), output)) 128 | pbar.set_description(f"SNR {snr_range[epoch]} iter {it}: avg. BER {np.mean(batch_BER):.8f} BLER {np.mean(block_ER):.7e}") 129 | 130 | return np.mean(batch_SER) 131 | 132 | num_samples = self.train_dataset.size(dim=0) 133 | for epoch in range(len(snr_range)): 134 | 135 | temp1 = 0 136 | it_in_while = 0 137 | 138 | while temp1 * num_samples *(it_in_while) < 1000: # To guarantee we have enough samples for the Monte-Carlo method. 139 | if it_in_while > 1 : 140 | print("The number of samples is not enough. One more epoch is running. Total # of samples used: " , num_samples * it_in_while) 141 | temp2 = run_epoch() 142 | temp1 = (it_in_while * temp1 + temp2)/(it_in_while+1) # taking average of the error probability. 143 | it_in_while += 1 144 | ser.append(temp1) 145 | 146 | return ser 147 | 148 | 149 | 150 | class TrainerConfig_DDM: 151 | # optimization parameters 152 | max_epochs = 10 153 | dataset_size = 1000000 154 | batch_size = 500 155 | noise_std = 1 156 | learning_rate = 1e-3 157 | M=16 158 | n=7 159 | num_steps = 50 160 | betas = torch.zeros([num_steps]) 161 | rate = 4/7 162 | num_workers = 0 # for DataLoader 163 | optim_gen = optim.Adam 164 | # TRANSFER_LEARNING= False 165 | IS_RES = False 166 | pred_type = 'epsilon' 167 | 168 | 169 | def __init__(self, **kwargs): 170 | for k,v in kwargs.items(): 171 | setattr(self, k, v) 172 | self.alphas = 1 - self.betas.to('cpu') 173 | self.alphas_prod = torch.cumprod(self.alphas, 0) 174 | self.alphas_prod_p = torch.cat([torch.tensor([1]), self.alphas_prod[:-1]], 0) 175 | self.alphas_bar_sqrt = torch.sqrt(self.alphas_prod) 176 | #one_minus_alphas_bar_log = torch.log(1 - alphas_prod) 177 | self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - self.alphas_prod) 178 | 179 | class Trainer_DDM: 180 | def __init__(self, channel_gen, ema, device, channel_model, tconf_gen): 181 | self.channel_gen = channel_gen 182 | self.ema = ema 183 | self.device = device 184 | self.channel_model = channel_model 185 | self.tconf_gen = tconf_gen 186 | 187 | def train_PreT(self): # pretraining algorithm 188 | channel_gen = self.channel_gen 189 | ema = self.ema 190 | device = self.device 191 | channel_model = self.channel_model 192 | tconf_gen = self.tconf_gen 193 | n = tconf_gen.n 194 | 195 | IS_RES = tconf_gen.IS_RES 196 | 197 | optimizer_ch = tconf_gen.optim_gen(channel_gen.parameters(), lr=tconf_gen.learning_rate) 198 | dataset_gen = torch.randn(tconf_gen.dataset_size, n) 199 | alphas_bar_sqrt = tconf_gen.alphas_bar_sqrt.to(device) 200 | one_minus_alphas_bar_sqrt = tconf_gen.one_minus_alphas_bar_sqrt.to(device) 201 | 202 | if tconf_gen.pred_type == 'epsilon': 203 | estimation_loss = utils.noise_estimation_loss 204 | elif tconf_gen.pred_type == 'v': 205 | estimation_loss = utils.v_estimation_loss 206 | else: 207 | assert(ValueError) 208 | 209 | 210 | loss_ep_gen = [] 211 | for t in range(tconf_gen.max_epochs): 212 | loader = DataLoader(dataset_gen, shuffle=True, pin_memory=False, batch_size=tconf_gen.batch_size) 213 | loss_batch, batch_BER, batch_SER = [], [], [] 214 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 215 | 216 | for it, x in pbar: 217 | x = x.to(device) 218 | y = channel_model(x, tconf_gen.noise_std, device) 219 | if IS_RES: 220 | y = y - x 221 | 222 | yx = torch.cat((y , x),dim=1) 223 | # Create EMA model 224 | loss = estimation_loss(channel_gen, yx, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, tconf_gen.num_steps, n) 225 | 226 | optimizer_ch.zero_grad() 227 | loss.backward() 228 | torch.nn.utils.clip_grad_norm_(channel_gen.parameters(), 1.) 229 | optimizer_ch.step() 230 | # Update the exponential moving average 231 | ema.update(channel_gen) 232 | 233 | loss_batch.append(loss.item()) 234 | 235 | if it % 100 == 0: 236 | pbar.set_description(f"epoch {t+1}: loss {np.mean(loss.item()):.2e}") 237 | 238 | loss_ep_gen.append(np.mean(loss_batch)) 239 | return loss_ep_gen 240 | 241 | def train_itr(self, encoder): # iterative training 242 | 243 | channel_gen = self.channel_gen 244 | ema = self.ema 245 | device = self.device 246 | channel_model = self.channel_model 247 | tconf_gen = self.tconf_gen 248 | M = tconf_gen.M 249 | n = tconf_gen.n 250 | 251 | optimizer_ch = tconf_gen.optim_gen(channel_gen.parameters(), lr=tconf_gen.learning_rate) 252 | 253 | dataset_gen = np.random.randint(M, size=tconf_gen.dataset_size) 254 | 255 | alphas_bar_sqrt = tconf_gen.alphas_bar_sqrt.to(device) 256 | one_minus_alphas_bar_sqrt = tconf_gen.one_minus_alphas_bar_sqrt.to(device) 257 | 258 | if tconf_gen.pred_type == 'epsilon': 259 | estimation_loss = utils.noise_estimation_loss 260 | elif tconf_gen.pred_type == 'v': 261 | estimation_loss = utils.v_estimation_loss 262 | else: 263 | assert(ValueError) 264 | 265 | loss_ep_gen = [] 266 | for t in range(tconf_gen.max_epochs): 267 | loader = DataLoader(dataset_gen, shuffle=True, pin_memory=False, batch_size=tconf_gen.batch_size) 268 | 269 | loss_batch, batch_BER, batch_SER = [], [], [] 270 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 271 | 272 | for it, m in pbar: 273 | labels = m.reshape((-1,1)) 274 | m = F.one_hot(labels.long(), num_classes=M) 275 | m = m.float().reshape(-1,M).to(device) 276 | 277 | x = encoder(m) 278 | y = channel_model(x, tconf_gen.noise_std, device) 279 | yx = torch.cat((y , x ),dim=1) 280 | 281 | # Create EMA model 282 | loss = estimation_loss(channel_gen, yx, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, tconf_gen.num_steps, n) 283 | 284 | optimizer_ch.zero_grad() 285 | loss.backward() 286 | torch.nn.utils.clip_grad_norm_(channel_gen.parameters(), 1.) 287 | optimizer_ch.step() 288 | 289 | # Update the exponential moving average 290 | ema.update(channel_gen) 291 | 292 | loss_batch.append(loss.item()) 293 | 294 | if it % 100 == 0: 295 | pbar.set_description(f"epoch {t+1}: loss {np.mean(loss.item()):.2e}") 296 | 297 | loss_ep_gen.append(np.mean(loss_batch)) 298 | return loss_ep_gen 299 | 300 | 301 | class TrainerConfig_AE_w_DDM: 302 | # optimization parameters 303 | max_epochs = 10 304 | dataset_size = 1000000 305 | batch_size = 500 306 | noise_std = 1 307 | learning_rate = 1e-3 308 | M=16 309 | n=7 310 | rate = 4/7 311 | num_steps = 50 312 | #alphas = torch.zeros([num_steps]) 313 | betas = torch.zeros([num_steps]) 314 | denoising_alg = 'DDPM' 315 | pred_type = 'epsilon' 316 | traj = range(num_steps) 317 | channel_model = 'None' 318 | 319 | optim_AE = optim.NAdam 320 | num_workers = 0 # for DataLoader 321 | 322 | def __init__(self, **kwargs): 323 | for k,v in kwargs.items(): 324 | setattr(self, k, v) 325 | self.alphas = 1 - self.betas.to('cpu') 326 | self.alphas_prod = torch.cumprod(self.alphas, 0) 327 | self.alphas_prod_p = torch.cat([torch.tensor([1]), self.alphas_prod[:-1]], 0) 328 | self.alphas_bar_sqrt = torch.sqrt(self.alphas_prod) 329 | #one_minus_alphas_bar_log = torch.log(1 - alphas_prod) 330 | self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - self.alphas_prod) 331 | 332 | class Trainer_AE_w_DDM: 333 | def __init__(self, encoder, decoder, channel_gen, device, tconf_AE): 334 | self.encoder = encoder 335 | self.decoder = decoder 336 | self.channel_gen = channel_gen 337 | self.device = device 338 | self.tconf_AE = tconf_AE 339 | 340 | def train_PreT(self):#(encoder, decoder, channel_gen, device, tconf_AE): 341 | encoder = self.encoder 342 | decoder = self.decoder 343 | channel_gen = self.channel_gen 344 | device = self.device 345 | tconf_AE = self.tconf_AE 346 | 347 | 348 | M = tconf_AE.M 349 | n = tconf_AE.n 350 | 351 | denoising_alg = tconf_AE.denoising_alg 352 | one_minus_alphas_bar_sqrt = tconf_AE.one_minus_alphas_bar_sqrt.to(device) 353 | betas = tconf_AE.betas.to(device) 354 | alphas = tconf_AE.alphas.to(device) 355 | 356 | dataset_AE = np.random.randint(M, size=tconf_AE.dataset_size) 357 | loss_CE = nn.CrossEntropyLoss() 358 | optimizer_AE = tconf_AE.optim_AE(list(encoder.parameters()) + list(decoder.parameters()), lr = tconf_AE.learning_rate) 359 | 360 | loss_ep_AE = [] 361 | 362 | for t in range(tconf_AE.max_epochs): 363 | loader = DataLoader(dataset_AE, shuffle=True, pin_memory=False, batch_size=tconf_AE.batch_size)#, num_workers=config.num_workers) 364 | loss_batch, batch_BER, batch_SER = [], [], [] 365 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 366 | 367 | for it, m in pbar: #enumerate(loader): 368 | labels = m.reshape((-1,1)).to(device) 369 | optimizer_AE.zero_grad() 370 | 371 | # Forward Pass: 372 | # Encoding 373 | m = F.one_hot(labels.long(), num_classes=M) 374 | m = m.float().reshape(-1,M) 375 | x = encoder(m) 376 | c = x 377 | 378 | # Channel 379 | if denoising_alg == 'DDPM': 380 | x_seq = utils.p_sample_loop_w_Condition(channel_gen, x.size(), tconf_AE.num_steps, 381 | tconf_AE.alphas, tconf_AE.betas, 382 | tconf_AE.alphas_bar_sqrt, 383 | tconf_AE.one_minus_alphas_bar_sqrt, c, 384 | pred_type=tconf_AE.pred_type) 385 | elif denoising_alg == 'DDIM': 386 | x_seq = utils.p_sample_loop_w_Condition_DDIM(channel_gen, x.size(), tconf_AE.traj, 387 | tconf_AE.alphas_prod, tconf_AE.alphas_bar_sqrt, 388 | tconf_AE.one_minus_alphas_bar_sqrt, c, 389 | pred_type=tconf_AE.pred_type) 390 | y = x_seq[-1] 391 | 392 | # Decoding 393 | m_d = decoder(y) 394 | 395 | # Backward pass: 396 | # Compute the loss and the gradient 397 | loss = loss_CE(m_d, labels.long().squeeze(1)) 398 | loss = loss.mean() 399 | loss.backward() 400 | 401 | torch.nn.utils.clip_grad_norm_(list(encoder.parameters())+list(decoder.parameters()), 1.) 402 | optimizer_AE.step() 403 | 404 | 405 | loss_batch.append(loss.item()) 406 | 407 | SER_tmp, _, _ = utils.SER(labels.long(), m_d, 0.9) 408 | batch_SER.append(SER_tmp.item()) 409 | 410 | if it%100==0: 411 | pbar.set_description(f"epoch {t+1}: loss {np.mean(loss_batch):.2e} SER {np.mean(batch_SER):.2e}") 412 | 413 | # Save epoch loss 414 | loss_ep_AE.append(np.mean(loss_batch)) 415 | return loss_ep_AE 416 | 417 | 418 | 419 | def train_PreT_enc(self): # (encoder, decoder, channel_gen, device, tconf_AE): 420 | encoder = self.encoder 421 | decoder = self.decoder 422 | channel_gen = self.channel_gen 423 | device = self.device 424 | tconf_AE = self.tconf_AE 425 | 426 | M = tconf_AE.M 427 | n = tconf_AE.n 428 | 429 | denoising_alg = tconf_AE.denoising_alg 430 | one_minus_alphas_bar_sqrt = tconf_AE.one_minus_alphas_bar_sqrt.to(device) 431 | betas = tconf_AE.betas.to(device) 432 | alphas = tconf_AE.alphas.to(device) 433 | 434 | dataset_AE = np.random.randint(M, size=tconf_AE.dataset_size) 435 | loss_CE = nn.CrossEntropyLoss() 436 | optimizer_AE = tconf_AE.optim_AE(encoder.parameters(), lr=tconf_AE.learning_rate) 437 | 438 | loss_ep_AE = [] 439 | SER_ep = [] 440 | 441 | for t in range(tconf_AE.max_epochs): 442 | loader = DataLoader(dataset_AE, shuffle=True, pin_memory=False, batch_size=tconf_AE.batch_size) 443 | loss_batch, batch_BER, batch_SER = [], [], [] 444 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 445 | 446 | for it, m in pbar: 447 | labels = m.reshape((-1, 1)).to(device) 448 | optimizer_AE.zero_grad() 449 | 450 | # Forward Pass: 451 | # Encoding 452 | m = F.one_hot(labels.long(), num_classes=M) 453 | m = m.float().reshape(-1, M) 454 | x = encoder(m) 455 | c = x 456 | 457 | # Channel 458 | 459 | if denoising_alg == 'DDPM': 460 | x_seq = utils.p_sample_loop_w_Condition(channel_gen, x.size(), tconf_AE.num_steps 461 | , tconf_AE.alphas, tconf_AE.betas, 462 | tconf_AE.alphas_bar_sqrt, 463 | tconf_AE.one_minus_alphas_bar_sqrt, c, 464 | pred_type=tconf_AE.pred_type) 465 | y = x_seq[-1] 466 | elif denoising_alg == 'DDIM': 467 | x_seq = utils.p_sample_loop_w_Condition_DDIM(channel_gen, x.size(), tconf_AE.traj, 468 | tconf_AE.alphas_prod, tconf_AE.alphas_bar_sqrt, 469 | tconf_AE.one_minus_alphas_bar_sqrt, c, 470 | pred_type=tconf_AE.pred_type) 471 | y = x_seq[-1] 472 | 473 | # Decoding 474 | m_d = decoder(y) 475 | 476 | # Backward pass: 477 | # Compute the loss and the gradient 478 | loss = loss_CE(m_d, labels.long().squeeze(1)) 479 | loss = loss.mean() 480 | loss.backward() 481 | 482 | torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.) 483 | optimizer_AE.step() 484 | 485 | loss_batch.append(loss.item()) 486 | 487 | SER_tmp, _, _ = utils.SER(labels.long(), m_d, 0.9) 488 | batch_SER.append(SER_tmp.item()) 489 | 490 | if it % 100 == 0: 491 | pbar.set_description(f"loss {np.mean(loss_batch):.2e} SER {np.mean(batch_SER):.2e}") 492 | 493 | # Save epoch loss 494 | loss_ep_AE.append(np.mean(loss_batch)) 495 | SER_ep.append(np.mean(batch_SER)) 496 | return loss_ep_AE, SER_ep 497 | 498 | def train_dec(self): 499 | encoder = self.encoder 500 | decoder = self.decoder 501 | device = self.device 502 | tconf_AE = self.tconf_AE 503 | 504 | 505 | M = tconf_AE.M 506 | n = tconf_AE.n 507 | 508 | dataset_AE = np.random.randint(M, size=tconf_AE.dataset_size) 509 | loss_CE = nn.CrossEntropyLoss() 510 | optimizer_AE = tconf_AE.optim_AE(decoder.parameters(), lr = tconf_AE.learning_rate) 511 | 512 | loss_ep_AE = [] 513 | SER_ep = [] 514 | 515 | for t in range(tconf_AE.max_epochs): 516 | loader = DataLoader(dataset_AE, shuffle=True, pin_memory=False, batch_size=tconf_AE.batch_size)#, num_workers=config.num_workers) 517 | loss_batch, batch_BER, batch_SER = [], [], [] 518 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 519 | 520 | for it, m in pbar: #enumerate(loader): 521 | labels = m.reshape((-1,1)).to(device) 522 | optimizer_AE.zero_grad() 523 | 524 | # Forward Pass: 525 | # Encoding 526 | m = F.one_hot(labels.long(), num_classes=M) 527 | m = m.float().reshape(-1,M) 528 | x = encoder(m) 529 | y = tconf_AE.channel_model(x, tconf_AE.noise_std, device) 530 | 531 | # Decoding 532 | m_d = decoder(y) 533 | 534 | # Backward pass: 535 | # Compute the loss and the gradient 536 | loss = loss_CE(m_d, labels.long().squeeze(1)) 537 | loss = loss.mean() 538 | loss.backward() 539 | 540 | torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1.) 541 | optimizer_AE.step() 542 | 543 | 544 | loss_batch.append(loss.item()) 545 | 546 | SER_tmp, _, _ = utils.SER(labels.long(), m_d, 0.9) 547 | batch_SER.append(SER_tmp.item()) 548 | 549 | if it%100==0: 550 | pbar.set_description(f"loss {np.mean(loss_batch):.2e} SER {np.mean(batch_SER):.2e}") 551 | 552 | # Save epoch loss 553 | loss_ep_AE.append(np.mean(loss_batch)) 554 | SER_ep.append(np.mean(batch_SER)) 555 | return loss_ep_AE, SER_ep 556 | 557 | def train_itr(self):#(encoder, decoder, channel_gen, device, tconf_AE): 558 | encoder = self.encoder 559 | decoder = self.decoder 560 | channel_gen = self.channel_gen 561 | device = self.device 562 | tconf_AE = self.tconf_AE 563 | 564 | M = tconf_AE.M 565 | n = tconf_AE.n 566 | 567 | one_minus_alphas_bar_sqrt = tconf_AE.one_minus_alphas_bar_sqrt.to(device) 568 | betas = tconf_AE.betas.to(device) 569 | alphas = tconf_AE.alphas.to(device) 570 | 571 | dataset_AE = np.random.randint(M, size=tconf_AE.dataset_size) 572 | loss_CE = nn.CrossEntropyLoss() 573 | optimizer_AE = tconf_AE.optim_AE(list(encoder.parameters()) + list(decoder.parameters()), lr = tconf_AE.learning_rate) 574 | 575 | loss_ep_AE = [] 576 | SER_ep = [] 577 | 578 | for t in range(tconf_AE.max_epochs): 579 | loader = DataLoader(dataset_AE, shuffle=True, pin_memory=False, batch_size=tconf_AE.batch_size)#, num_workers=config.num_workers) 580 | loss_batch, batch_BER, batch_SER = [], [], [] 581 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 582 | 583 | for it, m in pbar: #enumerate(loader): 584 | labels = m.reshape((-1,1)).to(device) 585 | optimizer_AE.zero_grad() 586 | 587 | # Forward Pass: 588 | # Encoding 589 | m = F.one_hot(labels.long(), num_classes=M) 590 | m = m.float().reshape(-1,M) 591 | x = encoder(m) 592 | c = x 593 | # Channel 594 | 595 | x_seq = utils.p_sample_loop_w_Condition(channel_gen, [x.size()[0], n], tconf_AE.num_steps, alphas, 596 | betas, one_minus_alphas_bar_sqrt, c, 597 | pred_type=tconf_AE.pred_type, is_light=True) 598 | y = x_seq # Generated data after 100 time steps 599 | 600 | # Decoding 601 | m_d = decoder(y) 602 | 603 | # Backward pass: 604 | # Compute the loss and the gradient 605 | loss = loss_CE(m_d, labels.long().squeeze(1)) 606 | loss = loss.mean() 607 | loss.backward() 608 | 609 | torch.nn.utils.clip_grad_norm_(list(encoder.parameters())+list(decoder.parameters()), 1.) 610 | optimizer_AE.step() 611 | 612 | 613 | loss_batch.append(loss.item()) 614 | 615 | SER_tmp, _, _ = utils.SER(labels.long(), m_d, 0.9) 616 | batch_SER.append(SER_tmp.item()) 617 | 618 | if it%100==0: 619 | pbar.set_description(f"epoch {t+1}: loss {np.mean(loss_batch):.2e} SER {np.mean(batch_SER):.2e}") 620 | 621 | # Save epoch loss 622 | loss_ep_AE.append(np.mean(loss_batch)) 623 | SER_ep.append(np.mean(batch_SER)) 624 | return loss_ep_AE, SER_ep 625 | 626 | def test(self, channel_model): 627 | encoder = self.encoder 628 | decoder = self.decoder 629 | 630 | device = self.device 631 | tconf_AE = self.tconf_AE 632 | 633 | 634 | M = tconf_AE.M 635 | n = tconf_AE.n 636 | noise_std = tconf_AE.noise_std 637 | 638 | dataset_AE = np.random.randint(M, size=tconf_AE.dataset_size) 639 | 640 | SER_ep_AE = [] 641 | 642 | for t in range(1): 643 | loader = DataLoader(dataset_AE, shuffle=True, pin_memory=False, batch_size=tconf_AE.batch_size) 644 | batch_SER = [] 645 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 646 | 647 | for it, m in pbar: 648 | labels = m.reshape((-1,1)).to(device) 649 | # Forward Pass: 650 | # Encoding 651 | m = F.one_hot(labels.long(), num_classes=M) 652 | m = m.float().reshape(-1,M) 653 | x = encoder(m) 654 | 655 | # Channel 656 | y = channel_model(x, noise_std, device) 657 | 658 | # Decoding 659 | m_d = decoder(y) 660 | 661 | SER_tmp, _, _ = utils.SER(labels.long(), m_d, 0.9) 662 | batch_SER.append(SER_tmp.item()) 663 | 664 | if it%100==0: 665 | pbar.set_description(f"Validation SER {np.mean(batch_SER):.2e}") 666 | 667 | # Save epoch loss 668 | SER_ep_AE.append(np.mean(batch_SER)) 669 | return np.mean(SER_ep_AE) 670 | -------------------------------------------------------------------------------- /src/.ipynb_checkpoints/utils-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Loss Function for Diffusion Model 2 | # Original Source: https://github.com/acids-ircam/diffusion_models 3 | 4 | import torch 5 | import numpy as np 6 | import sys 7 | import math 8 | assert sys.version_info >= (3, 5) 9 | import matplotlib as mpl 10 | import matplotlib.pyplot as plt 11 | import torch.nn.functional as F 12 | from torch import nn 13 | np.random.seed(42) 14 | 15 | 16 | def make_beta_schedule(schedule='linear', n_timesteps=1000, start=1e-5, end=1e-2): 17 | if schedule == 'linear': 18 | betas = torch.linspace(start, end, n_timesteps) 19 | elif schedule == "quad": 20 | betas = torch.linspace(start ** 0.5, end ** 0.5, n_timesteps) ** 2 21 | elif schedule == "sigmoid": 22 | betas = torch.linspace(-6, 6, n_timesteps) 23 | betas = torch.sigmoid(betas) * (end - start) + start 24 | 25 | elif schedule == "cosine": 26 | i = torch.linspace(0, 1, n_timesteps+1) 27 | f_t = torch.cos( (i + 0.008) / (1+0.008) * math.pi / 2 ) ** 2 28 | alphas_bar = f_t / f_t[0] 29 | alphas_bar_before = alphas_bar[:-1] 30 | alphas_bar = alphas_bar[1:] 31 | 32 | betas = 1 - alphas_bar / alphas_bar_before 33 | for i in range(n_timesteps): 34 | if betas[i] > 0.999: 35 | betas[i] = 0.999 36 | elif schedule == "cosine-zf": 37 | i = torch.linspace(0, 1, n_timesteps+1) 38 | f_t = torch.cos( (i + 0.008) / (1+0.008) * math.pi / 2 ) ** 2 39 | alphas_bar = f_t / f_t[0] 40 | alphas_bar_before = alphas_bar[:-1] 41 | alphas_bar = alphas_bar[1:] 42 | 43 | betas = 1 - alphas_bar / alphas_bar_before 44 | 45 | return betas 46 | 47 | def extract(input, t, shape): 48 | out = torch.gather(input, 0, t.to(input.device)) 49 | reshape = [t.shape[0]] + [1] * (len(shape) - 1) 50 | return out.reshape(*reshape).to(input.device) 51 | 52 | def q_posterior_mean_variance(x_0, x_t, t,posterior_mean_coef_1,posterior_mean_coef_2,posterior_log_variance_clipped): 53 | shape = x_0.shape 54 | coef_1 = extract(posterior_mean_coef_1, t, shape) 55 | coef_2 = extract(posterior_mean_coef_2, t, shape) 56 | mean = coef_1 * x_0 + coef_2 * x_t 57 | var = extract(posterior_log_variance_clipped, t, shape) 58 | return mean, var 59 | 60 | def p_mean_variance(model, x, t): 61 | # Go through model 62 | out = model(x, t) 63 | # Extract the mean and variance 64 | mean, log_var = torch.split(out, 2, dim=-1) 65 | var = torch.exp(log_var) 66 | return mean, log_var 67 | 68 | def p_sample_w_Condition(model, z, t, alphas, betas, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, c): 69 | device = model.this_device 70 | z = z.to(device) 71 | c = c.to(device) 72 | shape = z.shape 73 | t = torch.tensor([t]).to(device) 74 | 75 | ## The commented lines below are necessary only when the full version of sigma_t is used. 76 | #a = extract(alphas_bar_sqrt, t, shape) 77 | #am1 = extract(one_minus_alphas_bar_sqrt, t, shape) 78 | #if t > 0: 79 | # a_next = extract(alphas_bar_sqrt, t - 1, shape) 80 | # am1_next = extract(one_minus_alphas_bar_sqrt, t - 1, shape) 81 | #else: 82 | # a_next = torch.ones_like(z) 83 | # am1_next = torch.zeros_like(z) 84 | 85 | # Factor to the model output 86 | eps_factor = ((1 - extract(alphas, t, shape)) / extract(one_minus_alphas_bar_sqrt, t, shape)) 87 | eps_factor = eps_factor.to(device) 88 | # Model output 89 | eps_theta = model(z, t, c) 90 | # Final values 91 | mean = (1 / extract(alphas, t, shape).sqrt().to(device)) * (z - (eps_factor * eps_theta)) 92 | # Generate epsilon 93 | e = torch.randn_like(z).to(device) 94 | # Fixed sigma 95 | sigma_t = extract(betas, t, shape).sqrt().to(device) # Simplified version 96 | #sigma_t = ( 1 - a ** 2 / a_next ** 2) ** 0.5 * am1_next / am1 # full version 97 | 98 | sample = mean.to(device) + sigma_t.to(device) * e 99 | sample = sample.to(device) 100 | return (sample) 101 | 102 | 103 | def p_sample_w_Condition_v(model, z, t, alphas, betas, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, c): # stochastic sampling with condition 104 | 105 | device = model.this_device 106 | z = z.to(device) 107 | c = c.to(device) 108 | shape = z.shape 109 | t = torch.tensor([t]).to(device) 110 | 111 | a = extract(alphas_bar_sqrt, t, shape) 112 | am1 = extract(one_minus_alphas_bar_sqrt, t, shape) 113 | 114 | if t > 0: 115 | a_next = extract(alphas_bar_sqrt, t - 1, shape) 116 | am1_next = extract(one_minus_alphas_bar_sqrt, t - 1, shape) 117 | else: 118 | a_next = torch.ones_like(z) 119 | am1_next = torch.zeros_like(z) 120 | # Model output 121 | v_theta = model(z, t, c) 122 | e = torch.randn_like(z) 123 | 124 | # Generate Sample 125 | e_hat = am1 * z + a * v_theta 126 | x_hat = a * z - am1 * v_theta 127 | sample = a_next * x_hat + am1_next ** 2 * a / am1 / a_next * e_hat + ( 128 | 1 - a ** 2 / a_next ** 2) ** 0.5 * am1_next / am1 * e 129 | 130 | return sample 131 | 132 | 133 | def p_sample_loop_w_Condition(model, shape, n_steps, alphas, betas, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, c, 134 | pred_type='epsilon', is_light = False): 135 | device = model.this_device 136 | cur_x = torch.randn(shape).to(device) 137 | c = c.to(device) 138 | alphas = alphas.to(device) 139 | betas = betas.to(device) 140 | alphas_bar_sqrt = alphas_bar_sqrt.to(device) 141 | one_minus_alphas_bar_sqrt = one_minus_alphas_bar_sqrt.to(device) 142 | 143 | if pred_type == 'epsilon': 144 | p_sample_func = p_sample_w_Condition 145 | elif pred_type == 'v': 146 | p_sample_func = p_sample_w_Condition_v 147 | if not is_light: 148 | x_seq = [cur_x] 149 | for i in reversed(range(n_steps)): 150 | cur_x = p_sample_func(model, cur_x, i, alphas, betas, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, c) 151 | if not is_light: 152 | x_seq.append(cur_x) 153 | if not is_light: 154 | return x_seq 155 | else: return cur_x 156 | 157 | def p_sample_w_Condition_DDIM(model, xt, i, j, alphas_prod, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, shape, device, c): 158 | at = extract(alphas_prod, torch.tensor([i]).to(device), shape).to(device) 159 | if j > -1: 160 | at_next = extract(alphas_prod, torch.tensor([j]), shape).to(device) 161 | else: 162 | at_next = torch.ones(shape).to(device) # a_0 = 1 163 | 164 | et = model(xt, torch.tensor([i]).to(device), c) 165 | xt_next = at_next.sqrt() * (xt - et * (1 - at).sqrt()) / at.sqrt() + (1 - at_next).sqrt() * et 166 | return xt_next 167 | 168 | def p_sample_w_Condition_DDIM_v(model, xt, i, j, alphas_prod, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, shape, device, c): 169 | a = extract(alphas_bar_sqrt, torch.tensor([i]).to(device), shape).to(device) 170 | am1 = extract(one_minus_alphas_bar_sqrt, torch.tensor([i]).to(device), shape).to(device) 171 | if j > -1: 172 | a_next = extract(alphas_bar_sqrt, torch.tensor([j]), shape).to(device) 173 | am1_next = extract(one_minus_alphas_bar_sqrt, torch.tensor([j]), shape).to(device) 174 | else: 175 | a_next = torch.ones(shape).to(device) # a_0 = 1 176 | am1_next = torch.zeros(shape).to(device) # 1-sqrt(a_0) = 0 177 | z = x_seq[-1].to(device) 178 | v_theta = model(z, torch.tensor([i]).to(device), c) 179 | 180 | # Generate Sample 181 | x_hat = a * z - am1 * v_theta 182 | z_next = a_next * x_hat + am1_next * (z - a * x_hat) / am1 183 | 184 | def p_sample_loop_w_Condition_DDIM(model, shape, traj, alphas_prod, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, c, 185 | pred_type='epsilon'): 186 | device = model.this_device 187 | cur_x = torch.randn(shape).to(device) # randomly sampled x_T 188 | c = c.to(device) 189 | alphas_prod = alphas_prod.to(device) 190 | alphas_bar_sqrt = alphas_bar_sqrt.to(device) 191 | one_minus_alphas_bar_sqrt = one_minus_alphas_bar_sqrt.to(device) 192 | 193 | 194 | traj_next = [-1] + list(traj[:-1]) # t-1 195 | 196 | x_seq = [cur_x] 197 | 198 | if pred_type == 'epsilon': 199 | p_sample_func_DDIM = p_sample_w_Condition_DDIM 200 | elif pred_type == 'v': 201 | p_sample_func_DDIM = p_sample_w_Condition_DDIM_v 202 | 203 | for i, j in zip(reversed(traj), reversed(traj_next)): # i = t, j = t-1 204 | xt = x_seq[-1].to(device) 205 | xt_next = p_sample_func_DDIM(model, xt, i, j, alphas_prod, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, 206 | shape, device, c) 207 | x_seq.append(xt_next) 208 | 209 | return x_seq 210 | 211 | def q_sample(x_0, t, alphas_bar_sqrt, one_minus_alphas_bar_sqrt ,noise=None): 212 | shape = x_0.shape 213 | if noise is None: 214 | noise = torch.randn_like(x_0) 215 | alphas_t = extract(alphas_bar_sqrt, t, shape) 216 | alphas_1_m_t = extract(one_minus_alphas_bar_sqrt, t, shape) 217 | return (alphas_t * x_0 + alphas_1_m_t * noise) 218 | 219 | def noise_estimation_loss(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps, n): 220 | device = model.this_device 221 | shape = x_0.shape 222 | batch_size = x_0.shape[0] 223 | x = x_0[:, 0:n].to(device) 224 | c = x_0[:, n:].to(device) 225 | 226 | # Select a random step for each example 227 | t = torch.randint(0, n_steps, size=(batch_size,)).long().to(device) 228 | # x0 multiplier 229 | a = extract(alphas_bar_sqrt, t, shape).to(device) 230 | # eps multiplier 231 | am1 = extract(one_minus_alphas_bar_sqrt, t, shape).to(device) 232 | e = torch.randn_like(x).to(device) 233 | # model input 234 | y = x * a + e * am1 235 | output = model(y, t, c) 236 | return (e - output).square().mean() 237 | 238 | 239 | def v_estimation_loss(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps, n, 240 | losstype='no_weighting'): # , prediction_var = 'x'): 241 | 242 | assert isinstance(losstype, str) 243 | # assert isinstance(prediction_var, str) 244 | 245 | device = model.this_device 246 | shape = x_0.shape 247 | batch_size = x_0.shape[0] 248 | x = x_0[:, 0:n].to(device) 249 | c = x_0[:, n:].to(device) 250 | 251 | # Select a random step for each example 252 | t = torch.randint(0, n_steps, size=(batch_size,)).long().to(device) 253 | # x0 multiplier 254 | a = extract(alphas_bar_sqrt, t, shape).to(device) 255 | # eps multiplier 256 | am1 = extract(one_minus_alphas_bar_sqrt, t, shape).to(device) 257 | e = torch.randn_like(x).to(device) 258 | 259 | # model input - the noisy sample 260 | y = x * a + e * am1 261 | output = model(y, t, c) 262 | 263 | # loss weighting 264 | if losstype == 'SNR': 265 | weight_t = a ** 2 / am1 ** 2 266 | elif losstype == 'SNRp1': 267 | weight_t = a ** 2 / am1 ** 2 + torch.ones_like(x) 268 | elif losstype == 'SNRtr': 269 | weight_t = a ** 2 / am1 ** 2 270 | weight_t[weight_t < 1] = 1 271 | elif losstype == 'no_weighting': 272 | weight_t = 1 273 | else: 274 | raise ValueError('Unrecognized weight type. Possible options are SNR, SNRp1, SNRtr, and no_weighting.') 275 | 276 | v = e * a - x * am1 277 | 278 | loss = (v - output).square() 279 | 280 | return (weight_t * loss).mean() # (v - output).square().mean() 281 | 282 | def EbNo_to_noise(ebnodb, rate): 283 | '''Transform EbNo[dB]/snr to noise power''' 284 | ebno = 10**(ebnodb/10) 285 | noise_std = 1/np.sqrt(2*rate*ebno) 286 | return noise_std 287 | 288 | def SNR_to_noise(snrdb): 289 | '''Transform EbNo[dB]/snr to noise power''' 290 | snr = 10**(snrdb/10) 291 | noise_std = 1/np.sqrt(2*snr) 292 | return noise_std 293 | 294 | def SER(input_msg, msg, erasure_bound=0.9): 295 | 296 | '''Calculate the Symbol Error Rate''' 297 | batch, _ = msg.size() 298 | 299 | # count erasures and give indices without erasures 300 | smax = torch.nn.Softmax(dim=1) 301 | out, _ = torch.topk(smax(msg), 1, axis=1) 302 | indices_erasures = (out <= erasure_bound).flatten()#.nonzero() 303 | #erasures = out.size()[0]-indices_non_erasures.size()[0] 304 | 305 | #erasures = torch.count_nonzero(out < erasure_bound).item() 306 | 307 | # cut erasures from prediction 308 | #msg_wo_er = torch.index_select(msg, 0, indices_non_erasures.flatten()) 309 | #input_msg_wo_er = torch.index_select(torch.flatten(input_msg), 0, indices_non_erasures.flatten()) 310 | 311 | #print("input_msg", input_msg) 312 | #print("msg", msg.argmax(dim=1)) 313 | 314 | input_msg_wo_er = torch.flatten(input_msg) 315 | msg_wo_er = msg 316 | 317 | pred_error = torch.ne(input_msg_wo_er, msg_wo_er.argmax(dim=1)) 318 | block_er = torch.sum(pred_error) 319 | ser = block_er/batch 320 | return ser, pred_error, indices_erasures 321 | 322 | def qam16_mapper_n(m): 323 | # m takes in a vector of messages 324 | x = np.linspace(-3,3,4) 325 | y = np.meshgrid(x,x) 326 | z = np.array(y).reshape(2,16) 327 | z = z / np.std(z) 328 | return np.array([[z[0][i],z[1][i]] for i in m]) 329 | 330 | class LoadModels: 331 | def __init__(self, path, Load_model_gen, model_gen, tag_load_gen, Load_model_AE, model_enc, model_dec, tag_load_AE): 332 | self.Load_model_gen = Load_model_gen 333 | self.model_gen = model_gen 334 | self.dir_gen = path + '/channel_gen' + tag_load_gen 335 | self.Load_model_AE = Load_model_AE 336 | self.model_enc = model_enc 337 | self.model_dec = model_dec 338 | self.dir_enc = path + '/encoder' + tag_load_AE 339 | self.dir_dec = path + '/decoder' + tag_load_AE 340 | 341 | def load(self): 342 | if self.Load_model_gen: 343 | self.model_gen.load_state_dict(torch.load(self.dir_gen)) 344 | self.model_gen.eval() 345 | if self.Load_model_AE: 346 | self.model_enc.load_state_dict(torch.load(self.dir_enc)) 347 | self.model_enc.eval() 348 | self.model_dec.load_state_dict(torch.load(self.dir_dec)) 349 | self.model_dec.eval() 350 | return self.model_gen, self.model_enc, self.model_dec 351 | 352 | class SaveModels: 353 | def __init__(self, path, Save_model_gen, model_gen, tag_save_gen, Save_model_AE, model_enc, model_dec, tag_save_AE): 354 | self.Save_model_gen = Save_model_gen 355 | self.model_gen = model_gen 356 | self.dir_gen = path + '/channel_gen' + tag_save_gen 357 | self.Save_model_AE = Save_model_AE 358 | self.model_enc = model_enc 359 | self.model_dec = model_dec 360 | self.dir_enc = path + '/encoder' + tag_save_AE 361 | self.dir_dec = path + '/decoder' + tag_save_AE 362 | 363 | def save_gen(self): 364 | if self.Save_model_gen: 365 | torch.save(self.model_gen.state_dict(), self.dir_gen) 366 | def save_AE(self): 367 | if self.Save_model_AE: 368 | torch.save(self.model_enc.state_dict(), self.dir_enc) 369 | torch.save(self.model_dec.state_dict(), self.dir_dec) 370 | 371 | 372 | 373 | -------------------------------------------------------------------------------- /src/.ipynb_checkpoints/utils_plot-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import numpy as np 5 | import sys 6 | import math 7 | assert sys.version_info >= (3, 5) 8 | import matplotlib as mpl 9 | import matplotlib.pyplot as plt 10 | from utils import p_sample_loop_w_Condition, p_sample_loop_w_Condition_DDIM 11 | from statsmodels.distributions.empirical_distribution import ECDF 12 | 13 | mpl.rcParams.update(mpl.rcParamsDefault) 14 | mpl.rcParams['pdf.fonttype'] = 42 15 | mpl.rcParams['ps.fonttype'] = 42 16 | mpl.rcParams['font.family'] = "serif" 17 | 18 | mpl.rc('axes', labelsize=14) 19 | mpl.rc('xtick', labelsize=12) 20 | mpl.rc('ytick', labelsize=12) 21 | np.random.seed(42) 22 | from channel_models import * 23 | 24 | 25 | def Constellation_InOut_m(enc, ch_model, ch_model_type, device, alphas, betas, alphas_bar_sqrt, 26 | one_minus_alphas_bar_sqrt, 27 | num_steps, M, num_samples, noise_std, max_amp, m=0, PreT = False, 28 | denoising_alg = 'DDPM', traj = range(1), IS_RES = False, pred_type='epsilon'): 29 | labels = m * torch.ones(1) 30 | msg_1h = F.one_hot(labels.long(), num_classes=M) 31 | msg_1h = msg_1h.float().reshape(-1,M) 32 | x = enc(msg_1h).repeat(num_samples,1) 33 | 34 | n = x.shape[1] # channel dimension 35 | 36 | if ch_model_type == 'AWGN': 37 | channel_model = ch_AWGN 38 | elif ch_model_type == 'Rayleigh': 39 | channel_model = ch_Rayleigh_AWGN_n 40 | elif ch_model_type == 'SSPA': 41 | channel_model = ch_SSPA 42 | 43 | else: 44 | raise Exception('Unrecognized channel model. Available models: AWGN, Rayleigh, SSPA.') 45 | 46 | y = channel_model(x,noise_std,device) 47 | 48 | if PreT: 49 | c = x 50 | else: 51 | c = torch.cat((labels.to(device), x ),dim=1) 52 | if denoising_alg == 'DDPM': 53 | x_seq = p_sample_loop_w_Condition(ch_model, x.size(), num_steps, alphas, betas, alphas_bar_sqrt, 54 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 55 | elif denoising_alg == 'DDIM': 56 | alphas_prod = torch.cumprod(alphas, 0) 57 | x_seq = p_sample_loop_w_Condition_DDIM(ch_model, x.size(), traj, alphas_prod, alphas_bar_sqrt, 58 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 59 | 60 | y_gen = x_seq[-1] 61 | if IS_RES: 62 | y_gen = y_gen + c 63 | 64 | y_gen = y_gen 65 | 66 | x = x.detach().cpu().numpy() 67 | y = y.detach().cpu().numpy() 68 | y_gen = y_gen.detach().cpu().numpy() 69 | 70 | print("The scatter plot of the generated data in comparison to the original source; message", m+1, flush=True) 71 | 72 | if n//2 ==1 : 73 | fig = plt.figure(figsize=(4,4)) 74 | plt.scatter(y_gen[:,0], y_gen[:,1], edgecolor=(1,0,0,0.8), label='generated', facecolor=(1,1,1,0.5)) 75 | plt.scatter(y[:,0], y[:,1], edgecolor=(0,0.3,0.3,0.7), label='ground truth', facecolor=(1,1,1,0.5)) 76 | plt.scatter(x[:,0], x[:,1], edgecolor=(0,0,0,0.8), label='channel input',facecolor=(1,1,1,0.5) ) 77 | plt.xlim([-max_amp, max_amp]) 78 | plt.ylim([-max_amp, max_amp]) 79 | plt.xlabel('$Re(y)$') 80 | plt.ylabel('$Im(y)$') 81 | title = 'message '+ str(m+1) 82 | plt.title(title) 83 | plt.legend() 84 | else: 85 | fig, ax = plt.subplots(1,n//2,figsize=(10,4)) 86 | for j in range(n // 2): 87 | ax[j].scatter(y_gen[:,2*j], y_gen[:,2*j+1], edgecolor=(1,0,0,0.8), label='generated', facecolor=(1,1,1,0.5)) 88 | ax[j].scatter(y[:,2*j], y[:,2*j+1], edgecolor=(0,0.3,0.3,0.7), label='ground truth', facecolor=(1,1,1,0.5)) 89 | ax[j].scatter(x[:,2*j], x[:,2*j+1], edgecolor=(0,0,0,0.8), label='channel input',facecolor=(1,1,1,0.5) ) 90 | ax[j].set_xlim([-max_amp, max_amp]) 91 | ax[j].set_ylim([-max_amp, max_amp]) 92 | ax[j].set_xlabel('$Re(y_'+ str(j+1)+ ')$') 93 | ax[j].set_ylabel('$Im(y_'+ str(j+1)+')$') 94 | title = '$m = '+ str(m+1)+'$' 95 | ax[j].set_title(title) 96 | ax[j].legend() 97 | plt.show() 98 | return 99 | 100 | 101 | def Constellation_InOut(enc, ch_model, ch_model_type, device, alphas, betas, alphas_bar_sqrt, 102 | one_minus_alphas_bar_sqrt, 103 | num_steps, M, num_samples, noise_std, max_amp, PreT = False, 104 | denoising_alg = 'DDPM', traj = range(1), cmap_str = 'nipy_spectral', 105 | IS_RES = False, pred_type = 'epsilon'): 106 | 107 | labels = torch.empty((num_samples*M,1), dtype = int) 108 | for i in range(M): 109 | labels[num_samples*i:num_samples*(i+1),:] = torch.full([num_samples,1],i,dtype=int) 110 | 111 | m = F.one_hot(labels, num_classes=M) 112 | m = m.float().reshape(-1,M) 113 | x = enc(m) #to(device) 114 | n = x.shape[1] # channel dimension 115 | colors = 100//M*np.arange(M)/100. #np.arange(1, 100, 100//M, dtype=int)[:M] 116 | cmap = plt.cm.get_cmap(cmap_str) 117 | colors_rgba = cmap(colors) 118 | 119 | 120 | if ch_model_type == 'AWGN': 121 | channel_model = ch_AWGN 122 | elif ch_model_type == 'Rayleigh': 123 | channel_model = ch_Rayleigh_AWGN_n 124 | elif ch_model_type == 'SSPA': 125 | channel_model = ch_SSPA 126 | else: 127 | raise Exception('Unrecognized channel model. Available models: AWGN, Rayleigh, SSPA.') 128 | 129 | y = channel_model(x,noise_std,device) 130 | 131 | if PreT: 132 | c = x 133 | else: 134 | c = torch.cat((labels.to(device), x ),dim=1) #repeat(num_samples*M,1) 135 | #mx = torch.cat((m, x),dim=1) 136 | if denoising_alg == 'DDPM': 137 | x_seq = p_sample_loop_w_Condition(ch_model, x.size(), num_steps, alphas, betas, alphas_bar_sqrt, 138 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 139 | elif denoising_alg == 'DDIM': 140 | alphas_prod = torch.cumprod(alphas, 0) 141 | x_seq = p_sample_loop_w_Condition_DDIM(ch_model, x.size(), traj, alphas_prod, alphas_bar_sqrt, 142 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 143 | else: 144 | assert(ValueError) 145 | 146 | if IS_RES: 147 | y_gen = x_seq[-1] + c 148 | else: y_gen = x_seq[-1] 149 | 150 | y_gen = y_gen 151 | 152 | x = x.detach().cpu().numpy() 153 | y = y.detach().cpu().numpy() 154 | y_gen = y_gen.detach().cpu().numpy() 155 | if n // 2 == 1: 156 | fig = plt.figure(figsize=(7, 7)) 157 | else: 158 | fig, ax = plt.subplots( 1, n // 2, figsize=( 3 * n // 2, 3)) 159 | 160 | if n//2 ==1 : 161 | plt.scatter(y_gen[:,0], y_gen[:,1], edgecolor = np.repeat(colors_rgba, num_samples , axis=0), 162 | marker="^", s = 25, facecolor='none', label = 'generated') 163 | plt.scatter(y[:,0], y[:,1], color = np.repeat(colors_rgba, num_samples , axis=0), marker="x", 164 | s = 25, label='ground truth') #c=np.repeat(colors, num_samples), facecolor='none') 165 | plt.scatter(x[:,0], x[:,1], color = np.repeat(colors_rgba, num_samples , axis=0), 166 | edgecolor = 'black', label= 'channel input', s = 25 ) 167 | plt.legend() 168 | 169 | else: 170 | for j in range(n // 2): 171 | ax[j].scatter(y_gen[:,2*j], y_gen[:,2*j+1], edgecolor = np.repeat(colors_rgba, num_samples , axis=0), 172 | marker="^", s = 25, facecolor='none', label = 'generated') 173 | ax[j].scatter(y[:,2*j], y[:,2*j+1], color = np.repeat(colors_rgba, num_samples , axis=0), 174 | marker="x", s = 25, label='ground truth') 175 | ax[j].scatter(x[:,2*j], x[:,2*j+1], color = np.repeat(colors_rgba, num_samples , axis=0), 176 | edgecolor = 'black', label= 'channel input', s = 25 ) 177 | ax[j].legend() 178 | 179 | if n // 2 == 1: 180 | plt.xlim([-max_amp, max_amp]) 181 | plt.ylim([-max_amp, max_amp]) 182 | plt.xlabel('$Re(y)$') 183 | plt.ylabel('$Im(y)$') 184 | title = 'Channel Output Constellation' 185 | plt.title(title) 186 | else: 187 | for j in range(n // 2): 188 | ax[j].set_xlim([-max_amp, max_amp]) 189 | ax[j].set_ylim([-max_amp, max_amp]) 190 | ax[j].set_xlabel('$Re(y_' + str(j + 1) + ')$') 191 | ax[j].set_ylabel('$Im(y_' + str(j + 1) + ')$') 192 | title = 'Channel Output Constellation' 193 | ax[j].set_title(title) 194 | plt.show() 195 | 196 | fig.savefig('figures/constellation_inout_'+ch_model_type+'.pdf', format='pdf') 197 | return 198 | 199 | def Constellation_Enc(enc, M): 200 | # Show the constellation of the learned encoder 201 | # Arrays for plotting 202 | bins = 30 #np.linspace(0.0, noise_std*4, 20) 203 | 204 | labels = torch.arange(0,M, dtype=int) 205 | m = F.one_hot(labels, num_classes=M) 206 | m = m.float().reshape(-1,M) 207 | x = enc(m).detach().cpu().numpy() 208 | 209 | n = x.shape[1] 210 | 211 | print("Constellation of learned encoder", flush=True) 212 | if n//2==1: 213 | plt.figure(figsize=(4,4)) 214 | for j in range(M): 215 | plt.scatter(x[j,0],x[j,1]) 216 | plt.xlabel('$Re(x)$') 217 | plt.ylabel('$Im(x)$') 218 | else: 219 | fig, axes = plt.subplots(1,n//2,figsize=(15,4)) 220 | for i in range(n //2): 221 | for j in range(M): 222 | axes[i].scatter(x[j,2*i],x[j,2*i+1]) 223 | axes[i].set_xlabel('$Re(x_' + str(i+1) + ')$') 224 | axes[i].set_ylabel('$Im(x_'+ str(i+1) +')$') 225 | axes[i].set_title("channel #"+ str(i+1)) 226 | 227 | plt.show() 228 | 229 | return 230 | 231 | def ECDF_histogram_m(enc, ch_model, ch_model_type, device, alphas, betas, alphas_bar_sqrt, 232 | one_minus_alphas_bar_sqrt, 233 | num_steps, M, num_samples, batch_size, noise_std, max_amp, m=0, 234 | PreT = False, denoising_alg = 'DDPM', traj = range(1), pred_type='epsilon'): 235 | figsize=(6.5, 2.5) 236 | # Test: Channel generation - ECDF of the channel output 237 | 238 | # Check the distribution of the amplitude for one dimension WLOG 239 | 240 | # Arrays for plotting 241 | z = np.linspace(0,max_amp,100) 242 | bins = 30 #np.linspace(0.0, noise_std*4, 20) 243 | alphas = alphas.to(device) 244 | betas = betas.to(device) 245 | one_minus_alphas_bar_sqrt = one_minus_alphas_bar_sqrt.to(device) 246 | 247 | 248 | y_amp = [] 249 | 250 | labels = torch.full([num_samples,1],m,dtype=int) 251 | 252 | msg_1h = F.one_hot(labels, num_classes=M) 253 | msg_1h = msg_1h.float().reshape(-1,M) 254 | x = enc(msg_1h) 255 | 256 | # True channel output y 257 | 258 | if ch_model_type == 'AWGN': 259 | channel_model = ch_AWGN 260 | elif ch_model_type == 'Rayleigh': 261 | channel_model = ch_Rayleigh_AWGN_n 262 | elif ch_model_type == 'SSPA': 263 | channel_model = ch_SSPA 264 | else: 265 | raise Exception('Unrecognized channel model. Available models: AWGN, Rayleigh, SSPA.') 266 | 267 | y = channel_model(x,noise_std,device) 268 | 269 | y_amp.append(torch.sqrt(torch.sum(torch.pow(y,2), dim=1))) 270 | y_amp = [i.detach().cpu().numpy() for i in y_amp] 271 | y_amp = np.asarray(y_amp).reshape((-1,)) 272 | 273 | y_gen_amp = [] 274 | for i_btch in range(num_samples // batch_size): 275 | 276 | # Generated channel output y 277 | labels_batch = labels[batch_size*i_btch:batch_size*(i_btch+1)].to(device) 278 | x_batch = x[batch_size*i_btch:batch_size*(i_btch+1)].to(device) 279 | 280 | if PreT: 281 | c = x_batch 282 | else: 283 | c = torch.cat((labels_batch.reshape([-1,1]).to(device),x_batch), dim=1).to(device) 284 | 285 | if denoising_alg == 'DDPM': 286 | x_seq = p_sample_loop_w_Condition(ch_model, x_batch.size(), num_steps, alphas, betas, alphas_bar_sqrt, 287 | one_minus_alphas_bar_sqrt, c, pred_type=pred_type) 288 | elif denoising_alg == 'DDIM': 289 | alphas_prod = torch.cumprod(alphas, 0) 290 | x_seq = p_sample_loop_w_Condition_DDIM(ch_model, x_batch.size(), traj, alphas_prod, alphas_bar_sqrt, 291 | one_minus_alphas_bar_sqrt, c, pred_type=pred_type) 292 | else: 293 | assert(ValueError) 294 | 295 | y_gen = x_seq[-1].detach().cpu() 296 | y_gen = y_gen 297 | 298 | # Change them into the polar coordinate and append in lists 299 | y_gen_amp.append(torch.sqrt(torch.sum(torch.pow(y_gen,2), dim=1) ).numpy() ) 300 | 301 | # List of tensors -> numpy array 302 | 303 | y_gen_amp = np.asarray(y_gen_amp).reshape((-1,)) 304 | 305 | print("The evaluation of generated channel in comparison to the original source; m =", m+1, flush=True) 306 | # Compute the empirical CDF of the ground truth random variables and the generated random variables 307 | y_ecdf = ECDF(y_amp) 308 | y_gen_ecdf = ECDF(y_gen_amp) 309 | fig, axes = plt.subplots(1,2,figsize= figsize, constrained_layout=True) 310 | 311 | 312 | axes[1].plot(z, y_gen_ecdf(z), label = 'Generated', color = (1,0,0,0.8)) 313 | axes[1].plot(z, y_ecdf(z), '--', label = 'True', color = (0.3,0.3,0.3,0.7)) 314 | axes[1].legend() 315 | axes[1].set_xlabel('$|y|$') 316 | axes[1].set_ylabel('Empirical CDF') 317 | 318 | axes[0].hist(y_gen_amp, bins=bins, edgecolor=(1,0,0,0.8), label='Generated',facecolor=(1,1,1,0.5)) 319 | axes[0].hist(y_amp, bins=bins, edgecolor=(0.3,0.3,0.3,0.7), linestyle = '--', label='True', facecolor=(1,1,1,0.5)) 320 | axes[0].set_ylabel('Frequency') 321 | 322 | fig.tight_layout() 323 | plt.legend() 324 | plt.savefig('figures/ECDF_histogram_'+ch_model_type+'.pdf', bbox_inches = 'tight') 325 | plt.show() 326 | return 327 | 328 | 329 | def show_denoising_m(channel_gen, encoder, device, alphas, betas, alphas_bar_sqrt, 330 | one_minus_alphas_bar_sqrt, M =16, m = 0, 331 | num_samples_dn = 1000, num_steps = 50, max_amp =3, PreT = False, 332 | denoising_alg = 'DDPM', traj = range(1), pred_type = 'epsilon'): 333 | print('The denoising process is plotted for the message '+str(m+1)+' and the first two dimensions.') 334 | if m > M : 335 | raise Exception("The message index m should be smaller than or equal to " +str(M)+".") 336 | 337 | labels = m * torch.ones(1) 338 | msg_1h = F.one_hot(labels.long(), num_classes=M) 339 | msg_1h = msg_1h.float().reshape(-1,M) 340 | x = encoder(msg_1h).repeat(num_samples_dn,1) 341 | #n = x.shape[1] 342 | 343 | if PreT: 344 | c = x 345 | else: 346 | c = torch.cat((labels.repeat(num_samples_dn,1).to(device), x ),dim=1) 347 | #mx = torch.cat((m, x),dim=1) 348 | if denoising_alg == 'DDPM': 349 | x_seq = p_sample_loop_w_Condition(channel_gen, x.size(), num_steps, alphas, betas, alphas_bar_sqrt, 350 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 351 | elif denoising_alg == 'DDIM': 352 | alphas_prod = torch.cumprod(alphas, 0) 353 | x_seq = p_sample_loop_w_Condition_DDIM(channel_gen, x.size(), traj, alphas_prod, alphas_bar_sqrt, 354 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 355 | 356 | len_seq = len(x_seq) - 1 357 | max_norm = 0. 358 | fig, axs = plt.subplots(1, 10, figsize=(28, 3)) 359 | fig1, axs1 = plt.subplots(1, 10, figsize=(28, 3)) 360 | for i in range(1, 11): 361 | cur_x = x_seq[i * len_seq // 10].detach().cpu() 362 | axs[i-1].scatter(cur_x[:, 0], cur_x[:, 1], color='white', edgecolor='gray', s=5); 363 | #axs[i-1].set_axis_off(); 364 | axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*len_seq//10)+'})$') 365 | amp = torch.norm(cur_x, dim=1).numpy() 366 | max_norm = max( max_norm, np.max(amp)) 367 | axs[i-1].set_xlim([-max_amp, max_amp]) 368 | axs[i-1].set_ylim([-max_amp, max_amp]) 369 | axs1[i-1].hist(amp,bins=20,edgecolor='gray'); 370 | #axs1[i-1].set_axis_off(); 371 | 372 | axs1[i-1].set_title('$q(\mathbf{x}_{'+str(i*num_steps//10)+'})$') 373 | 374 | for i in range(1,11): 375 | axs1[i-1].set_xlim([0, max_norm]) 376 | plt.show() 377 | 378 | 379 | def show_diffusion_m(channel_model, q_x, encoder, device, noise_std, M=16, m = 0, num_samples_df=1000, 380 | num_steps=50, max_amp = 3): 381 | if m>M: 382 | raise Exception("The message index m should be smaller than or equal to " +str(M)+".") 383 | 384 | print("Diffusion process for message " + str(m+1) + " in the first two dimensions.") 385 | msg_1h = torch.zeros([num_samples_df, M]).to(device) 386 | msg_1h[:,m] = 1 # One hot encoding for message with index 1. 387 | 388 | dataset_test_diff = channel_model(encoder(msg_1h) , noise_std, device) 389 | 390 | fig, axs = plt.subplots(1, 10, figsize=(28, 3)) 391 | for i in range(10): 392 | q_i = q_x(dataset_test_diff, torch.tensor([i * num_steps//10]).to(device)) 393 | q_i = torch.Tensor.cpu(q_i).detach().numpy() 394 | axs[i].scatter(q_i[:, 0], q_i[:, 1],color='white',edgecolor='gray', s=5); 395 | #axs[i].set_axis_off(); axs[i].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$') 396 | axs[i].set_xlim([-max_amp, max_amp]) 397 | axs[i].set_ylim([-max_amp, max_amp]) 398 | plt.show() 399 | 400 | print("Adjust beta according to the diffusion results. If it needs to be more noisy, increase beta. If it gets noisy too soon, decrease beta.") 401 | 402 | 403 | 404 | -------------------------------------------------------------------------------- /src/__pycache__/channel_models.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/channel_models.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/channel_models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/channel_models.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/ema.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/ema.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/ema.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/ema.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/models.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/models.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/trainer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/trainer.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/trainer.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils_plot.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/utils_plot.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils_plot.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/src/__pycache__/utils_plot.cpython-39.pyc -------------------------------------------------------------------------------- /src/channel_models.py: -------------------------------------------------------------------------------- 1 | # The channel models are defined here. 2 | 3 | import torch 4 | import numpy as np 5 | 6 | def ch_AWGN(ch_input, AWGN_std, device): 7 | ch_input = ch_input.to(device) 8 | ch_output = ch_input + torch.normal(mean=0, std=AWGN_std, size = ch_input.size()).to(device) 9 | return ch_output 10 | 11 | 12 | def ch_Rayleigh_AWGN(codeword, noise_std, dvc): 13 | codeword = codeword.to(dvc) 14 | i,j = codeword.size() 15 | h_ray = (1/np.sqrt(2))*torch.sqrt( torch.randn((i,j), device = dvc)**2 + torch.randn((i,j), device = dvc)**2) 16 | 17 | channel_noise = noise_std*torch.randn((i,j), device = dvc) 18 | 19 | return h_ray* codeword + channel_noise 20 | 21 | 22 | def ch_SSPA(x, sigma_n, device, p = 3.0, A_0 = 1.5, v = 5.): # A_0 limiting output amplitude, v is small signal gain. 23 | assert x.size(1) % 2 == 0 24 | 25 | # x= ([1 2 3 4] ,[5 6 7 8]) 26 | dim = int(x.size(1) //2) 27 | x_2d = x.reshape(-1,2) #x_2d =([1 2], [3 4] , [5 6], [7 8]) 28 | A = torch.sum(x_2d ** 2, dim=1) ** 0.5 # Amplitude 29 | A_ratio = v / (1+ (v*A/A_0)**(2*p) )**(1/2/p) # g_A / A 30 | x_amp_2d = torch.mul(A_ratio.reshape(-1,1), x_2d) 31 | x_amp = x_amp_2d.reshape(-1, 2*dim) 32 | y = x_amp + (sigma_n/2**0.5) * torch.randn_like(x) 33 | 34 | return y 35 | -------------------------------------------------------------------------------- /src/ema.py: -------------------------------------------------------------------------------- 1 | # Exponential Moving Average Class 2 | # Orignal source: https://github.com/acids-ircam/diffusion_models 3 | 4 | 5 | class EMA(object): 6 | def __init__(self, mu=0.999): 7 | self.mu = mu 8 | self.shadow = {} 9 | 10 | def register(self, module): 11 | for name, param in module.named_parameters(): 12 | if param.requires_grad: 13 | self.shadow[name] = param.data.clone() 14 | 15 | def update(self, module): 16 | for name, param in module.named_parameters(): 17 | if param.requires_grad: 18 | self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data 19 | 20 | def ema(self, module): 21 | for name, param in module.named_parameters(): 22 | if param.requires_grad: 23 | param.data.copy_(self.shadow[name].data) 24 | 25 | def ema_copy(self, module): 26 | module_copy = type(module)(module.config).to(module.config.device) 27 | module_copy.load_state_dict(module.state_dict()) 28 | self.ema(module_copy) 29 | return module_copy 30 | 31 | def state_dict(self): 32 | return self.shadow 33 | 34 | def load_state_dict(self, state_dict): 35 | self.shadow = state_dict -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | #An Implementation of Diffusion Network Model 2 | #Oringinal source: https://github.com/acids-ircam/diffusion_models 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch 7 | import numpy as np 8 | from channel_models import * 9 | 10 | class ConditionalLinear(nn.Module): 11 | def __init__(self, num_in, num_out, n_steps): 12 | super(ConditionalLinear, self).__init__() 13 | use_cuda = torch.cuda.is_available() 14 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 15 | self.num_out = num_out 16 | self.lin = nn.Linear(num_in, num_out, device =self.this_device ) 17 | self.embed = nn.Embedding(n_steps, num_out, device = self.this_device) # for each step, the distribution is distinct. Given the number of steps, the embedding returns a gamma for it. 18 | self.embed.weight.data.uniform_() 19 | 20 | def forward(self, x, y): 21 | x = x.type(torch.FloatTensor).to(self.this_device) 22 | out = self.lin(x).to(self.this_device) 23 | gamma = self.embed(y).to(self.this_device) 24 | out = gamma.view(-1, self.num_out).to(self.this_device) * out # Why is it multiplicative? not additive? 25 | return out 26 | 27 | class ConditionalModel(nn.Module): 28 | def __init__(self, n_steps): 29 | super(ConditionalModel, self).__init__() 30 | use_cuda = torch.cuda.is_available() 31 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 32 | self.lin1 = ConditionalLinear(2, 128, n_steps) 33 | self.lin2 = ConditionalLinear(128, 128, n_steps) 34 | self.lin3 = ConditionalLinear(128, 128, n_steps) 35 | self.lin4 = nn.Linear(128, 2, device =self.this_device) 36 | 37 | def forward(self, x, y): 38 | x = x.type(torch.FloatTensor).to(self.this_device) 39 | x = F.softplus(self.lin1(x, y)) 40 | x = F.softplus(self.lin2(x, y)) 41 | x = F.softplus(self.lin3(x, y)) 42 | return self.lin4(x) 43 | 44 | 45 | class ConditionalLinear_w_Condition(nn.Module): 46 | def __init__(self, num_in, num_out, n_steps): 47 | super(ConditionalLinear_w_Condition, self).__init__() 48 | use_cuda = torch.cuda.is_available() 49 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 50 | self.num_out = num_out 51 | self.lin = nn.Linear(num_in, num_out, device = self.this_device) 52 | self.embed = nn.Embedding(n_steps, num_out, device = self.this_device) # for each step, the distribution is distinct. Given the number of steps, the embedding returns a gamma for it. 53 | self.embed.weight.data.uniform_() 54 | 55 | def forward(self, x, y, c): 56 | x = x.type(torch.FloatTensor).to(self.this_device) 57 | c = c.type(torch.FloatTensor).to(self.this_device) 58 | out = self.lin(torch.cat((x,c), dim=1)).to(self.this_device) 59 | gamma = self.embed(y).to(self.this_device) 60 | out = gamma.view(-1, self.num_out).to(self.this_device) * out # Why is it multiplicative? not additive? 61 | return out 62 | 63 | 64 | class ConditionalModel_w_Condition(nn.Module): 65 | def __init__(self, n_steps, M, n, N): 66 | super(ConditionalModel_w_Condition, self).__init__() 67 | use_cuda = torch.cuda.is_available() 68 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 69 | self.lin1 = ConditionalLinear_w_Condition(2*n+M, N, n_steps) # hidden layer 128 70 | self.lin2 = ConditionalLinear(N, N, n_steps) 71 | self.lin3 = ConditionalLinear(N, N, n_steps) 72 | self.lin4 = nn.Linear(N, n, device = self.this_device) 73 | 74 | def forward(self, x, y, c): 75 | x = x.type(torch.FloatTensor).to(self.this_device) 76 | c = c.type(torch.FloatTensor).to(self.this_device) 77 | 78 | x = F.softplus(self.lin1(x, y, c)) 79 | x = F.softplus(self.lin2(x, y)) 80 | x = F.softplus(self.lin3(x, y)) 81 | return self.lin4(x) 82 | 83 | 84 | 85 | # Binary Input Encoders/decoders: 86 | class Encoder(nn.Module): 87 | def __init__(self, k, M, n, **extra_kwargs): 88 | super(Encoder, self).__init__() 89 | 90 | use_cuda = torch.cuda.is_available() 91 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 92 | 93 | self._f = nn.Sequential( 94 | nn.Linear(k, M), 95 | nn.ELU(), 96 | nn.Linear(M,M), 97 | nn.ELU(), 98 | nn.Linear(M, n)#,norm_layer 99 | ) 100 | 101 | def power_constraint(self, codes): 102 | codes_mean = torch.mean(codes) 103 | codes_std = torch.std(codes) 104 | codes_norm = (codes-codes_mean)/ codes_std 105 | 106 | return codes_norm 107 | 108 | def forward(self, inputs): 109 | inputs = inputs.type(torch.FloatTensor).to(self.this_device) 110 | _x = self._f(inputs) 111 | x = self.power_constraint(_x) 112 | return x 113 | 114 | 115 | class Decoder(nn.Module): 116 | def __init__(self,k, M ,n, **extra_kwargs): 117 | super(Decoder, self).__init__() 118 | self._f = nn.Sequential( 119 | nn.Linear(n, M), 120 | nn.ELU(), 121 | nn.Linear(M, M), 122 | nn.ELU(), 123 | nn.Linear(M, k), 124 | nn.Sigmoid() 125 | ) 126 | 127 | def forward(self, inputs): 128 | x = self._f(inputs) 129 | return x 130 | 131 | # one-hot encoders/decoders: 132 | 133 | class Encoder_1h(nn.Module): #power constraint 134 | def __init__(self, M, n, P_in, **extra_kwargs): 135 | super(Encoder_1h, self).__init__() 136 | 137 | use_cuda = torch.cuda.is_available() 138 | self.this_device = torch.device("cuda" if use_cuda else "cpu") 139 | self.P_in = P_in 140 | self.M = M 141 | self._f = nn.Sequential( 142 | nn.Linear(M, M), 143 | nn.ELU(), 144 | nn.Linear(M,M), 145 | nn.ELU(), 146 | nn.Linear(M, n)#,norm_layer 147 | ) 148 | def power(self): 149 | inputs = torch.eye(self.M).to(self.this_device) 150 | codes = self._f(inputs) 151 | P = torch.mean(torch.sum(codes**2, dim=1)) 152 | return P 153 | 154 | def power_constraint(self, codes): 155 | P = self.power() #torch.mean(torch.sum(codes**2, dim=1)) 156 | codes_norm = codes * np.sqrt(self.P_in) / torch.sqrt(P) 157 | return codes_norm 158 | 159 | def forward(self, inputs): 160 | inputs = inputs.to(self.this_device) 161 | _x = self._f(inputs) 162 | x = self.power_constraint(_x) 163 | return x 164 | 165 | class Decoder_1h(nn.Module): 166 | def __init__(self, M ,n, **extra_kwargs): 167 | super(Decoder_1h, self).__init__() 168 | self._f = nn.Sequential( 169 | #nn.Linear(n, n), 170 | #nn.ELU(), 171 | nn.Linear(n, M), 172 | nn.ELU(), 173 | nn.Linear(M, M), 174 | nn.ELU(), 175 | nn.Linear(M, M), 176 | ) 177 | 178 | def forward(self, inputs): 179 | x = self._f(inputs) 180 | return x 181 | 182 | 183 | 184 | 185 | class Channel(torch.nn.Module): 186 | def __init__(self, enc, dec): 187 | super(Channel, self).__init__() 188 | use_cuda = torch.cuda.is_available() 189 | self.device = torch.device("cuda" if use_cuda else "cpu") 190 | self.enc = enc 191 | self.dec = dec 192 | 193 | def forward(self, inputs, noise_std): 194 | #End-to-End Framework 195 | codeword = self.enc(inputs) 196 | 197 | i,j = codeword.size() 198 | channel_noise = noise_std*torch.randn((i, j)).to(self.device) 199 | 200 | rec_signal = codeword + channel_noise 201 | dec_signal = self.dec(rec_signal) 202 | 203 | return dec_signal 204 | 205 | 206 | 207 | class Channel_ray(torch.nn.Module): 208 | def __init__(self, enc, dec): 209 | super(Channel_ray, self).__init__() 210 | use_cuda = torch.cuda.is_available() 211 | self.device = torch.device("cuda" if use_cuda else "cpu") 212 | self.enc = enc 213 | self.dec = dec 214 | 215 | def forward(self, inputs, noise_std): 216 | #End-to-End Framework 217 | codeword = self.enc(inputs) 218 | 219 | i,j = codeword.size() 220 | 221 | h_ray = (1/np.sqrt(2))*torch.sqrt( torch.rand(size = codeword.size(), device = self.device)**2 +torch.rand(size = codeword.size(), device = self.device)**2) 222 | channel_noise = noise_std*torch.randn((i, j), device = self.device) 223 | 224 | rec_signal = codeword*h_ray + channel_noise 225 | dec_signal = self.dec(rec_signal) 226 | 227 | return dec_signal 228 | 229 | 230 | class Channel_SSPA(torch.nn.Module): 231 | def __init__(self, enc, dec): 232 | super(Channel_SSPA, self).__init__() 233 | use_cuda = torch.cuda.is_available() 234 | self.device = torch.device("cuda" if use_cuda else "cpu") 235 | self.enc = enc 236 | self.dec = dec 237 | def forward(self, inputs, noise_std): 238 | codeword = self.enc(inputs) 239 | 240 | rec_signal = ch_SSPA(codeword, noise_std, self.device) 241 | dec_signal = self.dec(rec_signal) 242 | 243 | return dec_signal 244 | 245 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.optim.lr_scheduler import LambdaLR 11 | from torch.utils.data.dataloader import DataLoader 12 | from torch.nn import functional as F 13 | import utils 14 | 15 | 16 | class TrainerConfig: 17 | # optimization parameters 18 | max_epochs = 10 19 | batch_size = 500 20 | M=16 21 | learning_rate = 1e-3 22 | rate = 4/7 23 | num_workers = 0 # for DataLoader 24 | 25 | 26 | def __init__(self, **kwargs): 27 | for k,v in kwargs.items(): 28 | setattr(self, k, v) 29 | 30 | class Trainer: 31 | def __init__(self, model, train_dataset, config, one_hot): 32 | self.model = model 33 | self.train_dataset = train_dataset 34 | #self.test_dataset = test_dataset 35 | self.config = config 36 | self.one_hot = one_hot 37 | 38 | self.device = 'cpu' 39 | if torch.cuda.is_available(): 40 | self.device = torch.cuda.current_device() 41 | self.model = torch.nn.DataParallel(self.model).to(self.device) 42 | 43 | def train(self): 44 | model, config = self.model, self.config 45 | optimizer = torch.optim.NAdam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate) 46 | 47 | loss_CE = nn.CrossEntropyLoss() 48 | 49 | def run_epoch(): 50 | lr = config.learning_rate 51 | batch_size = config.batch_size 52 | M = config.M 53 | data = self.train_dataset 54 | loader = DataLoader(data, shuffle=True, pin_memory=False, batch_size=batch_size, 55 | num_workers=config.num_workers) 56 | losses, batch_BER, batch_SER = [], [], [] 57 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 58 | 59 | for it, x in pbar: 60 | labels = x.reshape((batch_size,1)).to(self.device) 61 | x = F.one_hot(labels.long(), num_classes=M) 62 | x = x.float().reshape(batch_size,M) 63 | 64 | optimizer.zero_grad() 65 | output = model(x, config.noise_std) 66 | if self.one_hot: 67 | loss = loss_CE(output, labels.long().squeeze(1)) 68 | else: 69 | loss = F.binary_cross_entropy(output, x) 70 | loss = loss.mean() 71 | loss.backward() 72 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.) 73 | losses.append(loss.item()) 74 | optimizer.step() 75 | 76 | if self.one_hot: 77 | SER, _, _ = utils.SER(labels.long(), output, 0.9) 78 | batch_SER.append(SER.item()) 79 | if it%100==0: 80 | pbar.set_description(f"epoch {epoch+1}: loss {np.mean(losses):.2e} SER {np.mean(batch_SER):.2e}") 81 | else: 82 | batch_BER.append(utils.B_Ber_m(x, output).item()) 83 | pbar.set_description(f"epoch {epoch+1}: loss {np.mean(losses):.2e} BER {np.mean(batch_BER):.2e}") 84 | 85 | 86 | for epoch in range(config.max_epochs): 87 | run_epoch() 88 | 89 | def test(self, snr_range, one_hot, erasure_bound): 90 | model, config = self.model, self.config 91 | ser = [] 92 | 93 | def run_epoch(): 94 | data = self.train_dataset 95 | batch_size = config.batch_size 96 | M = config.M 97 | loader = DataLoader(data, shuffle=True, pin_memory=False, batch_size=batch_size, 98 | num_workers=config.num_workers) 99 | batch_BER = [] 100 | block_ER = [] 101 | erasures = [] 102 | batch_SER = [] 103 | 104 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 105 | 106 | for it, x in pbar: 107 | if self.one_hot: 108 | labels = x.reshape((batch_size,1)).to(self.device) 109 | x=F.one_hot(labels.long(), num_classes=M) 110 | x = x.float().reshape(batch_size,M) 111 | else: 112 | x = x.float().to(self.device) 113 | noise_std_ = utils.EbNo_to_noise(snr_range[epoch], config.rate) 114 | #noise_std_ = utils.SNR_to_noise(snr_range[epoch]) 115 | 116 | output = model(x, noise_std_) 117 | 118 | if self.one_hot: 119 | SER, _, _ = utils.SER(labels.long(), output, erasure_bound) 120 | batch_SER.append(SER.item()) 121 | #block_ER.append(bler) 122 | #erasures.append(Erasures) 123 | if it%100==0: 124 | pbar.set_description(f"SNR {snr_range[epoch]} iter {it}: SER {np.mean(batch_SER):.3e}") 125 | else: 126 | batch_BER.append(utils.B_Ber_m(x.float(), output).item()) 127 | block_ER.append(utils.Block_ER(x.float(), output)) 128 | pbar.set_description(f"SNR {snr_range[epoch]} iter {it}: avg. BER {np.mean(batch_BER):.8f} BLER {np.mean(block_ER):.7e}") 129 | 130 | return np.mean(batch_SER) 131 | 132 | num_samples = self.train_dataset.size(dim=0) 133 | for epoch in range(len(snr_range)): 134 | 135 | temp1 = 0 136 | it_in_while = 0 137 | 138 | while temp1 * num_samples *(it_in_while) < 1000: # To guarantee we have enough samples for the Monte-Carlo method. 139 | if it_in_while > 1 : 140 | print("The number of samples is not enough. One more epoch is running. Total # of samples used: " , num_samples * it_in_while) 141 | temp2 = run_epoch() 142 | temp1 = (it_in_while * temp1 + temp2)/(it_in_while+1) # taking average of the error probability. 143 | it_in_while += 1 144 | ser.append(temp1) 145 | 146 | return ser 147 | 148 | 149 | 150 | class TrainerConfig_DDM: 151 | # optimization parameters 152 | max_epochs = 10 153 | dataset_size = 1000000 154 | batch_size = 500 155 | noise_std = 1 156 | learning_rate = 1e-3 157 | M=16 158 | n=7 159 | num_steps = 50 160 | betas = torch.zeros([num_steps]) 161 | rate = 4/7 162 | num_workers = 0 # for DataLoader 163 | optim_gen = optim.Adam 164 | # TRANSFER_LEARNING= False 165 | IS_RES = False 166 | pred_type = 'epsilon' 167 | 168 | 169 | def __init__(self, **kwargs): 170 | for k,v in kwargs.items(): 171 | setattr(self, k, v) 172 | self.alphas = 1 - self.betas.to('cpu') 173 | self.alphas_prod = torch.cumprod(self.alphas, 0) 174 | self.alphas_prod_p = torch.cat([torch.tensor([1]), self.alphas_prod[:-1]], 0) 175 | self.alphas_bar_sqrt = torch.sqrt(self.alphas_prod) 176 | #one_minus_alphas_bar_log = torch.log(1 - alphas_prod) 177 | self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - self.alphas_prod) 178 | 179 | class Trainer_DDM: 180 | def __init__(self, channel_gen, ema, device, channel_model, tconf_gen): 181 | self.channel_gen = channel_gen 182 | self.ema = ema 183 | self.device = device 184 | self.channel_model = channel_model 185 | self.tconf_gen = tconf_gen 186 | 187 | def train_PreT(self): # pretraining algorithm 188 | channel_gen = self.channel_gen 189 | ema = self.ema 190 | device = self.device 191 | channel_model = self.channel_model 192 | tconf_gen = self.tconf_gen 193 | n = tconf_gen.n 194 | 195 | IS_RES = tconf_gen.IS_RES 196 | 197 | optimizer_ch = tconf_gen.optim_gen(channel_gen.parameters(), lr=tconf_gen.learning_rate) 198 | dataset_gen = torch.randn(tconf_gen.dataset_size, n) 199 | alphas_bar_sqrt = tconf_gen.alphas_bar_sqrt.to(device) 200 | one_minus_alphas_bar_sqrt = tconf_gen.one_minus_alphas_bar_sqrt.to(device) 201 | 202 | if tconf_gen.pred_type == 'epsilon': 203 | estimation_loss = utils.noise_estimation_loss 204 | elif tconf_gen.pred_type == 'v': 205 | estimation_loss = utils.v_estimation_loss 206 | else: 207 | assert(ValueError) 208 | 209 | 210 | loss_ep_gen = [] 211 | for t in range(tconf_gen.max_epochs): 212 | loader = DataLoader(dataset_gen, shuffle=True, pin_memory=False, batch_size=tconf_gen.batch_size) 213 | loss_batch, batch_BER, batch_SER = [], [], [] 214 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 215 | 216 | for it, x in pbar: 217 | x = x.to(device) 218 | y = channel_model(x, tconf_gen.noise_std, device) 219 | if IS_RES: 220 | y = y - x 221 | 222 | yx = torch.cat((y , x),dim=1) 223 | # Create EMA model 224 | loss = estimation_loss(channel_gen, yx, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, tconf_gen.num_steps, n) 225 | 226 | optimizer_ch.zero_grad() 227 | loss.backward() 228 | torch.nn.utils.clip_grad_norm_(channel_gen.parameters(), 1.) 229 | optimizer_ch.step() 230 | # Update the exponential moving average 231 | ema.update(channel_gen) 232 | 233 | loss_batch.append(loss.item()) 234 | 235 | if it % 100 == 0: 236 | pbar.set_description(f"epoch {t+1}: loss {np.mean(loss.item()):.2e}") 237 | 238 | loss_ep_gen.append(np.mean(loss_batch)) 239 | return loss_ep_gen 240 | 241 | def train_itr(self, encoder): # iterative training 242 | 243 | channel_gen = self.channel_gen 244 | ema = self.ema 245 | device = self.device 246 | channel_model = self.channel_model 247 | tconf_gen = self.tconf_gen 248 | M = tconf_gen.M 249 | n = tconf_gen.n 250 | 251 | optimizer_ch = tconf_gen.optim_gen(channel_gen.parameters(), lr=tconf_gen.learning_rate) 252 | 253 | dataset_gen = np.random.randint(M, size=tconf_gen.dataset_size) 254 | 255 | alphas_bar_sqrt = tconf_gen.alphas_bar_sqrt.to(device) 256 | one_minus_alphas_bar_sqrt = tconf_gen.one_minus_alphas_bar_sqrt.to(device) 257 | 258 | if tconf_gen.pred_type == 'epsilon': 259 | estimation_loss = utils.noise_estimation_loss 260 | elif tconf_gen.pred_type == 'v': 261 | estimation_loss = utils.v_estimation_loss 262 | else: 263 | assert(ValueError) 264 | 265 | loss_ep_gen = [] 266 | for t in range(tconf_gen.max_epochs): 267 | loader = DataLoader(dataset_gen, shuffle=True, pin_memory=False, batch_size=tconf_gen.batch_size) 268 | 269 | loss_batch, batch_BER, batch_SER = [], [], [] 270 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 271 | 272 | for it, m in pbar: 273 | labels = m.reshape((-1,1)) 274 | m = F.one_hot(labels.long(), num_classes=M) 275 | m = m.float().reshape(-1,M).to(device) 276 | 277 | x = encoder(m) 278 | y = channel_model(x, tconf_gen.noise_std, device) 279 | yx = torch.cat((y , x ),dim=1) 280 | 281 | # Create EMA model 282 | loss = estimation_loss(channel_gen, yx, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, tconf_gen.num_steps, n) 283 | 284 | optimizer_ch.zero_grad() 285 | loss.backward() 286 | torch.nn.utils.clip_grad_norm_(channel_gen.parameters(), 1.) 287 | optimizer_ch.step() 288 | 289 | # Update the exponential moving average 290 | ema.update(channel_gen) 291 | 292 | loss_batch.append(loss.item()) 293 | 294 | if it % 100 == 0: 295 | pbar.set_description(f"epoch {t+1}: loss {np.mean(loss.item()):.2e}") 296 | 297 | loss_ep_gen.append(np.mean(loss_batch)) 298 | return loss_ep_gen 299 | 300 | 301 | class TrainerConfig_AE_w_DDM: 302 | # optimization parameters 303 | max_epochs = 10 304 | dataset_size = 1000000 305 | batch_size = 500 306 | noise_std = 1 307 | learning_rate = 1e-3 308 | M=16 309 | n=7 310 | rate = 4/7 311 | num_steps = 50 312 | #alphas = torch.zeros([num_steps]) 313 | betas = torch.zeros([num_steps]) 314 | denoising_alg = 'DDPM' 315 | pred_type = 'epsilon' 316 | traj = range(num_steps) 317 | channel_model = 'None' 318 | 319 | optim_AE = optim.NAdam 320 | num_workers = 0 # for DataLoader 321 | 322 | def __init__(self, **kwargs): 323 | for k,v in kwargs.items(): 324 | setattr(self, k, v) 325 | self.alphas = 1 - self.betas.to('cpu') 326 | self.alphas_prod = torch.cumprod(self.alphas, 0) 327 | self.alphas_prod_p = torch.cat([torch.tensor([1]), self.alphas_prod[:-1]], 0) 328 | self.alphas_bar_sqrt = torch.sqrt(self.alphas_prod) 329 | #one_minus_alphas_bar_log = torch.log(1 - alphas_prod) 330 | self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - self.alphas_prod) 331 | 332 | class Trainer_AE_w_DDM: 333 | def __init__(self, encoder, decoder, channel_gen, device, tconf_AE): 334 | self.encoder = encoder 335 | self.decoder = decoder 336 | self.channel_gen = channel_gen 337 | self.device = device 338 | self.tconf_AE = tconf_AE 339 | 340 | def train_PreT(self):#(encoder, decoder, channel_gen, device, tconf_AE): 341 | encoder = self.encoder 342 | decoder = self.decoder 343 | channel_gen = self.channel_gen 344 | device = self.device 345 | tconf_AE = self.tconf_AE 346 | 347 | 348 | M = tconf_AE.M 349 | n = tconf_AE.n 350 | 351 | denoising_alg = tconf_AE.denoising_alg 352 | one_minus_alphas_bar_sqrt = tconf_AE.one_minus_alphas_bar_sqrt.to(device) 353 | betas = tconf_AE.betas.to(device) 354 | alphas = tconf_AE.alphas.to(device) 355 | 356 | dataset_AE = np.random.randint(M, size=tconf_AE.dataset_size) 357 | loss_CE = nn.CrossEntropyLoss() 358 | optimizer_AE = tconf_AE.optim_AE(list(encoder.parameters()) + list(decoder.parameters()), lr = tconf_AE.learning_rate) 359 | 360 | loss_ep_AE = [] 361 | 362 | for t in range(tconf_AE.max_epochs): 363 | loader = DataLoader(dataset_AE, shuffle=True, pin_memory=False, batch_size=tconf_AE.batch_size)#, num_workers=config.num_workers) 364 | loss_batch, batch_BER, batch_SER = [], [], [] 365 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 366 | 367 | for it, m in pbar: #enumerate(loader): 368 | labels = m.reshape((-1,1)).to(device) 369 | optimizer_AE.zero_grad() 370 | 371 | # Forward Pass: 372 | # Encoding 373 | m = F.one_hot(labels.long(), num_classes=M) 374 | m = m.float().reshape(-1,M) 375 | x = encoder(m) 376 | c = x 377 | 378 | # Channel 379 | if denoising_alg == 'DDPM': 380 | x_seq = utils.p_sample_loop_w_Condition(channel_gen, x.size(), tconf_AE.num_steps, 381 | tconf_AE.alphas, tconf_AE.betas, 382 | tconf_AE.alphas_bar_sqrt, 383 | tconf_AE.one_minus_alphas_bar_sqrt, c, 384 | pred_type=tconf_AE.pred_type) 385 | elif denoising_alg == 'DDIM': 386 | x_seq = utils.p_sample_loop_w_Condition_DDIM(channel_gen, x.size(), tconf_AE.traj, 387 | tconf_AE.alphas_prod, tconf_AE.alphas_bar_sqrt, 388 | tconf_AE.one_minus_alphas_bar_sqrt, c, 389 | pred_type=tconf_AE.pred_type) 390 | y = x_seq[-1] 391 | 392 | # Decoding 393 | m_d = decoder(y) 394 | 395 | # Backward pass: 396 | # Compute the loss and the gradient 397 | loss = loss_CE(m_d, labels.long().squeeze(1)) 398 | loss = loss.mean() 399 | loss.backward() 400 | 401 | torch.nn.utils.clip_grad_norm_(list(encoder.parameters())+list(decoder.parameters()), 1.) 402 | optimizer_AE.step() 403 | 404 | 405 | loss_batch.append(loss.item()) 406 | 407 | SER_tmp, _, _ = utils.SER(labels.long(), m_d, 0.9) 408 | batch_SER.append(SER_tmp.item()) 409 | 410 | if it%100==0: 411 | pbar.set_description(f"epoch {t+1}: loss {np.mean(loss_batch):.2e} SER {np.mean(batch_SER):.2e}") 412 | 413 | # Save epoch loss 414 | loss_ep_AE.append(np.mean(loss_batch)) 415 | return loss_ep_AE 416 | 417 | 418 | 419 | def train_PreT_enc(self): # (encoder, decoder, channel_gen, device, tconf_AE): 420 | encoder = self.encoder 421 | decoder = self.decoder 422 | channel_gen = self.channel_gen 423 | device = self.device 424 | tconf_AE = self.tconf_AE 425 | 426 | M = tconf_AE.M 427 | n = tconf_AE.n 428 | 429 | denoising_alg = tconf_AE.denoising_alg 430 | one_minus_alphas_bar_sqrt = tconf_AE.one_minus_alphas_bar_sqrt.to(device) 431 | betas = tconf_AE.betas.to(device) 432 | alphas = tconf_AE.alphas.to(device) 433 | 434 | dataset_AE = np.random.randint(M, size=tconf_AE.dataset_size) 435 | loss_CE = nn.CrossEntropyLoss() 436 | optimizer_AE = tconf_AE.optim_AE(encoder.parameters(), lr=tconf_AE.learning_rate) 437 | 438 | loss_ep_AE = [] 439 | SER_ep = [] 440 | 441 | for t in range(tconf_AE.max_epochs): 442 | loader = DataLoader(dataset_AE, shuffle=True, pin_memory=False, batch_size=tconf_AE.batch_size) 443 | loss_batch, batch_BER, batch_SER = [], [], [] 444 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 445 | 446 | for it, m in pbar: 447 | labels = m.reshape((-1, 1)).to(device) 448 | optimizer_AE.zero_grad() 449 | 450 | # Forward Pass: 451 | # Encoding 452 | m = F.one_hot(labels.long(), num_classes=M) 453 | m = m.float().reshape(-1, M) 454 | x = encoder(m) 455 | c = x 456 | 457 | # Channel 458 | 459 | if denoising_alg == 'DDPM': 460 | x_seq = utils.p_sample_loop_w_Condition(channel_gen, x.size(), tconf_AE.num_steps 461 | , tconf_AE.alphas, tconf_AE.betas, 462 | tconf_AE.alphas_bar_sqrt, 463 | tconf_AE.one_minus_alphas_bar_sqrt, c, 464 | pred_type=tconf_AE.pred_type) 465 | y = x_seq[-1] 466 | elif denoising_alg == 'DDIM': 467 | x_seq = utils.p_sample_loop_w_Condition_DDIM(channel_gen, x.size(), tconf_AE.traj, 468 | tconf_AE.alphas_prod, tconf_AE.alphas_bar_sqrt, 469 | tconf_AE.one_minus_alphas_bar_sqrt, c, 470 | pred_type=tconf_AE.pred_type) 471 | y = x_seq[-1] 472 | 473 | # Decoding 474 | m_d = decoder(y) 475 | 476 | # Backward pass: 477 | # Compute the loss and the gradient 478 | loss = loss_CE(m_d, labels.long().squeeze(1)) 479 | loss = loss.mean() 480 | loss.backward() 481 | 482 | torch.nn.utils.clip_grad_norm_(encoder.parameters(), 1.) 483 | optimizer_AE.step() 484 | 485 | loss_batch.append(loss.item()) 486 | 487 | SER_tmp, _, _ = utils.SER(labels.long(), m_d, 0.9) 488 | batch_SER.append(SER_tmp.item()) 489 | 490 | if it % 100 == 0: 491 | pbar.set_description(f"loss {np.mean(loss_batch):.2e} SER {np.mean(batch_SER):.2e}") 492 | 493 | # Save epoch loss 494 | loss_ep_AE.append(np.mean(loss_batch)) 495 | SER_ep.append(np.mean(batch_SER)) 496 | return loss_ep_AE, SER_ep 497 | 498 | def train_dec(self): 499 | encoder = self.encoder 500 | decoder = self.decoder 501 | device = self.device 502 | tconf_AE = self.tconf_AE 503 | 504 | 505 | M = tconf_AE.M 506 | n = tconf_AE.n 507 | 508 | dataset_AE = np.random.randint(M, size=tconf_AE.dataset_size) 509 | loss_CE = nn.CrossEntropyLoss() 510 | optimizer_AE = tconf_AE.optim_AE(decoder.parameters(), lr = tconf_AE.learning_rate) 511 | 512 | loss_ep_AE = [] 513 | SER_ep = [] 514 | 515 | for t in range(tconf_AE.max_epochs): 516 | loader = DataLoader(dataset_AE, shuffle=True, pin_memory=False, batch_size=tconf_AE.batch_size)#, num_workers=config.num_workers) 517 | loss_batch, batch_BER, batch_SER = [], [], [] 518 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 519 | 520 | for it, m in pbar: #enumerate(loader): 521 | labels = m.reshape((-1,1)).to(device) 522 | optimizer_AE.zero_grad() 523 | 524 | # Forward Pass: 525 | # Encoding 526 | m = F.one_hot(labels.long(), num_classes=M) 527 | m = m.float().reshape(-1,M) 528 | x = encoder(m) 529 | y = tconf_AE.channel_model(x, tconf_AE.noise_std, device) 530 | 531 | # Decoding 532 | m_d = decoder(y) 533 | 534 | # Backward pass: 535 | # Compute the loss and the gradient 536 | loss = loss_CE(m_d, labels.long().squeeze(1)) 537 | loss = loss.mean() 538 | loss.backward() 539 | 540 | torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1.) 541 | optimizer_AE.step() 542 | 543 | 544 | loss_batch.append(loss.item()) 545 | 546 | SER_tmp, _, _ = utils.SER(labels.long(), m_d, 0.9) 547 | batch_SER.append(SER_tmp.item()) 548 | 549 | if it%100==0: 550 | pbar.set_description(f"loss {np.mean(loss_batch):.2e} SER {np.mean(batch_SER):.2e}") 551 | 552 | # Save epoch loss 553 | loss_ep_AE.append(np.mean(loss_batch)) 554 | SER_ep.append(np.mean(batch_SER)) 555 | return loss_ep_AE, SER_ep 556 | 557 | def train_itr(self):#(encoder, decoder, channel_gen, device, tconf_AE): 558 | encoder = self.encoder 559 | decoder = self.decoder 560 | channel_gen = self.channel_gen 561 | device = self.device 562 | tconf_AE = self.tconf_AE 563 | 564 | M = tconf_AE.M 565 | n = tconf_AE.n 566 | 567 | one_minus_alphas_bar_sqrt = tconf_AE.one_minus_alphas_bar_sqrt.to(device) 568 | betas = tconf_AE.betas.to(device) 569 | alphas = tconf_AE.alphas.to(device) 570 | 571 | dataset_AE = np.random.randint(M, size=tconf_AE.dataset_size) 572 | loss_CE = nn.CrossEntropyLoss() 573 | optimizer_AE = tconf_AE.optim_AE(list(encoder.parameters()) + list(decoder.parameters()), lr = tconf_AE.learning_rate) 574 | 575 | loss_ep_AE = [] 576 | SER_ep = [] 577 | 578 | for t in range(tconf_AE.max_epochs): 579 | loader = DataLoader(dataset_AE, shuffle=True, pin_memory=False, batch_size=tconf_AE.batch_size)#, num_workers=config.num_workers) 580 | loss_batch, batch_BER, batch_SER = [], [], [] 581 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 582 | 583 | for it, m in pbar: #enumerate(loader): 584 | labels = m.reshape((-1,1)).to(device) 585 | optimizer_AE.zero_grad() 586 | 587 | # Forward Pass: 588 | # Encoding 589 | m = F.one_hot(labels.long(), num_classes=M) 590 | m = m.float().reshape(-1,M) 591 | x = encoder(m) 592 | c = x 593 | # Channel 594 | 595 | x_seq = utils.p_sample_loop_w_Condition(channel_gen, [x.size()[0], n], tconf_AE.num_steps, alphas, 596 | betas, one_minus_alphas_bar_sqrt, c, 597 | pred_type=tconf_AE.pred_type, is_light=True) 598 | y = x_seq # Generated data after 100 time steps 599 | 600 | # Decoding 601 | m_d = decoder(y) 602 | 603 | # Backward pass: 604 | # Compute the loss and the gradient 605 | loss = loss_CE(m_d, labels.long().squeeze(1)) 606 | loss = loss.mean() 607 | loss.backward() 608 | 609 | torch.nn.utils.clip_grad_norm_(list(encoder.parameters())+list(decoder.parameters()), 1.) 610 | optimizer_AE.step() 611 | 612 | 613 | loss_batch.append(loss.item()) 614 | 615 | SER_tmp, _, _ = utils.SER(labels.long(), m_d, 0.9) 616 | batch_SER.append(SER_tmp.item()) 617 | 618 | if it%100==0: 619 | pbar.set_description(f"epoch {t+1}: loss {np.mean(loss_batch):.2e} SER {np.mean(batch_SER):.2e}") 620 | 621 | # Save epoch loss 622 | loss_ep_AE.append(np.mean(loss_batch)) 623 | SER_ep.append(np.mean(batch_SER)) 624 | return loss_ep_AE, SER_ep 625 | 626 | def test(self, channel_model): 627 | encoder = self.encoder 628 | decoder = self.decoder 629 | 630 | device = self.device 631 | tconf_AE = self.tconf_AE 632 | 633 | 634 | M = tconf_AE.M 635 | n = tconf_AE.n 636 | noise_std = tconf_AE.noise_std 637 | 638 | dataset_AE = np.random.randint(M, size=tconf_AE.dataset_size) 639 | 640 | SER_ep_AE = [] 641 | 642 | for t in range(1): 643 | loader = DataLoader(dataset_AE, shuffle=True, pin_memory=False, batch_size=tconf_AE.batch_size) 644 | batch_SER = [] 645 | pbar = tqdm(enumerate(loader), total=len(loader), mininterval=0.5) 646 | 647 | for it, m in pbar: 648 | labels = m.reshape((-1,1)).to(device) 649 | # Forward Pass: 650 | # Encoding 651 | m = F.one_hot(labels.long(), num_classes=M) 652 | m = m.float().reshape(-1,M) 653 | x = encoder(m) 654 | 655 | # Channel 656 | y = channel_model(x, noise_std, device) 657 | 658 | # Decoding 659 | m_d = decoder(y) 660 | 661 | SER_tmp, _, _ = utils.SER(labels.long(), m_d, 0.9) 662 | batch_SER.append(SER_tmp.item()) 663 | 664 | if it%100==0: 665 | pbar.set_description(f"Validation SER {np.mean(batch_SER):.2e}") 666 | 667 | # Save epoch loss 668 | SER_ep_AE.append(np.mean(batch_SER)) 669 | return np.mean(SER_ep_AE) 670 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Loss Function for Diffusion Model 2 | # Original Source: https://github.com/acids-ircam/diffusion_models 3 | 4 | import torch 5 | import numpy as np 6 | import sys 7 | import math 8 | assert sys.version_info >= (3, 5) 9 | import matplotlib as mpl 10 | import matplotlib.pyplot as plt 11 | import torch.nn.functional as F 12 | from torch import nn 13 | np.random.seed(42) 14 | 15 | 16 | def make_beta_schedule(schedule='linear', n_timesteps=1000, start=1e-5, end=1e-2): 17 | if schedule == 'linear': 18 | betas = torch.linspace(start, end, n_timesteps) 19 | elif schedule == "quad": 20 | betas = torch.linspace(start ** 0.5, end ** 0.5, n_timesteps) ** 2 21 | elif schedule == "sigmoid": 22 | betas = torch.linspace(-6, 6, n_timesteps) 23 | betas = torch.sigmoid(betas) * (end - start) + start 24 | 25 | elif schedule == "cosine": 26 | i = torch.linspace(0, 1, n_timesteps+1) 27 | f_t = torch.cos( (i + 0.008) / (1+0.008) * math.pi / 2 ) ** 2 28 | alphas_bar = f_t / f_t[0] 29 | alphas_bar_before = alphas_bar[:-1] 30 | alphas_bar = alphas_bar[1:] 31 | 32 | betas = 1 - alphas_bar / alphas_bar_before 33 | for i in range(n_timesteps): 34 | if betas[i] > 0.999: 35 | betas[i] = 0.999 36 | elif schedule == "cosine-zf": 37 | i = torch.linspace(0, 1, n_timesteps+1) 38 | f_t = torch.cos( (i + 0.008) / (1+0.008) * math.pi / 2 ) ** 2 39 | alphas_bar = f_t / f_t[0] 40 | alphas_bar_before = alphas_bar[:-1] 41 | alphas_bar = alphas_bar[1:] 42 | 43 | betas = 1 - alphas_bar / alphas_bar_before 44 | 45 | return betas 46 | 47 | def extract(input, t, shape): 48 | out = torch.gather(input, 0, t.to(input.device)) 49 | reshape = [t.shape[0]] + [1] * (len(shape) - 1) 50 | return out.reshape(*reshape).to(input.device) 51 | 52 | def q_posterior_mean_variance(x_0, x_t, t,posterior_mean_coef_1,posterior_mean_coef_2,posterior_log_variance_clipped): 53 | shape = x_0.shape 54 | coef_1 = extract(posterior_mean_coef_1, t, shape) 55 | coef_2 = extract(posterior_mean_coef_2, t, shape) 56 | mean = coef_1 * x_0 + coef_2 * x_t 57 | var = extract(posterior_log_variance_clipped, t, shape) 58 | return mean, var 59 | 60 | def p_mean_variance(model, x, t): 61 | # Go through model 62 | out = model(x, t) 63 | # Extract the mean and variance 64 | mean, log_var = torch.split(out, 2, dim=-1) 65 | var = torch.exp(log_var) 66 | return mean, log_var 67 | 68 | def p_sample_w_Condition(model, z, t, alphas, betas, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, c): 69 | device = model.this_device 70 | z = z.to(device) 71 | c = c.to(device) 72 | shape = z.shape 73 | t = torch.tensor([t]).to(device) 74 | 75 | ## The commented lines below are necessary only when the full version of sigma_t is used. 76 | #a = extract(alphas_bar_sqrt, t, shape) 77 | #am1 = extract(one_minus_alphas_bar_sqrt, t, shape) 78 | #if t > 0: 79 | # a_next = extract(alphas_bar_sqrt, t - 1, shape) 80 | # am1_next = extract(one_minus_alphas_bar_sqrt, t - 1, shape) 81 | #else: 82 | # a_next = torch.ones_like(z) 83 | # am1_next = torch.zeros_like(z) 84 | 85 | # Factor to the model output 86 | eps_factor = ((1 - extract(alphas, t, shape)) / extract(one_minus_alphas_bar_sqrt, t, shape)) 87 | eps_factor = eps_factor.to(device) 88 | # Model output 89 | eps_theta = model(z, t, c) 90 | # Final values 91 | mean = (1 / extract(alphas, t, shape).sqrt().to(device)) * (z - (eps_factor * eps_theta)) 92 | # Generate epsilon 93 | e = torch.randn_like(z).to(device) 94 | # Fixed sigma 95 | sigma_t = extract(betas, t, shape).sqrt().to(device) # Simplified version 96 | #sigma_t = ( 1 - a ** 2 / a_next ** 2) ** 0.5 * am1_next / am1 # full version 97 | 98 | sample = mean.to(device) + sigma_t.to(device) * e 99 | sample = sample.to(device) 100 | return (sample) 101 | 102 | 103 | def p_sample_w_Condition_v(model, z, t, alphas, betas, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, c): # stochastic sampling with condition 104 | 105 | device = model.this_device 106 | z = z.to(device) 107 | c = c.to(device) 108 | shape = z.shape 109 | t = torch.tensor([t]).to(device) 110 | 111 | a = extract(alphas_bar_sqrt, t, shape) 112 | am1 = extract(one_minus_alphas_bar_sqrt, t, shape) 113 | 114 | if t > 0: 115 | a_next = extract(alphas_bar_sqrt, t - 1, shape) 116 | am1_next = extract(one_minus_alphas_bar_sqrt, t - 1, shape) 117 | else: 118 | a_next = torch.ones_like(z) 119 | am1_next = torch.zeros_like(z) 120 | # Model output 121 | v_theta = model(z, t, c) 122 | e = torch.randn_like(z) 123 | 124 | # Generate Sample 125 | e_hat = am1 * z + a * v_theta 126 | x_hat = a * z - am1 * v_theta 127 | sample = a_next * x_hat + am1_next ** 2 * a / am1 / a_next * e_hat + ( 128 | 1 - a ** 2 / a_next ** 2) ** 0.5 * am1_next / am1 * e 129 | 130 | return sample 131 | 132 | 133 | def p_sample_loop_w_Condition(model, shape, n_steps, alphas, betas, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, c, 134 | pred_type='epsilon', is_light = False): 135 | device = model.this_device 136 | cur_x = torch.randn(shape).to(device) 137 | c = c.to(device) 138 | alphas = alphas.to(device) 139 | betas = betas.to(device) 140 | alphas_bar_sqrt = alphas_bar_sqrt.to(device) 141 | one_minus_alphas_bar_sqrt = one_minus_alphas_bar_sqrt.to(device) 142 | 143 | if pred_type == 'epsilon': 144 | p_sample_func = p_sample_w_Condition 145 | elif pred_type == 'v': 146 | p_sample_func = p_sample_w_Condition_v 147 | if not is_light: 148 | x_seq = [cur_x] 149 | for i in reversed(range(n_steps)): 150 | cur_x = p_sample_func(model, cur_x, i, alphas, betas, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, c) 151 | if not is_light: 152 | x_seq.append(cur_x) 153 | if not is_light: 154 | return x_seq 155 | else: return cur_x 156 | 157 | def p_sample_w_Condition_DDIM(model, xt, i, j, alphas_prod, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, shape, device, c): 158 | at = extract(alphas_prod, torch.tensor([i]).to(device), shape).to(device) 159 | if j > -1: 160 | at_next = extract(alphas_prod, torch.tensor([j]), shape).to(device) 161 | else: 162 | at_next = torch.ones(shape).to(device) # a_0 = 1 163 | 164 | et = model(xt, torch.tensor([i]).to(device), c) 165 | xt_next = at_next.sqrt() * (xt - et * (1 - at).sqrt()) / at.sqrt() + (1 - at_next).sqrt() * et 166 | return xt_next 167 | 168 | def p_sample_w_Condition_DDIM_v(model, xt, i, j, alphas_prod, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, shape, device, c): 169 | a = extract(alphas_bar_sqrt, torch.tensor([i]).to(device), shape).to(device) 170 | am1 = extract(one_minus_alphas_bar_sqrt, torch.tensor([i]).to(device), shape).to(device) 171 | if j > -1: 172 | a_next = extract(alphas_bar_sqrt, torch.tensor([j]), shape).to(device) 173 | am1_next = extract(one_minus_alphas_bar_sqrt, torch.tensor([j]), shape).to(device) 174 | else: 175 | a_next = torch.ones(shape).to(device) # a_0 = 1 176 | am1_next = torch.zeros(shape).to(device) # 1-sqrt(a_0) = 0 177 | z = x_seq[-1].to(device) 178 | v_theta = model(z, torch.tensor([i]).to(device), c) 179 | 180 | # Generate Sample 181 | x_hat = a * z - am1 * v_theta 182 | z_next = a_next * x_hat + am1_next * (z - a * x_hat) / am1 183 | 184 | def p_sample_loop_w_Condition_DDIM(model, shape, traj, alphas_prod, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, c, 185 | pred_type='epsilon'): 186 | device = model.this_device 187 | cur_x = torch.randn(shape).to(device) # randomly sampled x_T 188 | c = c.to(device) 189 | alphas_prod = alphas_prod.to(device) 190 | alphas_bar_sqrt = alphas_bar_sqrt.to(device) 191 | one_minus_alphas_bar_sqrt = one_minus_alphas_bar_sqrt.to(device) 192 | 193 | 194 | traj_next = [-1] + list(traj[:-1]) # t-1 195 | 196 | x_seq = [cur_x] 197 | 198 | if pred_type == 'epsilon': 199 | p_sample_func_DDIM = p_sample_w_Condition_DDIM 200 | elif pred_type == 'v': 201 | p_sample_func_DDIM = p_sample_w_Condition_DDIM_v 202 | 203 | for i, j in zip(reversed(traj), reversed(traj_next)): # i = t, j = t-1 204 | xt = x_seq[-1].to(device) 205 | xt_next = p_sample_func_DDIM(model, xt, i, j, alphas_prod, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, 206 | shape, device, c) 207 | x_seq.append(xt_next) 208 | 209 | return x_seq 210 | 211 | def q_sample(x_0, t, alphas_bar_sqrt, one_minus_alphas_bar_sqrt ,noise=None): 212 | shape = x_0.shape 213 | if noise is None: 214 | noise = torch.randn_like(x_0) 215 | alphas_t = extract(alphas_bar_sqrt, t, shape) 216 | alphas_1_m_t = extract(one_minus_alphas_bar_sqrt, t, shape) 217 | return (alphas_t * x_0 + alphas_1_m_t * noise) 218 | 219 | def noise_estimation_loss(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps, n): 220 | device = model.this_device 221 | shape = x_0.shape 222 | batch_size = x_0.shape[0] 223 | x = x_0[:, 0:n].to(device) 224 | c = x_0[:, n:].to(device) 225 | 226 | # Select a random step for each example 227 | t = torch.randint(0, n_steps, size=(batch_size,)).long().to(device) 228 | # x0 multiplier 229 | a = extract(alphas_bar_sqrt, t, shape).to(device) 230 | # eps multiplier 231 | am1 = extract(one_minus_alphas_bar_sqrt, t, shape).to(device) 232 | e = torch.randn_like(x).to(device) 233 | # model input 234 | y = x * a + e * am1 235 | output = model(y, t, c) 236 | return (e - output).square().mean() 237 | 238 | 239 | def v_estimation_loss(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps, n, 240 | losstype='no_weighting'): # , prediction_var = 'x'): 241 | 242 | assert isinstance(losstype, str) 243 | # assert isinstance(prediction_var, str) 244 | 245 | device = model.this_device 246 | shape = x_0.shape 247 | batch_size = x_0.shape[0] 248 | x = x_0[:, 0:n].to(device) 249 | c = x_0[:, n:].to(device) 250 | 251 | # Select a random step for each example 252 | t = torch.randint(0, n_steps, size=(batch_size,)).long().to(device) 253 | # x0 multiplier 254 | a = extract(alphas_bar_sqrt, t, shape).to(device) 255 | # eps multiplier 256 | am1 = extract(one_minus_alphas_bar_sqrt, t, shape).to(device) 257 | e = torch.randn_like(x).to(device) 258 | 259 | # model input - the noisy sample 260 | y = x * a + e * am1 261 | output = model(y, t, c) 262 | 263 | # loss weighting 264 | if losstype == 'SNR': 265 | weight_t = a ** 2 / am1 ** 2 266 | elif losstype == 'SNRp1': 267 | weight_t = a ** 2 / am1 ** 2 + torch.ones_like(x) 268 | elif losstype == 'SNRtr': 269 | weight_t = a ** 2 / am1 ** 2 270 | weight_t[weight_t < 1] = 1 271 | elif losstype == 'no_weighting': 272 | weight_t = 1 273 | else: 274 | raise ValueError('Unrecognized weight type. Possible options are SNR, SNRp1, SNRtr, and no_weighting.') 275 | 276 | v = e * a - x * am1 277 | 278 | loss = (v - output).square() 279 | 280 | return (weight_t * loss).mean() # (v - output).square().mean() 281 | 282 | def EbNo_to_noise(ebnodb, rate): 283 | '''Transform EbNo[dB]/snr to noise power''' 284 | ebno = 10**(ebnodb/10) 285 | noise_std = 1/np.sqrt(2*rate*ebno) 286 | return noise_std 287 | 288 | def SNR_to_noise(snrdb): 289 | '''Transform EbNo[dB]/snr to noise power''' 290 | snr = 10**(snrdb/10) 291 | noise_std = 1/np.sqrt(2*snr) 292 | return noise_std 293 | 294 | def SER(input_msg, msg, erasure_bound=0.9): 295 | 296 | '''Calculate the Symbol Error Rate''' 297 | batch, _ = msg.size() 298 | 299 | # count erasures and give indices without erasures 300 | smax = torch.nn.Softmax(dim=1) 301 | out, _ = torch.topk(smax(msg), 1, axis=1) 302 | indices_erasures = (out <= erasure_bound).flatten()#.nonzero() 303 | #erasures = out.size()[0]-indices_non_erasures.size()[0] 304 | 305 | #erasures = torch.count_nonzero(out < erasure_bound).item() 306 | 307 | # cut erasures from prediction 308 | #msg_wo_er = torch.index_select(msg, 0, indices_non_erasures.flatten()) 309 | #input_msg_wo_er = torch.index_select(torch.flatten(input_msg), 0, indices_non_erasures.flatten()) 310 | 311 | #print("input_msg", input_msg) 312 | #print("msg", msg.argmax(dim=1)) 313 | 314 | input_msg_wo_er = torch.flatten(input_msg) 315 | msg_wo_er = msg 316 | 317 | pred_error = torch.ne(input_msg_wo_er, msg_wo_er.argmax(dim=1)) 318 | block_er = torch.sum(pred_error) 319 | ser = block_er/batch 320 | return ser, pred_error, indices_erasures 321 | 322 | def qam16_mapper_n(m): 323 | # m takes in a vector of messages 324 | x = np.linspace(-3,3,4) 325 | y = np.meshgrid(x,x) 326 | z = np.array(y).reshape(2,16) 327 | z = z / np.std(z) 328 | return np.array([[z[0][i],z[1][i]] for i in m]) 329 | 330 | class LoadModels: 331 | def __init__(self, path, Load_model_gen, model_gen, tag_load_gen, Load_model_AE, model_enc, model_dec, tag_load_AE): 332 | self.Load_model_gen = Load_model_gen 333 | self.model_gen = model_gen 334 | self.dir_gen = path + '/channel_gen' + tag_load_gen 335 | self.Load_model_AE = Load_model_AE 336 | self.model_enc = model_enc 337 | self.model_dec = model_dec 338 | self.dir_enc = path + '/encoder' + tag_load_AE 339 | self.dir_dec = path + '/decoder' + tag_load_AE 340 | 341 | def load(self): 342 | if self.Load_model_gen: 343 | self.model_gen.load_state_dict(torch.load(self.dir_gen)) 344 | self.model_gen.eval() 345 | if self.Load_model_AE: 346 | self.model_enc.load_state_dict(torch.load(self.dir_enc)) 347 | self.model_enc.eval() 348 | self.model_dec.load_state_dict(torch.load(self.dir_dec)) 349 | self.model_dec.eval() 350 | return self.model_gen, self.model_enc, self.model_dec 351 | 352 | class SaveModels: 353 | def __init__(self, path, Save_model_gen, model_gen, tag_save_gen, Save_model_AE, model_enc, model_dec, tag_save_AE): 354 | self.Save_model_gen = Save_model_gen 355 | self.model_gen = model_gen 356 | self.dir_gen = path + '/channel_gen' + tag_save_gen 357 | self.Save_model_AE = Save_model_AE 358 | self.model_enc = model_enc 359 | self.model_dec = model_dec 360 | self.dir_enc = path + '/encoder' + tag_save_AE 361 | self.dir_dec = path + '/decoder' + tag_save_AE 362 | 363 | def save_gen(self): 364 | if self.Save_model_gen: 365 | torch.save(self.model_gen.state_dict(), self.dir_gen) 366 | def save_AE(self): 367 | if self.Save_model_AE: 368 | torch.save(self.model_enc.state_dict(), self.dir_enc) 369 | torch.save(self.model_dec.state_dict(), self.dir_dec) 370 | 371 | 372 | 373 | -------------------------------------------------------------------------------- /src/utils_plot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import numpy as np 5 | import sys 6 | import math 7 | assert sys.version_info >= (3, 5) 8 | import matplotlib as mpl 9 | import matplotlib.pyplot as plt 10 | from utils import p_sample_loop_w_Condition, p_sample_loop_w_Condition_DDIM 11 | from statsmodels.distributions.empirical_distribution import ECDF 12 | 13 | mpl.rcParams.update(mpl.rcParamsDefault) 14 | mpl.rcParams['pdf.fonttype'] = 42 15 | mpl.rcParams['ps.fonttype'] = 42 16 | mpl.rcParams['font.family'] = "serif" 17 | 18 | mpl.rc('axes', labelsize=14) 19 | mpl.rc('xtick', labelsize=12) 20 | mpl.rc('ytick', labelsize=12) 21 | np.random.seed(42) 22 | from channel_models import * 23 | 24 | 25 | def Constellation_InOut_m(enc, ch_model, ch_model_type, device, alphas, betas, alphas_bar_sqrt, 26 | one_minus_alphas_bar_sqrt, 27 | num_steps, M, num_samples, noise_std, max_amp, m=0, PreT = False, 28 | denoising_alg = 'DDPM', traj = range(1), IS_RES = False, pred_type='epsilon'): 29 | labels = m * torch.ones(1) 30 | msg_1h = F.one_hot(labels.long(), num_classes=M) 31 | msg_1h = msg_1h.float().reshape(-1,M) 32 | x = enc(msg_1h).repeat(num_samples,1) 33 | 34 | n = x.shape[1] # channel dimension 35 | 36 | if ch_model_type == 'AWGN': 37 | channel_model = ch_AWGN 38 | elif ch_model_type == 'Rayleigh': 39 | channel_model = ch_Rayleigh_AWGN_n 40 | elif ch_model_type == 'SSPA': 41 | channel_model = ch_SSPA 42 | 43 | else: 44 | raise Exception('Unrecognized channel model. Available models: AWGN, Rayleigh, SSPA.') 45 | 46 | y = channel_model(x,noise_std,device) 47 | 48 | if PreT: 49 | c = x 50 | else: 51 | c = torch.cat((labels.to(device), x ),dim=1) 52 | if denoising_alg == 'DDPM': 53 | x_seq = p_sample_loop_w_Condition(ch_model, x.size(), num_steps, alphas, betas, alphas_bar_sqrt, 54 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 55 | elif denoising_alg == 'DDIM': 56 | alphas_prod = torch.cumprod(alphas, 0) 57 | x_seq = p_sample_loop_w_Condition_DDIM(ch_model, x.size(), traj, alphas_prod, alphas_bar_sqrt, 58 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 59 | 60 | y_gen = x_seq[-1] 61 | if IS_RES: 62 | y_gen = y_gen + c 63 | 64 | y_gen = y_gen 65 | 66 | x = x.detach().cpu().numpy() 67 | y = y.detach().cpu().numpy() 68 | y_gen = y_gen.detach().cpu().numpy() 69 | 70 | print("The scatter plot of the generated data in comparison to the original source; message", m+1, flush=True) 71 | 72 | if n//2 ==1 : 73 | fig = plt.figure(figsize=(4,4)) 74 | plt.scatter(y_gen[:,0], y_gen[:,1], edgecolor=(1,0,0,0.8), label='generated', facecolor=(1,1,1,0.5)) 75 | plt.scatter(y[:,0], y[:,1], edgecolor=(0,0.3,0.3,0.7), label='ground truth', facecolor=(1,1,1,0.5)) 76 | plt.scatter(x[:,0], x[:,1], edgecolor=(0,0,0,0.8), label='channel input',facecolor=(1,1,1,0.5) ) 77 | plt.xlim([-max_amp, max_amp]) 78 | plt.ylim([-max_amp, max_amp]) 79 | plt.xlabel('$Re(y)$') 80 | plt.ylabel('$Im(y)$') 81 | title = 'message '+ str(m+1) 82 | plt.title(title) 83 | plt.legend() 84 | else: 85 | fig, ax = plt.subplots(1,n//2,figsize=(10,4)) 86 | for j in range(n // 2): 87 | ax[j].scatter(y_gen[:,2*j], y_gen[:,2*j+1], edgecolor=(1,0,0,0.8), label='generated', facecolor=(1,1,1,0.5)) 88 | ax[j].scatter(y[:,2*j], y[:,2*j+1], edgecolor=(0,0.3,0.3,0.7), label='ground truth', facecolor=(1,1,1,0.5)) 89 | ax[j].scatter(x[:,2*j], x[:,2*j+1], edgecolor=(0,0,0,0.8), label='channel input',facecolor=(1,1,1,0.5) ) 90 | ax[j].set_xlim([-max_amp, max_amp]) 91 | ax[j].set_ylim([-max_amp, max_amp]) 92 | ax[j].set_xlabel('$Re(y_'+ str(j+1)+ ')$') 93 | ax[j].set_ylabel('$Im(y_'+ str(j+1)+')$') 94 | title = '$m = '+ str(m+1)+'$' 95 | ax[j].set_title(title) 96 | ax[j].legend() 97 | plt.show() 98 | return 99 | 100 | 101 | def Constellation_InOut(enc, ch_model, ch_model_type, device, alphas, betas, alphas_bar_sqrt, 102 | one_minus_alphas_bar_sqrt, 103 | num_steps, M, num_samples, noise_std, max_amp, PreT = False, 104 | denoising_alg = 'DDPM', traj = range(1), cmap_str = 'nipy_spectral', 105 | IS_RES = False, pred_type = 'epsilon'): 106 | 107 | labels = torch.empty((num_samples*M,1), dtype = int) 108 | for i in range(M): 109 | labels[num_samples*i:num_samples*(i+1),:] = torch.full([num_samples,1],i,dtype=int) 110 | 111 | m = F.one_hot(labels, num_classes=M) 112 | m = m.float().reshape(-1,M) 113 | x = enc(m) #to(device) 114 | n = x.shape[1] # channel dimension 115 | colors = 100//M*np.arange(M)/100. #np.arange(1, 100, 100//M, dtype=int)[:M] 116 | cmap = plt.cm.get_cmap(cmap_str) 117 | colors_rgba = cmap(colors) 118 | 119 | 120 | if ch_model_type == 'AWGN': 121 | channel_model = ch_AWGN 122 | elif ch_model_type == 'Rayleigh': 123 | channel_model = ch_Rayleigh_AWGN_n 124 | elif ch_model_type == 'SSPA': 125 | channel_model = ch_SSPA 126 | else: 127 | raise Exception('Unrecognized channel model. Available models: AWGN, Rayleigh, SSPA.') 128 | 129 | y = channel_model(x,noise_std,device) 130 | 131 | if PreT: 132 | c = x 133 | else: 134 | c = torch.cat((labels.to(device), x ),dim=1) #repeat(num_samples*M,1) 135 | #mx = torch.cat((m, x),dim=1) 136 | if denoising_alg == 'DDPM': 137 | x_seq = p_sample_loop_w_Condition(ch_model, x.size(), num_steps, alphas, betas, alphas_bar_sqrt, 138 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 139 | elif denoising_alg == 'DDIM': 140 | alphas_prod = torch.cumprod(alphas, 0) 141 | x_seq = p_sample_loop_w_Condition_DDIM(ch_model, x.size(), traj, alphas_prod, alphas_bar_sqrt, 142 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 143 | else: 144 | assert(ValueError) 145 | 146 | if IS_RES: 147 | y_gen = x_seq[-1] + c 148 | else: y_gen = x_seq[-1] 149 | 150 | y_gen = y_gen 151 | 152 | x = x.detach().cpu().numpy() 153 | y = y.detach().cpu().numpy() 154 | y_gen = y_gen.detach().cpu().numpy() 155 | if n // 2 == 1: 156 | fig = plt.figure(figsize=(7, 7)) 157 | else: 158 | fig, ax = plt.subplots( 1, n // 2, figsize=( 3 * n // 2, 3)) 159 | 160 | if n//2 ==1 : 161 | plt.scatter(y_gen[:,0], y_gen[:,1], edgecolor = np.repeat(colors_rgba, num_samples , axis=0), 162 | marker="^", s = 25, facecolor='none', label = 'generated') 163 | plt.scatter(y[:,0], y[:,1], color = np.repeat(colors_rgba, num_samples , axis=0), marker="x", 164 | s = 25, label='ground truth') #c=np.repeat(colors, num_samples), facecolor='none') 165 | plt.scatter(x[:,0], x[:,1], color = np.repeat(colors_rgba, num_samples , axis=0), 166 | edgecolor = 'black', label= 'channel input', s = 25 ) 167 | plt.legend() 168 | 169 | else: 170 | for j in range(n // 2): 171 | ax[j].scatter(y_gen[:,2*j], y_gen[:,2*j+1], edgecolor = np.repeat(colors_rgba, num_samples , axis=0), 172 | marker="^", s = 25, facecolor='none', label = 'generated') 173 | ax[j].scatter(y[:,2*j], y[:,2*j+1], color = np.repeat(colors_rgba, num_samples , axis=0), 174 | marker="x", s = 25, label='ground truth') 175 | ax[j].scatter(x[:,2*j], x[:,2*j+1], color = np.repeat(colors_rgba, num_samples , axis=0), 176 | edgecolor = 'black', label= 'channel input', s = 25 ) 177 | ax[j].legend() 178 | 179 | if n // 2 == 1: 180 | plt.xlim([-max_amp, max_amp]) 181 | plt.ylim([-max_amp, max_amp]) 182 | plt.xlabel('$Re(y)$') 183 | plt.ylabel('$Im(y)$') 184 | title = 'Channel Output Constellation' 185 | plt.title(title) 186 | else: 187 | for j in range(n // 2): 188 | ax[j].set_xlim([-max_amp, max_amp]) 189 | ax[j].set_ylim([-max_amp, max_amp]) 190 | ax[j].set_xlabel('$Re(y_' + str(j + 1) + ')$') 191 | ax[j].set_ylabel('$Im(y_' + str(j + 1) + ')$') 192 | title = 'Channel Output Constellation' 193 | ax[j].set_title(title) 194 | plt.show() 195 | 196 | fig.savefig('figures/constellation_inout_'+ch_model_type+'.pdf', format='pdf') 197 | return 198 | 199 | def Constellation_Enc(enc, M): 200 | # Show the constellation of the learned encoder 201 | # Arrays for plotting 202 | bins = 30 #np.linspace(0.0, noise_std*4, 20) 203 | 204 | labels = torch.arange(0,M, dtype=int) 205 | m = F.one_hot(labels, num_classes=M) 206 | m = m.float().reshape(-1,M) 207 | x = enc(m).detach().cpu().numpy() 208 | 209 | n = x.shape[1] 210 | 211 | print("Constellation of learned encoder", flush=True) 212 | if n//2==1: 213 | plt.figure(figsize=(4,4)) 214 | for j in range(M): 215 | plt.scatter(x[j,0],x[j,1]) 216 | plt.xlabel('$Re(x)$') 217 | plt.ylabel('$Im(x)$') 218 | else: 219 | fig, axes = plt.subplots(1,n//2,figsize=(15,4)) 220 | for i in range(n //2): 221 | for j in range(M): 222 | axes[i].scatter(x[j,2*i],x[j,2*i+1]) 223 | axes[i].set_xlabel('$Re(x_' + str(i+1) + ')$') 224 | axes[i].set_ylabel('$Im(x_'+ str(i+1) +')$') 225 | axes[i].set_title("channel #"+ str(i+1)) 226 | 227 | plt.show() 228 | 229 | return 230 | 231 | def ECDF_histogram_m(enc, ch_model, ch_model_type, device, alphas, betas, alphas_bar_sqrt, 232 | one_minus_alphas_bar_sqrt, 233 | num_steps, M, num_samples, batch_size, noise_std, max_amp, m=0, 234 | PreT = False, denoising_alg = 'DDPM', traj = range(1), pred_type='epsilon'): 235 | figsize=(6.5, 2.5) 236 | # Test: Channel generation - ECDF of the channel output 237 | 238 | # Check the distribution of the amplitude for one dimension WLOG 239 | 240 | # Arrays for plotting 241 | z = np.linspace(0,max_amp,100) 242 | bins = 30 #np.linspace(0.0, noise_std*4, 20) 243 | alphas = alphas.to(device) 244 | betas = betas.to(device) 245 | one_minus_alphas_bar_sqrt = one_minus_alphas_bar_sqrt.to(device) 246 | 247 | 248 | y_amp = [] 249 | 250 | labels = torch.full([num_samples,1],m,dtype=int) 251 | 252 | msg_1h = F.one_hot(labels, num_classes=M) 253 | msg_1h = msg_1h.float().reshape(-1,M) 254 | x = enc(msg_1h) 255 | 256 | # True channel output y 257 | 258 | if ch_model_type == 'AWGN': 259 | channel_model = ch_AWGN 260 | elif ch_model_type == 'Rayleigh': 261 | channel_model = ch_Rayleigh_AWGN_n 262 | elif ch_model_type == 'SSPA': 263 | channel_model = ch_SSPA 264 | else: 265 | raise Exception('Unrecognized channel model. Available models: AWGN, Rayleigh, SSPA.') 266 | 267 | y = channel_model(x,noise_std,device) 268 | 269 | y_amp.append(torch.sqrt(torch.sum(torch.pow(y,2), dim=1))) 270 | y_amp = [i.detach().cpu().numpy() for i in y_amp] 271 | y_amp = np.asarray(y_amp).reshape((-1,)) 272 | 273 | y_gen_amp = [] 274 | for i_btch in range(num_samples // batch_size): 275 | 276 | # Generated channel output y 277 | labels_batch = labels[batch_size*i_btch:batch_size*(i_btch+1)].to(device) 278 | x_batch = x[batch_size*i_btch:batch_size*(i_btch+1)].to(device) 279 | 280 | if PreT: 281 | c = x_batch 282 | else: 283 | c = torch.cat((labels_batch.reshape([-1,1]).to(device),x_batch), dim=1).to(device) 284 | 285 | if denoising_alg == 'DDPM': 286 | x_seq = p_sample_loop_w_Condition(ch_model, x_batch.size(), num_steps, alphas, betas, alphas_bar_sqrt, 287 | one_minus_alphas_bar_sqrt, c, pred_type=pred_type) 288 | elif denoising_alg == 'DDIM': 289 | alphas_prod = torch.cumprod(alphas, 0) 290 | x_seq = p_sample_loop_w_Condition_DDIM(ch_model, x_batch.size(), traj, alphas_prod, alphas_bar_sqrt, 291 | one_minus_alphas_bar_sqrt, c, pred_type=pred_type) 292 | else: 293 | assert(ValueError) 294 | 295 | y_gen = x_seq[-1].detach().cpu() 296 | y_gen = y_gen 297 | 298 | # Change them into the polar coordinate and append in lists 299 | y_gen_amp.append(torch.sqrt(torch.sum(torch.pow(y_gen,2), dim=1) ).numpy() ) 300 | 301 | # List of tensors -> numpy array 302 | 303 | y_gen_amp = np.asarray(y_gen_amp).reshape((-1,)) 304 | 305 | print("The evaluation of generated channel in comparison to the original source; m =", m+1, flush=True) 306 | # Compute the empirical CDF of the ground truth random variables and the generated random variables 307 | y_ecdf = ECDF(y_amp) 308 | y_gen_ecdf = ECDF(y_gen_amp) 309 | fig, axes = plt.subplots(1,2,figsize= figsize, constrained_layout=True) 310 | 311 | 312 | axes[1].plot(z, y_gen_ecdf(z), label = 'Generated', color = (1,0,0,0.8)) 313 | axes[1].plot(z, y_ecdf(z), '--', label = 'True', color = (0.3,0.3,0.3,0.7)) 314 | axes[1].legend() 315 | axes[1].set_xlabel('$|y|$') 316 | axes[1].set_ylabel('Empirical CDF') 317 | 318 | axes[0].hist(y_gen_amp, bins=bins, edgecolor=(1,0,0,0.8), label='Generated',facecolor=(1,1,1,0.5)) 319 | axes[0].hist(y_amp, bins=bins, edgecolor=(0.3,0.3,0.3,0.7), linestyle = '--', label='True', facecolor=(1,1,1,0.5)) 320 | axes[0].set_ylabel('Frequency') 321 | 322 | fig.tight_layout() 323 | plt.legend() 324 | plt.savefig('figures/ECDF_histogram_'+ch_model_type+'.pdf', bbox_inches = 'tight') 325 | plt.show() 326 | return 327 | 328 | 329 | def show_denoising_m(channel_gen, encoder, device, alphas, betas, alphas_bar_sqrt, 330 | one_minus_alphas_bar_sqrt, M =16, m = 0, 331 | num_samples_dn = 1000, num_steps = 50, max_amp =3, PreT = False, 332 | denoising_alg = 'DDPM', traj = range(1), pred_type = 'epsilon'): 333 | print('The denoising process is plotted for the message '+str(m+1)+' and the first two dimensions.') 334 | if m > M : 335 | raise Exception("The message index m should be smaller than or equal to " +str(M)+".") 336 | 337 | labels = m * torch.ones(1) 338 | msg_1h = F.one_hot(labels.long(), num_classes=M) 339 | msg_1h = msg_1h.float().reshape(-1,M) 340 | x = encoder(msg_1h).repeat(num_samples_dn,1) 341 | #n = x.shape[1] 342 | 343 | if PreT: 344 | c = x 345 | else: 346 | c = torch.cat((labels.repeat(num_samples_dn,1).to(device), x ),dim=1) 347 | #mx = torch.cat((m, x),dim=1) 348 | if denoising_alg == 'DDPM': 349 | x_seq = p_sample_loop_w_Condition(channel_gen, x.size(), num_steps, alphas, betas, alphas_bar_sqrt, 350 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 351 | elif denoising_alg == 'DDIM': 352 | alphas_prod = torch.cumprod(alphas, 0) 353 | x_seq = p_sample_loop_w_Condition_DDIM(channel_gen, x.size(), traj, alphas_prod, alphas_bar_sqrt, 354 | one_minus_alphas_bar_sqrt, c, pred_type = pred_type) 355 | 356 | len_seq = len(x_seq) - 1 357 | max_norm = 0. 358 | fig, axs = plt.subplots(1, 10, figsize=(28, 3)) 359 | fig1, axs1 = plt.subplots(1, 10, figsize=(28, 3)) 360 | for i in range(1, 11): 361 | cur_x = x_seq[i * len_seq // 10].detach().cpu() 362 | axs[i-1].scatter(cur_x[:, 0], cur_x[:, 1], color='white', edgecolor='gray', s=5); 363 | #axs[i-1].set_axis_off(); 364 | axs[i-1].set_title('$q(\mathbf{x}_{'+str(i*len_seq//10)+'})$') 365 | amp = torch.norm(cur_x, dim=1).numpy() 366 | max_norm = max( max_norm, np.max(amp)) 367 | axs[i-1].set_xlim([-max_amp, max_amp]) 368 | axs[i-1].set_ylim([-max_amp, max_amp]) 369 | axs1[i-1].hist(amp,bins=20,edgecolor='gray'); 370 | #axs1[i-1].set_axis_off(); 371 | 372 | axs1[i-1].set_title('$q(\mathbf{x}_{'+str(i*num_steps//10)+'})$') 373 | 374 | for i in range(1,11): 375 | axs1[i-1].set_xlim([0, max_norm]) 376 | plt.show() 377 | 378 | 379 | def show_diffusion_m(channel_model, q_x, encoder, device, noise_std, M=16, m = 0, num_samples_df=1000, 380 | num_steps=50, max_amp = 3): 381 | if m>M: 382 | raise Exception("The message index m should be smaller than or equal to " +str(M)+".") 383 | 384 | print("Diffusion process for message " + str(m+1) + " in the first two dimensions.") 385 | msg_1h = torch.zeros([num_samples_df, M]).to(device) 386 | msg_1h[:,m] = 1 # One hot encoding for message with index 1. 387 | 388 | dataset_test_diff = channel_model(encoder(msg_1h) , noise_std, device) 389 | 390 | fig, axs = plt.subplots(1, 10, figsize=(28, 3)) 391 | for i in range(10): 392 | q_i = q_x(dataset_test_diff, torch.tensor([i * num_steps//10]).to(device)) 393 | q_i = torch.Tensor.cpu(q_i).detach().numpy() 394 | axs[i].scatter(q_i[:, 0], q_i[:, 1],color='white',edgecolor='gray', s=5); 395 | #axs[i].set_axis_off(); axs[i].set_title('$q(\mathbf{x}_{'+str(i*10)+'})$') 396 | axs[i].set_xlim([-max_amp, max_amp]) 397 | axs[i].set_ylim([-max_amp, max_amp]) 398 | plt.show() 399 | 400 | print("Adjust beta according to the diffusion results. If it needs to be more noisy, increase beta. If it gets noisy too soon, decrease beta.") 401 | 402 | 403 | 404 | -------------------------------------------------------------------------------- /state_dict/channel_gen_AWGN_M16N7_T100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/state_dict/channel_gen_AWGN_M16N7_T100.pt -------------------------------------------------------------------------------- /state_dict/channel_gen_Rayleigh_M16N7_T100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/state_dict/channel_gen_Rayleigh_M16N7_T100.pt -------------------------------------------------------------------------------- /state_dict/channel_gen_SSPA_M64N8_T100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/state_dict/channel_gen_SSPA_M64N8_T100.pt -------------------------------------------------------------------------------- /state_dict/decoder_AWGN_M16N7_T100DDIM20.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/state_dict/decoder_AWGN_M16N7_T100DDIM20.pt -------------------------------------------------------------------------------- /state_dict/decoder_Rayleigh_M16N7_DDPM_T100DDIM20.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/state_dict/decoder_Rayleigh_M16N7_DDPM_T100DDIM20.pt -------------------------------------------------------------------------------- /state_dict/decoder_SSPA_M64N8_T100DDPM1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/state_dict/decoder_SSPA_M64N8_T100DDPM1.pt -------------------------------------------------------------------------------- /state_dict/encoder_AWGN_M16N7_T100DDIM20.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/state_dict/encoder_AWGN_M16N7_T100DDIM20.pt -------------------------------------------------------------------------------- /state_dict/encoder_Rayleigh_M16N7_DDPM_T100DDIM20.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/state_dict/encoder_Rayleigh_M16N7_DDPM_T100DDIM20.pt -------------------------------------------------------------------------------- /state_dict/encoder_SSPA_M64N8_T100DDPM1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muahkim/learning-E2E-channel-coding-with-diffusion-models/92c2219fc625c3553f02e7d5fda323bf9533f355/state_dict/encoder_SSPA_M64N8_T100DDPM1.pt --------------------------------------------------------------------------------