├── _ntsccp ├── net │ ├── __init__.py │ └── utils.py └── layer │ └── __init__.py ├── _pdjscc ├── __init__.py ├── net │ ├── __init__.py │ ├── loss │ │ ├── __init__.py │ │ └── distortion.py │ ├── normalization │ │ ├── __init__.py │ │ ├── channel.py │ │ └── GDN.py │ ├── attention.py │ ├── channel.py │ ├── encoder.py │ ├── decoder.py │ ├── discriminator.py │ └── network.py ├── data │ ├── __init__.py │ └── dataset.py ├── loss_utils │ ├── __init__.py │ ├── loss │ │ ├── __init__.py │ │ ├── vgg_loss.py │ │ ├── gan_loss.py │ │ └── distortion.py │ └── perceptual_similarity │ │ ├── __init__.py │ │ ├── weights │ │ └── v0.1 │ │ │ ├── alex.pth │ │ │ ├── vgg.pth │ │ │ └── squeeze.pth │ │ ├── dists_loss │ │ ├── weights.pt │ │ ├── DISTS │ │ │ ├── weights.pt │ │ │ ├── main.py │ │ │ └── DISTS_pt.py │ │ ├── main.py │ │ └── DISTS_pt.py │ │ ├── base_model.py │ │ ├── perceptual_loss.py │ │ ├── pretrained_networks.py │ │ └── networks_basic.py ├── demo │ ├── 69037.png │ ├── 69133.png │ ├── 69367.png │ ├── 69887.png │ └── 69929.png ├── README.md ├── config.py ├── utils.py ├── test.py └── train.py ├── channel ├── __init__.py ├── Pilot_bit.pt └── channel.py ├── conditioning_method ├── __init__.py └── diffcom.py ├── _djscc └── ckpt │ └── put adjscc weights here ├── imgs ├── Fig_RDP.png ├── Fig_ce_free.png ├── Fig_framework.png └── Fig_generalization.png ├── website ├── public │ ├── cover.webp │ ├── favicon.ico │ ├── logo192.png │ ├── logo512.png │ ├── robots.txt │ ├── imgs │ │ ├── overview.png │ │ ├── Fig_generalization.png │ │ └── results │ │ │ ├── NTSCC │ │ │ └── SNR1 │ │ │ │ ├── ori_0.png │ │ │ │ ├── ori_1.png │ │ │ │ ├── ori_2.png │ │ │ │ ├── vtm_0.png │ │ │ │ ├── vtm_1.png │ │ │ │ ├── vtm_2.png │ │ │ │ ├── input_0.png │ │ │ │ ├── input_1.png │ │ │ │ ├── input_2.png │ │ │ │ ├── recon_0.png │ │ │ │ ├── recon_1.png │ │ │ │ └── recon_2.png │ │ │ └── DeepJSCC │ │ │ └── SNR1 │ │ │ ├── input_0.png │ │ │ ├── input_1.png │ │ │ ├── input_2.png │ │ │ ├── ori_0.png │ │ │ ├── ori_1.png │ │ │ ├── ori_2.png │ │ │ ├── recon_0.png │ │ │ ├── recon_1.png │ │ │ ├── recon_2.png │ │ │ ├── vtm_0.png │ │ │ ├── vtm_1.png │ │ │ └── vtm_2.png │ ├── manifest.json │ └── index.html ├── README.md ├── src │ ├── setupTests.js │ ├── App.test.js │ ├── reportWebVitals.js │ ├── index.css │ ├── index.js │ ├── App.js │ ├── components │ │ ├── section2 │ │ │ └── Section2.js │ │ ├── footer │ │ │ └── Footer.js │ │ ├── header │ │ │ └── Header.js │ │ ├── section1 │ │ │ └── Section1.js │ │ ├── section3 │ │ │ └── Section3.js │ │ └── section4 │ │ │ └── Section4.js │ ├── logo.svg │ └── App.css ├── Dockerfile ├── .gitignore ├── convert_webp.py └── package.json ├── model_zoo └── README.md ├── data └── datasets.py ├── utils └── utils_logger.py ├── configs └── diffcom.yaml ├── guided_diffusion ├── losses.py ├── noise_schedule.py ├── respace.py ├── nn.py └── fp16_util.py └── README.md /_ntsccp/net/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/net/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /channel/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_ntsccp/layer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/net/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /conditioning_method/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_djscc/ckpt/put adjscc weights here: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/net/normalization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/Fig_RDP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/imgs/Fig_RDP.png -------------------------------------------------------------------------------- /channel/Pilot_bit.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/channel/Pilot_bit.pt -------------------------------------------------------------------------------- /imgs/Fig_ce_free.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/imgs/Fig_ce_free.png -------------------------------------------------------------------------------- /_pdjscc/demo/69037.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/_pdjscc/demo/69037.png -------------------------------------------------------------------------------- /_pdjscc/demo/69133.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/_pdjscc/demo/69133.png -------------------------------------------------------------------------------- /_pdjscc/demo/69367.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/_pdjscc/demo/69367.png -------------------------------------------------------------------------------- /_pdjscc/demo/69887.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/_pdjscc/demo/69887.png -------------------------------------------------------------------------------- /_pdjscc/demo/69929.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/_pdjscc/demo/69929.png -------------------------------------------------------------------------------- /imgs/Fig_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/imgs/Fig_framework.png -------------------------------------------------------------------------------- /website/public/cover.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/cover.webp -------------------------------------------------------------------------------- /imgs/Fig_generalization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/imgs/Fig_generalization.png -------------------------------------------------------------------------------- /website/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/favicon.ico -------------------------------------------------------------------------------- /website/public/logo192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/logo192.png -------------------------------------------------------------------------------- /website/public/logo512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/logo512.png -------------------------------------------------------------------------------- /website/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /website/public/imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/overview.png -------------------------------------------------------------------------------- /website/public/imgs/Fig_generalization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/Fig_generalization.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/ori_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/ori_0.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/ori_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/ori_1.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/ori_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/ori_2.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/vtm_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/vtm_0.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/vtm_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/vtm_1.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/vtm_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/vtm_2.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/input_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/input_0.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/input_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/input_1.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/input_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/input_2.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/recon_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/recon_0.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/recon_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/recon_1.png -------------------------------------------------------------------------------- /website/public/imgs/results/NTSCC/SNR1/recon_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/NTSCC/SNR1/recon_2.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/input_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/input_0.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/input_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/input_1.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/input_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/input_2.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/ori_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/ori_0.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/ori_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/ori_1.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/ori_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/ori_2.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/recon_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/recon_0.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/recon_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/recon_1.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/recon_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/recon_2.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/vtm_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/vtm_0.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/vtm_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/vtm_1.png -------------------------------------------------------------------------------- /website/public/imgs/results/DeepJSCC/SNR1/vtm_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/website/public/imgs/results/DeepJSCC/SNR1/vtm_2.png -------------------------------------------------------------------------------- /website/README.md: -------------------------------------------------------------------------------- 1 | # Getting Started with Create React App 2 | 3 | This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app). 4 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/dists_loss/weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/_pdjscc/loss_utils/perceptual_similarity/dists_loss/weights.pt -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/dists_loss/DISTS/weights.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/diffcom/HEAD/_pdjscc/loss_utils/perceptual_similarity/dists_loss/DISTS/weights.pt -------------------------------------------------------------------------------- /website/src/setupTests.js: -------------------------------------------------------------------------------- 1 | // jest-dom adds custom jest matchers for asserting on DOM nodes. 2 | // allows you to do things like: 3 | // expect(element).toHaveTextContent(/react/i) 4 | // learn more: https://github.com/testing-library/jest-dom 5 | import '@testing-library/jest-dom'; 6 | -------------------------------------------------------------------------------- /website/src/App.test.js: -------------------------------------------------------------------------------- 1 | import { render, screen } from '@testing-library/react'; 2 | import App from './App'; 3 | 4 | test('renders learn react link', () => { 5 | render(); 6 | const linkElement = screen.getByText(/learn react/i); 7 | expect(linkElement).toBeInTheDocument(); 8 | }); 9 | -------------------------------------------------------------------------------- /model_zoo/README.md: -------------------------------------------------------------------------------- 1 | |Model|Download link| 2 | |---|:--:| 3 | |256x256_diffusion_uncond.pt(ILSVRC 2012 subset of ImageNet)| [download link](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt) | 4 | |ffhq_10m.pt| [download link](https://drive.google.com/drive/folders/1jElnRoFv7b31fG0v6pTSQkelbSX3xGZh?usp=sharing) | 5 | 6 | -------------------------------------------------------------------------------- /website/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM node 2 | WORKDIR /usr/src/app 3 | 4 | # git configuration 5 | ARG SSH_PRIVATE_KEY 6 | RUN mkdir ~/.ssh/ 7 | RUN echo "$SSH_PRIVATE_KEY" >> ~/.ssh/id_rsa && chmod 600 ~/.ssh/id_rsa 8 | 9 | # npm package install 10 | COPY package*.json ./ 11 | RUN npm install --silent 12 | COPY . . 13 | CMD ["npm", "start"] 14 | EXPOSE 3000 15 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/dists_loss/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, sys 3 | import torch 4 | from torchvision import models, transforms 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | print(np.hanning(5)) 9 | a = np.hanning(5)[1:-1] 10 | print(a) 11 | g = torch.Tensor(a[:, None] * a[None, :]) 12 | print(g) 13 | g = g / torch.sum(g) 14 | 15 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/dists_loss/DISTS/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, sys 3 | import torch 4 | from torchvision import models, transforms 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | print(np.hanning(5)) 9 | a = np.hanning(5)[1:-1] 10 | print(a) 11 | g = torch.Tensor(a[:, None] * a[None, :]) 12 | print(g) 13 | g = g / torch.sum(g) 14 | 15 | -------------------------------------------------------------------------------- /website/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | -------------------------------------------------------------------------------- /website/src/reportWebVitals.js: -------------------------------------------------------------------------------- 1 | const reportWebVitals = onPerfEntry => { 2 | if (onPerfEntry && onPerfEntry instanceof Function) { 3 | import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => { 4 | getCLS(onPerfEntry); 5 | getFID(onPerfEntry); 6 | getFCP(onPerfEntry); 7 | getLCP(onPerfEntry); 8 | getTTFB(onPerfEntry); 9 | }); 10 | } 11 | }; 12 | 13 | export default reportWebVitals; 14 | -------------------------------------------------------------------------------- /website/convert_webp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from PIL import Image 4 | import shutil 5 | 6 | def convert_to_webp(file_path: Path): 7 | img = Image.open(file_path).convert('RGB') 8 | new_path = str(file_path).replace(file_path.suffix, '.webp') 9 | img.save(new_path, 'webp') 10 | 11 | all_files = Path('public/imgs').glob('**/*') 12 | all_imgs = [x for x in all_files if str(x).endswith(".png") or str(x).endswith(".jpg")] 13 | 14 | for img_path in all_imgs: 15 | convert_to_webp(img_path) 16 | os.remove(img_path) 17 | -------------------------------------------------------------------------------- /website/src/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 4 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', 5 | sans-serif; 6 | -webkit-font-smoothing: antialiased; 7 | -moz-osx-font-smoothing: grayscale; 8 | } 9 | 10 | code { 11 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', 12 | monospace; 13 | } 14 | 15 | .ddiffcom { 16 | font-style: italic; 17 | font-weight: bold; 18 | } 19 | 20 | .bold { 21 | /* bold text */ 22 | font-weight: bold; 23 | } -------------------------------------------------------------------------------- /website/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom/client'; 3 | import './index.css'; 4 | import App from './App'; 5 | import reportWebVitals from './reportWebVitals'; 6 | 7 | const root = ReactDOM.createRoot(document.getElementById('root')); 8 | root.render( 9 | 10 | 11 | 12 | ); 13 | 14 | // If you want to start measuring performance in your app, pass a function 15 | // to log results (for example: reportWebVitals(console.log)) 16 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals 17 | reportWebVitals(); 18 | -------------------------------------------------------------------------------- /website/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "logo192.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "logo512.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } 26 | -------------------------------------------------------------------------------- /_pdjscc/net/loss/distortion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MSE(torch.nn.Module): 5 | def __init__(self): 6 | super(MSE, self).__init__() 7 | self.squared_difference = torch.nn.MSELoss(reduction='none') 8 | 9 | def forward(self, X, Y): 10 | return torch.mean(self.squared_difference(X * 255., Y * 255.)) # / 255. 11 | 12 | 13 | class Distortion(torch.nn.Module): 14 | def __init__(self, config): 15 | super(Distortion, self).__init__() 16 | if config.distortion_metric == 'MSE': 17 | self.dist = MSE() 18 | else: 19 | print("Unknown distortion type!") 20 | raise ValueError 21 | 22 | def forward(self, X, Y): 23 | return self.dist.forward(X, Y) # / 255. 24 | -------------------------------------------------------------------------------- /_pdjscc/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Perceptual Learned Source-Channel Coding for High-Fidelity Image Semantic Transmission 3 | This is the repository of the 4 | paper "[Perceptual Learned Source-Channel Coding for High-Fidelity Image Semantic Transmission](https://arxiv.org/abs/2205.13120)". 5 | 6 | ## Pretrained weights (AWGN 10dB) 7 | LINK: [Google Drive](https://drive.google.com/drive/folders/1Bgaons0M8kaz8yYyGgmiDi-9b_3mEgIj?usp=sharing) 8 | 9 | ## Citation 10 | 11 | If you find this code useful for your research, please cite our paper 12 | 13 | ``` 14 | @inproceedings{wang2022perceptual, 15 | title={Perceptual learned source-channel coding for high-fidelity image semantic transmission}, 16 | author={Wang, Jun and Wang, Sixian and Dai, Jincheng and Si, Zhongwei and Zhou, Dekun and Niu, Kai}, 17 | booktitle={GLOBECOM 2022-2022 IEEE Global Communications Conference}, 18 | pages={3959--3964}, 19 | year={2022}, 20 | organization={IEEE} 21 | } 22 | ``` -------------------------------------------------------------------------------- /_pdjscc/net/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import torch.nn.functional as F 6 | from timm.models.layers import trunc_normal_ 7 | 8 | 9 | class SNRAttention(nn.Module): 10 | 11 | def __init__(self, C): 12 | super(SNRAttention, self).__init__() 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | 15 | self.fc1 = nn.Linear(C + 1, (C + 1) // 16) 16 | self.relu1 = nn.ReLU() 17 | self.fc2 = nn.Linear((C + 1) // 16, C) 18 | self.sigmoid = nn.Sigmoid() 19 | 20 | def forward(self, x, SNR): 21 | feature_pooling = self.avg_pool(x) 22 | [b, c, _, _] = feature_pooling.shape 23 | context_information = torch.cat((SNR, feature_pooling.reshape(b, c)), 1) 24 | scale_factor = self.sigmoid(self.fc2(self.relu1(self.fc1(context_information)))) 25 | out = torch.mul(x, scale_factor.unsqueeze(2).unsqueeze(3)) 26 | return out 27 | -------------------------------------------------------------------------------- /website/src/App.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import './App.css'; 3 | 4 | import Header from "./components/header/Header"; 5 | import Section1 from "./components/section1/Section1"; 6 | import Section2 from "./components/section2/Section2"; 7 | import Section3 from "./components/section3/Section3"; 8 | import Section4 from "./components/section4/Section4"; 9 | // import Section5 from "./components/section5/Section5"; 10 | import Footer from "./components/footer/Footer"; 11 | 12 | import { MathJaxContext } from 'better-react-mathjax'; 13 | 14 | const config = { 15 | loader: { load: ["[tex]/html"]}, 16 | tex: { 17 | packages: { "[+]": ["html"] }, 18 | inlineMath: [ 19 | ["$", "$"], 20 | ["\\(", "\\)"] 21 | ], 22 | displayMath: [ 23 | ["$$", "$$"], 24 | ["\\[", "\\]"] 25 | ] 26 | } 27 | }; 28 | 29 | function App() { 30 | return ( 31 | 32 |
33 |
34 | 35 | 36 | 37 | 38 | {/**/} 39 |
40 |
41 |
42 | ); 43 | } 44 | 45 | export default App; 46 | -------------------------------------------------------------------------------- /website/src/components/section2/Section2.js: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | 3 | const CenterWrapper = (props) => { 4 | return ( 5 |
6 |
7 |
8 |
9 | {props.content} 10 |
11 |
12 |
13 |
14 | ); 15 | } 16 | 17 | const Content = () => { 18 | return ( 19 |
20 |

DiffCom is Robust to Unexpected Transmission Degradations

21 | {"loading.."}/ 25 |

A visual comparison illustrating the impact of several unexpected 26 | transmission degradations: ① unseen channel fading, ② PAPR reduction, ③ 27 | with ISI (removed CP symbols), and ④ very low CSNR (0dB).

28 |
29 | ) 30 | } 31 | 32 | const Section2 = () => { 33 | return ( 34 | }/> 35 | ); 36 | } 37 | 38 | export default Section2 39 | -------------------------------------------------------------------------------- /_pdjscc/config.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | class config: 5 | CUDA = True 6 | gpu_list = '2' 7 | logger = False 8 | pass_channel = True 9 | channel = {"type": 'awgn', 'chan_param': 10} 10 | 11 | img_size = (3, 256, 256) 12 | norm = normalize = False 13 | # dataset_dir = '/media/D/Dataset/CIFAR10' 14 | filename = datetime.now().__str__()[:-7] 15 | workdir = './history/{}'.format(filename) 16 | log = workdir + '/Log_{}.log'.format(filename) 17 | samples = workdir + '/samples' 18 | models = workdir + '/models' 19 | train_data_dir = ['path of train dataset'] 20 | test_data_dir = ["path of test dataset"] 21 | 22 | C = 2 23 | kdivn = 2/96 24 | 25 | K_P = 1.0 26 | K_M = 0.01 27 | K_S = 0.0 28 | beta = 5.0 29 | 30 | image_dims = (3, 256, 256) 31 | use_discriminator = False 32 | gan_loss_type = 'non_saturating' 33 | discriminator_steps = 1 34 | generator_steps = 1 35 | dis_acc = 1.0 36 | # Parameters Setting for Training 37 | epochs = 1000 38 | batch_size = 8 39 | test_batch_size = 1 40 | print_step = 50 41 | test_step = 1000 42 | g_learning_rate = 1e-4 43 | d_learning_rate = 1e-4 44 | distortion_metric = 'MSE' 45 | multiple_snr = [1] 46 | save_epoch = 1 47 | predict = 'The path of the pretrained model' -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | from torch.utils.data.dataset import Dataset 9 | from torchvision import transforms 10 | 11 | 12 | class Datasets(Dataset): 13 | def __init__(self, dataset_path): 14 | self.data_dir = dataset_path 15 | self.imgs = [] 16 | self.imgs += glob(os.path.join(self.data_dir, '*.jpg')) 17 | self.imgs += glob(os.path.join(self.data_dir, '*.png')) 18 | self.imgs.sort() 19 | if len(self.imgs) == 0: 20 | raise ValueError(f"Dataset path {self.data_dir} is empty") 21 | self.transform = transforms.Compose([ 22 | transforms.ToTensor()]) 23 | 24 | def __getitem__(self, item): 25 | image_path = self.imgs[item] 26 | name = os.path.basename(image_path) 27 | image = Image.open(image_path).convert('RGB') 28 | img = self.transform(image) 29 | return img, name 30 | 31 | def __len__(self): 32 | return len(self.imgs) 33 | 34 | 35 | def get_test_loader(test_dir, batch_size=1, shuffle=False): 36 | test_dataset = Datasets(test_dir) 37 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 38 | batch_size=batch_size, 39 | shuffle=shuffle) 40 | return test_loader 41 | -------------------------------------------------------------------------------- /website/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "diffcom-app", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "@emotion/react": "^11.10.6", 7 | "@emotion/styled": "^11.10.6", 8 | "@mui/material": "^5.11.11", 9 | "@testing-library/jest-dom": "^5.16.5", 10 | "@testing-library/react": "^13.4.0", 11 | "@testing-library/user-event": "^13.5.0", 12 | "apexcharts": "^3.37.1", 13 | "better-react-mathjax": "^2.0.2", 14 | "lodash": "^4.17.21", 15 | "react": "^18.2.0", 16 | "react-apexcharts": "^1.4.0", 17 | "react-compare-slider": "^2.2.0", 18 | "react-dom": "^18.2.0", 19 | "react-icons": "^4.7.1", 20 | "react-responsive": "^9.0.2", 21 | "react-responsive-carousel": "^3.2.23", 22 | "react-scripts": "5.0.1", 23 | "react-swipe": "^6.0.4", 24 | "web-vitals": "^2.1.4" 25 | }, 26 | "scripts": { 27 | "start": "react-scripts start", 28 | "build": "react-scripts build", 29 | "test": "react-scripts test", 30 | "eject": "react-scripts eject", 31 | "predeploy": "npm run build", 32 | "deploy": "gh-pages -d build" 33 | }, 34 | "eslintConfig": { 35 | "extends": [ 36 | "react-app", 37 | "react-app/jest" 38 | ] 39 | }, 40 | "browserslist": { 41 | "production": [ 42 | ">0.2%", 43 | "not dead", 44 | "not op_mini all" 45 | ], 46 | "development": [ 47 | "last 1 chrome version", 48 | "last 1 firefox version", 49 | "last 1 safari version" 50 | ] 51 | }, 52 | "devDependencies": { 53 | "gh-pages": "^4.0.0" 54 | }, 55 | "homepage": "https://semcomm.github.io/DiffCom" 56 | } 57 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | class BaseModel(): 6 | def __init__(self): 7 | pass 8 | 9 | def name(self): 10 | return 'BaseModel' 11 | 12 | def initialize(self, use_gpu=True, gpu_ids=[0]): 13 | self.use_gpu = use_gpu 14 | self.gpu_ids = gpu_ids 15 | 16 | def forward(self): 17 | pass 18 | 19 | def get_image_paths(self): 20 | pass 21 | 22 | def optimize_parameters(self): 23 | pass 24 | 25 | def get_current_visuals(self): 26 | return self.input 27 | 28 | def get_current_errors(self): 29 | return {} 30 | 31 | def save(self, label): 32 | pass 33 | 34 | # helper saving function that can be used by subclasses 35 | def save_network(self, network, path, network_label, epoch_label): 36 | save_filename = f'{epoch_label}_net_{network_label}' 37 | save_path = os.path.join(path, save_filename) 38 | torch.save(network.state_dict(), save_path) 39 | 40 | # helper loading function that can be used by subclasses 41 | def load_network(self, network, network_label, epoch_label): 42 | save_filename = f'{epoch_label}_net_{network_label}' 43 | save_path = os.path.join(self.save_dir, save_filename) 44 | print(f'Loading network from {save_path}') 45 | network.load_state_dict(torch.load(save_path)) 46 | 47 | def get_image_paths(self): 48 | return self.image_paths 49 | 50 | def save_done(self, flag=False): 51 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 52 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 53 | 54 | -------------------------------------------------------------------------------- /utils/utils_logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import datetime 4 | import logging 5 | 6 | 7 | ''' 8 | modified by Kai Zhang (github: https://github.com/cszn) 9 | 03/03/2019 10 | https://github.com/xinntao/BasicSR 11 | ''' 12 | 13 | 14 | def log(*args, **kwargs): 15 | print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) 16 | 17 | 18 | ''' 19 | # =============================== 20 | # logger 21 | # logger_name = None = 'base' ??? 22 | # =============================== 23 | ''' 24 | 25 | 26 | def logger_info(logger_name, log_path='default_logger.log'): 27 | ''' set up logger 28 | modified by Kai Zhang (github: https://github.com/cszn) 29 | ''' 30 | log = logging.getLogger(logger_name) 31 | if log.hasHandlers(): 32 | print('LogHandlers exists!') 33 | else: 34 | print('LogHandlers setup!') 35 | level = logging.INFO 36 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') 37 | fh = logging.FileHandler(log_path, mode='a') 38 | fh.setFormatter(formatter) 39 | log.setLevel(level) 40 | log.addHandler(fh) 41 | # print(len(log.handlers)) 42 | 43 | sh = logging.StreamHandler() 44 | sh.setFormatter(formatter) 45 | log.addHandler(sh) 46 | 47 | 48 | ''' 49 | # =============================== 50 | # print to file and std_out simultaneously 51 | # =============================== 52 | ''' 53 | 54 | 55 | class logger_print(object): 56 | def __init__(self, log_path="default.log"): 57 | self.terminal = sys.stdout 58 | self.log = open(log_path, 'a') 59 | 60 | def write(self, message): 61 | self.terminal.write(message) 62 | self.log.write(message) # write the message 63 | 64 | def flush(self): 65 | pass 66 | -------------------------------------------------------------------------------- /website/src/components/footer/Footer.js: -------------------------------------------------------------------------------- 1 | import React, {Component} from "react"; 2 | import { IconButton } from "@mui/material"; 3 | import { VscGithub } from "react-icons/vsc" 4 | import {FaFilePdf} from "react-icons/fa" 5 | 6 | const LinkButton = (props) => ( 7 | 8 | {props.icon} 9 | 10 | ); 11 | 12 | export default class Footer extends Component{ 13 | render(){ 14 | return ( 15 |
16 |
17 |
18 | } text="Paper"/> 19 | } text="Code"/> 20 |
21 |
22 |
23 |
24 |

25 | This website is licensed under a Creative 27 | Commons Attribution-ShareAlike 4.0 International License. 28 |

29 |

30 | This means you are free to borrow the source code of this website, 32 | we just ask that you link back to this page in the footer. 33 | Please remember to remove the analytics code included in the header of the website which 34 | you do not want on your website. 35 |

36 |
37 |
38 |
39 |
40 |
41 | ); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/loss/vgg_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import models 3 | import numpy as np 4 | 5 | 6 | class Vgg19(torch.nn.Module): 7 | """ 8 | Vgg19 network for perceptual loss. See Sec 3.3. 9 | """ 10 | def __init__(self, requires_grad=False): 11 | super(Vgg19, self).__init__() 12 | vgg_pretrained_features = models.vgg19(pretrained=True).features 13 | self.slice1 = torch.nn.Sequential() 14 | self.slice2 = torch.nn.Sequential() 15 | self.slice3 = torch.nn.Sequential() 16 | self.slice4 = torch.nn.Sequential() 17 | self.slice5 = torch.nn.Sequential() 18 | for x in range(2): 19 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 20 | for x in range(2, 7): 21 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(7, 12): 23 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(12, 21): 25 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 26 | for x in range(21, 30): 27 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 28 | 29 | self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), 30 | requires_grad=False) 31 | self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), 32 | requires_grad=False) 33 | 34 | if not requires_grad: 35 | for param in self.parameters(): 36 | param.requires_grad = False 37 | 38 | def forward(self, X): 39 | X = (X - self.mean) / self.std 40 | h_relu1 = self.slice1(X) 41 | h_relu2 = self.slice2(h_relu1) 42 | h_relu3 = self.slice3(h_relu2) 43 | h_relu4 = self.slice4(h_relu3) 44 | h_relu5 = self.slice5(h_relu4) 45 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 46 | return out -------------------------------------------------------------------------------- /website/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 12 | 13 | 17 | 18 | 27 | DiffCom: Channel Received Signal is a Natural Condition to Guide Diffusion Posterior Sampling 28 | 29 | 30 | 31 | 32 | 33 | 34 |
35 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /_pdjscc/net/normalization/channel.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.nn import Parameter 7 | 8 | def InstanceNorm2D_wrap(input_channels, momentum=0.1, affine=True, 9 | track_running_stats=False, **kwargs): 10 | """ 11 | Wrapper around default Torch instancenorm 12 | """ 13 | instance_norm_layer = nn.InstanceNorm2d(input_channels, 14 | momentum=momentum, affine=affine, 15 | track_running_stats=track_running_stats) 16 | return instance_norm_layer 17 | 18 | def ChannelNorm2D_wrap(input_channels, momentum=0.1, affine=True, 19 | track_running_stats=False, **kwargs): 20 | """ 21 | Wrapper around Channel Norm module 22 | """ 23 | channel_norm_layer = ChannelNorm2D(input_channels, 24 | momentum=momentum, affine=affine, 25 | track_running_stats=track_running_stats) 26 | 27 | return channel_norm_layer 28 | 29 | class ChannelNorm2D(nn.Module): 30 | """ 31 | Similar to default Torch instanceNorm2D but calculates 32 | moments over channel dimension instead of spatial dims. 33 | Expects input_dim in format (B,C,H,W) 34 | """ 35 | 36 | def __init__(self, input_channels, momentum=0.1, eps=1e-3, 37 | affine=True, **kwargs): 38 | super(ChannelNorm2D, self).__init__() 39 | 40 | self.momentum = momentum 41 | self.eps = eps 42 | self.affine = affine 43 | 44 | if affine is True: 45 | self.gamma = nn.Parameter(torch.ones(1, input_channels, 1, 1)) 46 | self.beta = nn.Parameter(torch.zeros(1, input_channels, 1, 1)) 47 | 48 | def forward(self, x): 49 | """ 50 | Calculate moments over channel dim, normalize. 51 | x: Image tensor, shape (B,C,H,W) 52 | """ 53 | mu, var = torch.mean(x, dim=1, keepdim=True), torch.var(x, dim=1, keepdim=True) 54 | 55 | x_normed = (x - mu) * torch.rsqrt(var + self.eps) 56 | 57 | if self.affine is True: 58 | x_normed = self.gamma * x_normed + self.beta 59 | return x_normed 60 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/loss/gan_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def _non_saturating_loss(D_real_logits, D_gen_logits, D_real=None, D_gen=None): 8 | D_loss_real = F.binary_cross_entropy_with_logits(input=D_real_logits, 9 | target=torch.ones_like(D_real_logits)) 10 | D_loss_gen = F.binary_cross_entropy_with_logits(input=D_gen_logits, 11 | target=torch.zeros_like(D_gen_logits)) 12 | D_loss = D_loss_real + D_loss_gen 13 | 14 | G_loss = F.binary_cross_entropy_with_logits(input=D_gen_logits, 15 | target=torch.ones_like(D_gen_logits)) 16 | 17 | return D_loss, G_loss 18 | 19 | 20 | def _least_squares_loss(D_real, D_gen, D_real_logits=None, D_gen_logits=None): 21 | D_loss_real = torch.mean(torch.square(D_real - 1.0)) 22 | D_loss_gen = torch.mean(torch.square(D_gen)) 23 | D_loss = 0.5 * (D_loss_real + D_loss_gen) 24 | 25 | G_loss = 0.5 * torch.mean(torch.square(D_gen - 1.0)) 26 | 27 | return D_loss, G_loss 28 | 29 | def relativistic_least_squares_loss(D_real, D_gen, D_real_logits=None, D_gen_logits=None): 30 | Relativisitc_loss_real = D_real - D_gen.mean(0, keepdim=True) 31 | Relativisitc_loss_gen = D_gen - D_real.mean(0, keepdim=True) 32 | 33 | D_loss = torch.mean(torch.square(Relativisitc_loss_real - 1.0)) + torch.mean(torch.square(Relativisitc_loss_gen + 1.0)) 34 | 35 | G_loss = torch.mean(torch.square(Relativisitc_loss_real + 1.0)) + torch.mean(torch.square(Relativisitc_loss_gen - 1.0)) 36 | 37 | return D_loss, G_loss 38 | 39 | def gan_loss(gan_loss_type, disc_out, mode='generator_loss'): 40 | if gan_loss_type == 'non_saturating': 41 | loss_fn = _non_saturating_loss 42 | elif gan_loss_type == 'least_squares': 43 | loss_fn = _least_squares_loss 44 | elif gan_loss_type == 'relative_least_squares': 45 | loss_fn = relativistic_least_squares_loss 46 | else: 47 | raise ValueError('Invalid GAN loss') 48 | 49 | D_loss, G_loss = loss_fn(D_real=disc_out.D_real, D_gen=disc_out.D_gen, 50 | D_real_logits=disc_out.D_real_logits, D_gen_logits=disc_out.D_gen_logits) 51 | 52 | loss = G_loss if mode == 'generator_loss' else D_loss 53 | 54 | return loss 55 | -------------------------------------------------------------------------------- /_pdjscc/net/channel.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import os 4 | import torch 5 | 6 | 7 | class Channel(nn.Module): 8 | def __init__(self, config): 9 | super(Channel, self).__init__() 10 | self.config = config 11 | self.chan_type = config.channel['type'] 12 | self.chan_param = config.channel['chan_param'] # SNR 13 | if config.logger: 14 | config.logger.info('【Channel】: Built {} channel, SNR {} dB.'.format( 15 | config.channel['type'], config.channel['chan_param'])) 16 | 17 | def gaussian_noise_layer(self, input_layer, std): 18 | noise_real = torch.normal(mean=0.0, std=std, size=np.shape(input_layer)) 19 | noise_imag = torch.normal(mean=0.0, std=std, size=np.shape(input_layer)) 20 | noise = noise_real + 1j * noise_imag 21 | if self.config.CUDA: 22 | noise = noise.to(input_layer.get_device()) 23 | return input_layer + noise 24 | 25 | def complex_normalize(self, x, power): 26 | pwr = torch.mean(x ** 2) * 2 27 | out = np.sqrt(power) * x / torch.sqrt(pwr) 28 | return out, pwr 29 | 30 | def complex_forward(self, channel_in): 31 | if self.chan_type == 0 or self.chan_type == 'none': 32 | return channel_in 33 | 34 | elif self.chan_type == 1 or self.chan_type == 'awgn': 35 | # power normalization 36 | channel_tx = channel_in 37 | sigma = np.sqrt(1.0 / (2 * 10 ** (self.chan_param / 10))) 38 | chan_output = self.gaussian_noise_layer(channel_tx, 39 | std=sigma) 40 | return chan_output 41 | 42 | def forward(self, input): 43 | # input \in R 44 | channel_tx, pwr = self.complex_normalize(input, power=1) 45 | input_shape = channel_tx.shape 46 | channel_in = channel_tx.reshape(-1) 47 | L = channel_in.shape[0] 48 | channel_in = channel_in[:L // 2] + channel_in[L // 2:] * 1j 49 | channel_output = self.complex_forward(channel_in) 50 | channel_output = torch.cat([torch.real(channel_output), torch.imag(channel_output)]) 51 | channel_output = channel_output.reshape(input_shape) 52 | 53 | noise = (channel_output - channel_tx).detach() 54 | noise.requires_grad = False 55 | channel_rx = channel_tx + noise 56 | return channel_rx 57 | -------------------------------------------------------------------------------- /website/src/logo.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_pdjscc/net/encoder.py: -------------------------------------------------------------------------------- 1 | from .normalization.GDN import * 2 | from .attention import SNRAttention 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class Encoder(nn.Module): 8 | def __init__(self, config): 9 | super(Encoder, self).__init__() 10 | self.C = config.C 11 | self.img_size = config.img_size 12 | activation = 'prelu' 13 | device = torch.device('cuda') 14 | activation_d = dict(relu='ReLU', elu='ELU', leaky_relu='LeakyReLU', prelu='PReLU') 15 | self.n_downsampling_layers = 2 16 | self.activation = getattr(nn, activation_d[activation]) # (leaky_relu, relu, elu, prelu) 17 | 18 | self.attention1 = SNRAttention(256) 19 | self.attention2 = SNRAttention(256) 20 | self.attention3 = SNRAttention(256) 21 | self.attention4 = SNRAttention(256) 22 | # self.attention5 = SNRAttention(self.C) 23 | 24 | # (3,32,32) -> (256,16,16), with implicit padding 25 | self.conv_block1 = nn.Sequential( 26 | nn.Conv2d(3, 256, kernel_size=(9, 9), stride=2, padding=(9 - 2) // 2 + 1), 27 | GDN(256, device, False), 28 | self.activation(), 29 | ) 30 | 31 | # (256,16,16) -> (256,8,8) 32 | self.conv_block2 = nn.Sequential( 33 | nn.Conv2d(256, 256, kernel_size=(5, 5), stride=2, padding=(5 - 2) // 2 + 1), 34 | GDN(256, device, False), 35 | self.activation(), 36 | ) 37 | 38 | # (256,8,8) -> (256,8,8) 39 | self.conv_block3 = nn.Sequential( 40 | nn.Conv2d(256, 256, kernel_size=(5, 5), stride=1, padding=(5 - 1) // 2), 41 | GDN(256, device, False), 42 | self.activation(), 43 | ) 44 | 45 | # (256,8,8) -> (256,8,8) 46 | self.conv_block4 = nn.Sequential( 47 | nn.Conv2d(256, 256, kernel_size=(5, 5), stride=1, padding=(5 - 1) // 2), 48 | GDN(256, device, False), 49 | self.activation(), 50 | ) 51 | 52 | # (256,8,8) -> (tcn,8,8) 53 | self.conv_block5 = nn.Sequential( 54 | nn.Conv2d(256, self.C, kernel_size=(5, 5), stride=1, padding=(5 - 1) // 2), 55 | GDN(self.C, device, False), 56 | ) 57 | 58 | def forward(self, x, SNR): 59 | x = self.conv_block1(x) 60 | x = self.attention1(x, SNR) 61 | x = self.conv_block2(x) 62 | x = self.attention2(x, SNR) 63 | x = self.conv_block3(x) 64 | x = self.attention3(x, SNR) 65 | x = self.conv_block4(x) 66 | x = self.attention4(x, SNR) 67 | out = self.conv_block5(x) 68 | # out = self.attention5(x, SNR) 69 | return out 70 | -------------------------------------------------------------------------------- /configs/diffcom.yaml: -------------------------------------------------------------------------------- 1 | conditioning_method: 'hifi_diffcom' # 'diffcom', 'hifi_diffcom', 'blind_diffcom' 2 | 3 | # config forward operator 4 | operator_name: 'djscc' # 'ntscc', 'swinjscc' 5 | 6 | djscc: 7 | # The djscc is implemented with SNR attention module, trained under AWGN channel with CSNR in [0, 14] dB. 8 | channel_num: 2 9 | jscc_model_path: '_djscc/ckpt/ADJSCC_C=2.pth.tar' 10 | # channel_num: 4 11 | # jscc_model_path: '_djscc/ckpt/ADJSCC_C=4.pth.tar' 12 | # channel_num: 6 13 | # jscc_model_path: '_djscc/ckpt/ADJSCC_C=6.pth.tar' 14 | 15 | #ntscc: 16 | # compatible: True 17 | # eta: 0.2 # 0.15~0.3, coarse adjustment, larger for higher bandwidth consumption 18 | # q_level: 3 # 0-99, fine adjustment, larger for higher bandwidth consumption 19 | 20 | #channel_type: 'awgn' # 'ofdm_tdl', 'awgn' 21 | channel_type: 'awgn' 22 | CSNR: 10 23 | ofdm_tdl: 24 | P: 1 25 | S: 16 26 | K: 12 27 | L: 8 28 | decay: 4 29 | N_pilot: 1 30 | is_clip: False 31 | clip_ratio: 0.8 32 | refine_channel_est: False 33 | blind: True # whether to use channel estimation algorithm, our blind_diffcom supports transmission without channel estimation. 34 | channel_est: LMMSE 35 | equalization: MMSE 36 | 37 | # config diffusion model and testset 38 | 39 | # for imagenet dataset 40 | #model_name: 256x256_diffusion_uncond 41 | #testset_name: imagenet 42 | 43 | # for ffhq dataset 44 | model_name: ffhq_10m # diffcom (djscc) GPU memory ≈ 5404MB 45 | testset_name: ffhq_demo 46 | #testset_name: ffhq_test 47 | 48 | 49 | # config hyperparameters for posterior sampling 50 | 51 | CSNR_adapt_t_start: True # whether to accelerate sampling with adaptive initialization, only works for diffcom and hifi_diffcom 52 | N: 1.0 # scaling factor to adjust the total rounds of reverse sampling steps. More rounds will lead to better performance but slower speed. 53 | 54 | diffcom_series: 55 | num_train_timesteps: 1000 56 | iter_num: 1000 57 | save_recon_every: 20 58 | 59 | diffcom: 60 | lr_schedule: constant 61 | learning_rate: 1.0 62 | zeta: 1.0 63 | gamma: 0.0 64 | 65 | hifi_diffcom: 66 | lr_schedule: constant 67 | lr_min: 0.3 68 | learning_rate: 1.0 69 | zeta: 0.25 70 | gamma: 0.25 71 | 72 | blind_diffcom: 73 | lr_schedule: constant 74 | learning_rate: 1.0 75 | zeta: 1.0 76 | gamma: 0.0 77 | h_lr: 0.2 78 | 79 | seed: 22 80 | gpu_id: 0 81 | iter_num_U: 1 82 | batch_size: 1 83 | save_L: true 84 | save_E: true 85 | log_process: false 86 | 87 | # default config for diffusion sampling, should be consistent with the training condition of the diffusion model 88 | ddim_sample: false 89 | model_output_type: pred_xstart 90 | skip_type: uniform 91 | beta_start: 0.0001 92 | beta_end: 0.02 -------------------------------------------------------------------------------- /guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /_pdjscc/net/decoder.py: -------------------------------------------------------------------------------- 1 | from .normalization.GDN import * 2 | from .attention import SNRAttention 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | 8 | class Decoder(nn.Module): 9 | def __init__(self, config,number_residual = 7): 10 | super(Decoder, self).__init__() 11 | self.C = config.C 12 | self.img_size = config.img_size 13 | self.post_pad = nn.ReflectionPad2d(3) 14 | activation = 'prelu' 15 | device = torch.device('cuda') 16 | activation_d = dict(relu='ReLU', elu='ELU', leaky_relu='LeakyReLU', prelu='PReLU') 17 | self.activation = getattr(nn, activation_d[activation]) # (leaky_relu, relu, elu, prelu) 18 | self.sigmoid = nn.Sigmoid() 19 | 20 | self.attention1 = SNRAttention(256) 21 | self.attention2 = SNRAttention(256) 22 | self.attention3 = SNRAttention(256) 23 | self.attention4 = SNRAttention(256) 24 | self.attention5 = SNRAttention(256) 25 | 26 | 27 | 28 | # (256,8,8) -> (256,8,8) 29 | self.upconv_block1 = nn.Sequential( 30 | nn.ConvTranspose2d(self.C, 256, kernel_size=(5, 5), stride=1, padding=(5 - 1) // 2), 31 | GDN(256, device, True), 32 | self.activation(), 33 | ) 34 | 35 | # (256,8,8) -> (256,8,8) 36 | self.upconv_block2 = nn.Sequential( 37 | nn.ConvTranspose2d(256, 256, kernel_size=(5, 5), stride=1, padding=(5 - 1) // 2), 38 | GDN(256, device, True), 39 | self.activation(), 40 | ) 41 | 42 | # (256,8,8) -> (256,8,8) 43 | self.upconv_block3 = nn.Sequential( 44 | nn.ConvTranspose2d(256, 256, kernel_size=(5, 5), stride=1, padding=(5 - 1) // 2), 45 | GDN(256, device, True), 46 | self.activation(), 47 | ) 48 | 49 | self.upconv_block4 = nn.Sequential( 50 | nn.ConvTranspose2d(256, 256, kernel_size=(5, 5), stride=2, padding=2, output_padding=1), 51 | GDN(256, device, True), 52 | self.activation(), 53 | ) 54 | 55 | self.upconv_block5 = nn.Sequential( 56 | nn.ConvTranspose2d(256, 256, kernel_size=(9, 9), stride=2, padding=4, output_padding=1), 57 | GDN(256, device, True), 58 | self.activation(), 59 | ) 60 | 61 | self.conv_block_out = nn.Sequential( 62 | self.post_pad, 63 | nn.Conv2d(256, 3, kernel_size=(7, 7), stride=1), 64 | nn.Sigmoid(), 65 | ) 66 | 67 | def forward(self, x, SNR): 68 | x = self.upconv_block1(x) 69 | x = self.attention1(x, SNR) 70 | x = self.upconv_block2(x) 71 | x = self.attention2(x, SNR) 72 | x = self.upconv_block3(x) 73 | x = self.attention3(x, SNR) 74 | x = self.upconv_block4(x) 75 | x = self.attention4(x, SNR) 76 | x = self.upconv_block5(x) 77 | x = self.attention5(x, SNR) 78 | out = self.conv_block_out(x) 79 | return out 80 | 81 | 82 | if __name__ == '__main__': 83 | import torch 84 | import torch.nn.functional as F 85 | from ADJSCC.config import config 86 | 87 | input_Tensor = torch.ones([2, 16, 64, 64]).cuda() 88 | SNR = torch.ones([2, 1]).cuda() 89 | model = Decoder(config).cuda() 90 | out = model(input_Tensor, SNR) 91 | print(out.shape) 92 | -------------------------------------------------------------------------------- /guided_diffusion/noise_schedule.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from utils.util import Config 5 | from utils import utils_model 6 | 7 | 8 | class NoiseSchedule: 9 | def __init__(self, config, logger, device): 10 | 11 | self.cond_config = Config(config.diffcom_series) 12 | 13 | # 1. linear schedule 14 | betas = np.linspace(config.beta_start, config.beta_end, self.cond_config.num_train_timesteps, dtype=np.float32) 15 | 16 | # 2. cosine schedule 17 | # t = np.linspace(1, 0, config.num_train_timesteps + 1)[1:] 18 | # betas = np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2 19 | 20 | self.betas = torch.from_numpy(betas).to(device) 21 | self.alphas = 1. - self.betas 22 | alphas_cumprod = np.cumprod(self.alphas.cpu(), axis=0) 23 | self.alphas_cumprod = alphas_cumprod.to(device) 24 | self.log_SNRs = torch.log10(self.alphas_cumprod / (1 - self.alphas_cumprod)) 25 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 26 | self.alphas_cumprod_prev = torch.from_numpy(alphas_cumprod_prev).to(device) 27 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 28 | self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) 29 | self.reduced_alpha_cumprod = torch.div(self.sqrt_1m_alphas_cumprod, 30 | self.sqrt_alphas_cumprod) # equivalent noise sigma on image 31 | self.posterior_mean_coef1 = ( 32 | self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 33 | ) 34 | self.posterior_mean_coef2 = ( 35 | (1.0 - self.alphas_cumprod_prev) 36 | * torch.sqrt(self.alphas) 37 | / (1.0 - self.alphas_cumprod) 38 | ) 39 | self.posterior_variance = ( 40 | self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 41 | ) 42 | self.num_train_timesteps = self.cond_config.num_train_timesteps 43 | self.sigma = config.sigma 44 | self.iter_num = self.cond_config.iter_num 45 | 46 | if not config.CSNR_adapt_t_start: 47 | self.t_start = self.num_train_timesteps - 1 48 | else: 49 | snr = 10 ** (config.CSNR / 10) 50 | DSNR = np.log10((1 + snr) ** (1 / 48)) * 5 51 | self.t_start = int((1000 - np.searchsorted(self.log_SNRs.cpu().numpy()[::-1], DSNR)) * config.N) 52 | 53 | logger.info(f'Start from timestep {self.t_start} with SNR from {self.log_SNRs[-1]} to {self.log_SNRs[0]}') 54 | 55 | # create sequence of timestep for sampling 56 | skip = self.num_train_timesteps // self.iter_num 57 | if config.skip_type == 'uniform': 58 | seq = [i * skip for i in np.arange(0, self.t_start // skip)] 59 | elif config.skip_type == "quad": 60 | seq = np.sqrt(np.linspace(0, self.num_train_timesteps ** 2, self.t_start)) 61 | seq = [int(s) for s in list(seq)] 62 | seq[-1] = seq[-1] - 1 63 | self.seq = seq[::-1] 64 | 65 | # plot log-SNR schedule 66 | # plt.plot(self.log_SNRs.cpu().numpy()) 67 | # plt.xlim(0, 1000) 68 | # plt.ylim(-15, 15) 69 | # plt.xlabel('timestep') 70 | # plt.ylabel('log-SNR') 71 | # plt.title('log-SNR noise schedule') 72 | # plt.savefig(os.path.join(config.E_path, 'log_SNR_schedule.png')) 73 | # plt.close() 74 | -------------------------------------------------------------------------------- /website/src/components/header/Header.js: -------------------------------------------------------------------------------- 1 | import {Button} from "@mui/material"; 2 | import React, {Component} from "react"; 3 | import {VscGithub} from "react-icons/vsc" 4 | import {FaFilePdf} from "react-icons/fa" 5 | import {SiArxiv} from "react-icons/si" 6 | 7 | const AuthorBlock = (props) => ( 8 | 9 | {props.name} 10 | {props.number}, 11 | 12 | ) 13 | 14 | const LinkButton = (props) => ( 15 | 26 | ); 27 | 28 | export default class Header extends Component { 29 | render() { 30 | return ( 31 |
32 |
33 |
34 |
35 |

36 | DiffCom: Channel Received Signal is a Natural Condition to Guide Diffusion Posterior 37 | Sampling 38 |

39 |
40 | Kailin Tan, 45 | Ping Zhang, 50 |
51 |
52 | Beijing University of Posts and Telecommunications (BUPT), Beijing, China 53 |
54 | {/*Publication links*/} 55 |
56 | {/*} text="Paper"/>*/} 57 | } text="arXiv"/> 58 | } 59 | text="Code"/> 60 |
61 |
62 |
63 |
64 |
65 | ); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /_pdjscc/net/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | class Discriminator(nn.Module): 4 | def __init__(self, image_dims,C, spectral_norm=True): 5 | """ 6 | Convolutional patchGAN discriminator used in [1]. 7 | Accepts as input generator output G(z) or x ~ p*(x) where 8 | p*(x) is the true data distribution. 9 | Contextual information provided is encoder output y = E(x) 10 | ======== 11 | Arguments: 12 | image_dims: Dimensions of input image, (C_in,H,W) 13 | context_dims: Dimensions of contextual information, (C_in', H', W') 14 | C: Bottleneck depth, controls bits-per-pixel 15 | C = 220 used in [1], C = C_in' if encoder output used 16 | as context. 17 | 18 | [1] Mentzer et. al., "High-Fidelity Generative Image Compression", 19 | arXiv:2006.09965 (2020). 20 | """ 21 | super(Discriminator, self).__init__() 22 | 23 | self.image_dims = image_dims 24 | im_channels = self.image_dims[0] 25 | kernel_dim = 4 26 | context_C_out = 12 27 | filters = (64, 128, 256, 512) 28 | 29 | # Upscale encoder output - (C, 16, 16) -> (12, 256, 256) 30 | self.latent = nn.Conv2d(C, context_C_out, kernel_size=3,padding = 1 ) 31 | self.context_upsample = nn.Upsample(scale_factor=4, mode='nearest') 32 | 33 | # Images downscaled to 500 x 1000 + randomly cropped to 256 x 256 34 | # assert image_dims == (im_channels, 256, 256), 'Crop image to 256 x 256!' 35 | 36 | # Layer / normalization options 37 | # TODO: calculate padding properly 38 | cnn_kwargs = dict(stride=2,padding = 1) 39 | self.activation = nn.LeakyReLU(negative_slope=0.2) 40 | 41 | if spectral_norm is True: 42 | norm = nn.utils.spectral_norm 43 | else: 44 | norm = nn.utils.weight_norm 45 | 46 | # (C_in + C_in', 256,256) -> (64,128,128), with implicit padding 47 | # TODO: Check if removing spectral norm in first layer works 48 | self.d_conv_head = norm(nn.Conv2d(im_channels + context_C_out, filters[0], kernel_dim, **cnn_kwargs)) 49 | 50 | # (128,128) -> (64,64) 51 | self.d_conv_0 = norm(nn.Conv2d(filters[0], filters[1], kernel_dim, **cnn_kwargs)) 52 | 53 | # (64,64) -> (32,32) 54 | self.d_conv_1 = norm(nn.Conv2d(filters[1], filters[2], kernel_dim, **cnn_kwargs)) 55 | 56 | # (32,32) -> (16,16) 57 | self.d_conv_a = norm(nn.Conv2d(filters[2], filters[3], kernel_dim, stride = 1,padding = 1)) 58 | 59 | self.d_conv_b = nn.Conv2d(filters[3], 1, kernel_dim, stride=1,padding = 1) 60 | 61 | for m in self.modules(): 62 | if isinstance(m, nn.Conv2d): 63 | torch.nn.init.normal_(m.weight.data,std=0.02) 64 | torch.nn.init.constant(m.bias.data, 0.0) 65 | 66 | def forward(self, x, y): 67 | """ 68 | x: Concatenated real/gen images 69 | y: Quantized latents 70 | """ 71 | batch_size = x.size()[0] 72 | 73 | # Concatenate upscaled encoder output y as contextual information 74 | y = self.activation(self.latent(y)) 75 | y = self.context_upsample(y) 76 | 77 | x = torch.cat((x, y), dim=1) 78 | x = self.activation(self.d_conv_head(x)) 79 | x = self.activation(self.d_conv_0(x)) 80 | x = self.activation(self.d_conv_1(x)) 81 | x = self.activation(self.d_conv_a(x)) 82 | out_logits = self.d_conv_b(x).view(-1, 1) 83 | out = torch.sigmoid(out_logits) 84 | 85 | return out, out_logits -------------------------------------------------------------------------------- /_pdjscc/net/normalization/GDN.py: -------------------------------------------------------------------------------- 1 | # This code is from https://github.com/jorge-pessoa/pytorch-gdn 2 | 3 | import torch 4 | import torch.utils.data 5 | from torch import nn, optim 6 | from torch.nn import functional as F 7 | #from torchvision import datasets, transforms 8 | #from torchvision.utils import save_image 9 | from torch.autograd import Function 10 | 11 | 12 | class LowerBound(Function): 13 | @staticmethod 14 | def forward(ctx, inputs, bound): 15 | b = torch.ones(inputs.size())*bound 16 | b = b.to(inputs.device) 17 | ctx.save_for_backward(inputs, b) 18 | return torch.max(inputs, b) 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | inputs, b = ctx.saved_tensors 23 | 24 | pass_through_1 = inputs >= b 25 | pass_through_2 = grad_output < 0 26 | 27 | pass_through = pass_through_1 | pass_through_2 28 | return pass_through.type(grad_output.dtype) * grad_output, None 29 | 30 | 31 | class GDN(nn.Module): 32 | """Generalized divisive normalization layer. 33 | y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j])) 34 | """ 35 | 36 | def __init__(self, 37 | ch, 38 | device, 39 | inverse, 40 | beta_min=1e-6, 41 | gamma_init=.1, 42 | reparam_offset=2**-18): 43 | super(GDN, self).__init__() 44 | self.inverse = inverse 45 | self.beta_min = beta_min 46 | self.gamma_init = gamma_init 47 | self.reparam_offset = torch.FloatTensor([reparam_offset]) 48 | 49 | self.build(ch, torch.device(device)) 50 | 51 | def build(self, ch, device): 52 | self.pedestal = self.reparam_offset**2 53 | self.beta_bound = (self.beta_min + self.reparam_offset**2)**.5 54 | self.gamma_bound = self.reparam_offset 55 | 56 | # Create beta param 57 | beta = torch.sqrt(torch.ones(ch)+self.pedestal) 58 | self.beta = nn.Parameter(beta.to(device)) 59 | 60 | # Create gamma param 61 | eye = torch.eye(ch) 62 | g = self.gamma_init*eye 63 | g = g + self.pedestal 64 | gamma = torch.sqrt(g) 65 | 66 | self.gamma = nn.Parameter(gamma.to(device)) 67 | self.pedestal = self.pedestal.to(device) 68 | 69 | def forward(self, inputs): 70 | unfold = False 71 | if inputs.dim() == 5: 72 | unfold = True 73 | bs, ch, d, w, h = inputs.size() 74 | inputs = inputs.view(bs, ch, d*w, h) 75 | 76 | _, ch, _, _ = inputs.size() 77 | 78 | # Beta bound and reparam 79 | lowerbound_beta = LowerBound.apply 80 | #beta = LowerBound()(self.beta, self.beta_bound) 81 | #print('aaaa') 82 | beta = lowerbound_beta(self.beta, self.beta_bound) 83 | beta = beta**2 - self.pedestal 84 | 85 | # Gamma bound and reparam 86 | lowerbound_gamma = LowerBound.apply 87 | #gamma = LowerBound()(self.gamma, self.gamma_bound) 88 | gamma = lowerbound_gamma(self.gamma, self.gamma_bound) 89 | gamma = gamma**2 - self.pedestal 90 | gamma = gamma.view(ch, ch, 1, 1) 91 | 92 | # Norm pool calc 93 | norm_ = nn.functional.conv2d(inputs**2, gamma, beta) 94 | norm_ = torch.sqrt(norm_) 95 | 96 | # Apply norm 97 | if self.inverse: 98 | outputs = inputs * norm_ 99 | else: 100 | outputs = inputs / norm_ 101 | 102 | if unfold: 103 | outputs = outputs.view(bs, ch, d, w, h) 104 | return outputs 105 | -------------------------------------------------------------------------------- /_pdjscc/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | import math 7 | import torch 8 | import random 9 | import os 10 | from torch.autograd import Variable 11 | import logging 12 | 13 | 14 | def CalcuPSNR(img1, img2, max_val=255.): 15 | """ 16 | Based on `tf.image.psnr` 17 | https://www.tensorflow.org/api_docs/python/tf/image/psnr 18 | """ 19 | float_type = 'float64' 20 | # img1 = (torch.clamp(img1,-1,1).cpu().numpy() + 1) / 2 * 255 21 | # img2 = (torch.clamp(img2,-1,1).cpu().numpy() + 1) / 2 * 255 22 | img1 = torch.clamp(img1, 0, 1).cpu().numpy() * 255 23 | img2 = torch.clamp(img2, 0, 1).cpu().numpy() * 255 24 | img1 = img1.astype(float_type) 25 | img2 = img2.astype(float_type) 26 | mse = np.mean(np.square(img1 - img2), axis=(1, 2, 3)) 27 | psnr = 20 * np.log10(max_val) - 10 * np.log10(mse) 28 | return psnr 29 | 30 | 31 | def MSE2PSNR(MSE): 32 | return 10 * math.log10(255 ** 2 / (MSE)) 33 | 34 | 35 | def logger_configuration(config, save_log=False, test_mode=False): 36 | # 配置 logger 37 | logger = logging.getLogger("Deep joint source channel coder") 38 | if save_log: 39 | makedirs(config.workdir) 40 | makedirs(config.samples) 41 | makedirs(config.models) 42 | formatter = logging.Formatter('%(asctime)s - %(levelname)s] %(message)s') 43 | stdhandler = logging.StreamHandler() 44 | stdhandler.setLevel(logging.INFO) 45 | stdhandler.setFormatter(formatter) 46 | logger.addHandler(stdhandler) 47 | if save_log: 48 | filehandler = logging.FileHandler(config.log) 49 | filehandler.setLevel(logging.INFO) 50 | filehandler.setFormatter(formatter) 51 | logger.addHandler(filehandler) 52 | logger.setLevel(logging.INFO) 53 | config.logger = logger 54 | return config.logger 55 | 56 | 57 | def Var(x, device): 58 | return Variable(x.to(device)) 59 | 60 | 61 | def single_plot(epoch, global_step, real, gen, config, number, single_compress=False): 62 | real = real.transpose([1, 2, 0]) 63 | gen = gen.transpose([1, 2, 0]) 64 | images = list() 65 | 66 | for im, imtype in zip([real, gen], ['real', 'gen']): 67 | # im = ((im + 1.0)) / 2 # [-1,1] -> [0,1] 68 | im = np.squeeze(im) 69 | if len(im.shape) == 3: 70 | im = im[:, :, :3] 71 | if len(im.shape) == 4: 72 | im = im[0, :, :, :3] 73 | images.append(im) 74 | 75 | comparison = np.hstack(images) 76 | 77 | f = plt.figure() 78 | plt.imshow(comparison) 79 | plt.axis('off') 80 | if single_compress: 81 | f.savefig(config.name, format='png', dpi=720, bbox_inches='tight', pad_inches=0) 82 | else: 83 | f.savefig( 84 | "{}/JSCCModel_{}_epoch{}_step{}_{}.png".format(config.samples, config.trainset, epoch, global_step, number), 85 | format='png', dpi=720, bbox_inches='tight', pad_inches=0) 86 | plt.gcf().clear() 87 | plt.close(f) 88 | 89 | 90 | def makedirs(directory): 91 | if not os.path.exists(directory): 92 | os.makedirs(directory) 93 | 94 | 95 | def save_model(model, save_path): 96 | torch.save(model.state_dict(), save_path) 97 | 98 | 99 | def seed_torch(seed=1029): 100 | random.seed(seed) 101 | os.environ['PYTHONHASHSEED'] = str(seed) # 为了禁止hash随机化,使得实验可复现 102 | np.random.seed(seed) 103 | torch.manual_seed(seed) 104 | torch.cuda.manual_seed(seed) 105 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 106 | torch.backends.cudnn.benchmark = False 107 | torch.backends.cudnn.deterministic = True 108 | -------------------------------------------------------------------------------- /_pdjscc/test.py: -------------------------------------------------------------------------------- 1 | import pyiqa 2 | 3 | from loss_utils.utils import * 4 | from config import config 5 | from data.dataset import get_test_loader 6 | from net.network import ADJSCC 7 | 8 | 9 | def CalcuPSNR(img1, img2, max_val=255.): 10 | """ 11 | Based on `tf.image.psnr` 12 | https://www.tensorflow.org/api_docs/python/tf/image/psnr 13 | """ 14 | float_type = 'float64' 15 | # img1 = (torch.clamp(img1,-1,1).cpu().numpy() + 1) / 2 * 255 16 | # img2 = (torch.clamp(img2,-1,1).cpu().numpy() + 1) / 2 * 255 17 | img1 = torch.clamp(img1, 0, 1).cpu().numpy() * 255 18 | img2 = torch.clamp(img2, 0, 1).cpu().numpy() * 255 19 | img1 = img1.astype(float_type) 20 | img2 = img2.astype(float_type) 21 | mse = np.mean(np.square(img1 - img2), axis=(1, 2, 3)) 22 | psnr = 20 * np.log10(max_val) - 10 * np.log10(mse) 23 | return psnr 24 | 25 | 26 | config.device = torch.device("cuda:0") 27 | logger = config.logger = logger_configuration(config, save_log=False, test_mode=True) 28 | 29 | lpips_metric = pyiqa.create_metric('lpips').to(config.device) 30 | dists_metric = pyiqa.create_metric('dists').to(config.device) 31 | niqe_metric = pyiqa.create_metric('niqe').to(config.device) 32 | 33 | # initialize model 34 | config.C = 8 35 | config.multiple_snr = [0] 36 | config.test_data_dir = ["./demo"] 37 | save_dir = './recon/demo/{}dB/C={}/recon'.format(config.multiple_snr[0], config.C) 38 | 39 | net = ADJSCC(config) 40 | net = net.cuda() 41 | num_params = 0 42 | for param in net.parameters(): 43 | num_params += param.numel() 44 | print("TOTAL Params {}M".format(num_params / 10 ** 6)) 45 | 46 | pre_dict = torch.load("./pretrained_model/CBR=0.08.model", 47 | map_location=config.device) 48 | net.load_state_dict(pre_dict, strict=False) 49 | 50 | # initialize dataset 51 | test_loader = get_test_loader(config) 52 | from loss_utils.utils import * 53 | 54 | avg_time = 0 55 | net.eval() 56 | global_step = -1 57 | with torch.no_grad(): 58 | elapsed, ms_ssim_losses, mse_losses, lpips_losses, psnrs, dists_losses, niqe = [AverageMeter() for _ in range(7)] 59 | metrics = [elapsed, mse_losses, ms_ssim_losses, lpips_losses, psnrs, dists_losses, niqe] 60 | for batch_idx, input_image in enumerate(test_loader): 61 | global_step += 1 62 | start_time = time.time() 63 | input_image = input_image.cuda() 64 | ms_ssim_loss, mse_loss, lpips_loss, x_hat = net.forward(input_image) 65 | niqe_loss = niqe_metric(x_hat) 66 | niqe.update(niqe_loss.item()) 67 | elapsed.update(time.time() - start_time) 68 | ms_ssim_losses.update(ms_ssim_loss.item()) 69 | mse_losses.update(mse_loss.item()) 70 | lpips_losses.update(lpips_loss.item()) 71 | dists_loss = dists_metric(input_image, x_hat) 72 | dists_losses.update(dists_loss.item()) 73 | if mse_loss.item() > 0: 74 | psnr = 10 * (torch.log(255. * 255. / mse_loss) / np.log(10)) 75 | psnrs.update(psnr.item()) 76 | else: 77 | psnrs.update(100) 78 | log = (' | '.join([ 79 | f'test_Time {elapsed.avg:.2f}', 80 | f'test_NIQE {niqe.val:.3f} ({niqe.avg:.3f})', 81 | f'test_PSNR {psnrs.val:.3f} ({psnrs.avg:.3f})', 82 | f'test_LPIPS {lpips_losses.val:.3f} ({lpips_losses.avg:.3f})', 83 | f'test_MS-SSIM {ms_ssim_losses.val:.3f} ({ms_ssim_losses.avg:.3f})', 84 | f'test_dists {dists_losses.val:.3f} ({dists_losses.avg:.3f})', 85 | ])) 86 | 87 | # makedirs 88 | if not os.path.exists(save_dir): 89 | os.makedirs(save_dir) 90 | fname = os.path.join(save_dir, 91 | "{}.png".format(global_step.__str__().zfill(5))) 92 | torchvision.utils.save_image(x_hat, fname, normalize=True) 93 | logger.info(log) 94 | for i in metrics: 95 | i.clear() 96 | -------------------------------------------------------------------------------- /channel/channel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Channel(nn.Module): 7 | def __init__(self, channel_type, SNR, logger, device, rescale=True): 8 | super(Channel, self).__init__() 9 | self.chan_type = channel_type 10 | self.chan_param = SNR 11 | self.device = device 12 | self.logger = logger 13 | self.rescale = rescale 14 | if self.logger: 15 | self.logger.info('【Channel】: Built {} channel, SNR {} dB'.format( 16 | channel_type, SNR)) 17 | 18 | def gaussian_noise_layer(self, input_layer, std): 19 | device = input_layer.get_device() 20 | noise_real = torch.normal(mean=0.0, std=std, size=np.shape(input_layer), device=device) 21 | noise_imag = torch.normal(mean=0.0, std=std, size=np.shape(input_layer), device=device) 22 | noise = noise_real + 1j * noise_imag 23 | return input_layer + noise 24 | 25 | def rayleigh_noise_layer(self, input_layer, std): 26 | device = input_layer.get_device() 27 | # fast rayleigh channel 28 | noise_real = torch.normal(mean=0.0, std=std, size=np.shape(input_layer), device=device) 29 | noise_imag = torch.normal(mean=0.0, std=std, size=np.shape(input_layer), device=device) 30 | noise = noise_real + 1j * noise_imag 31 | h = torch.sqrt(torch.normal(mean=0.0, std=1, size=np.shape(input_layer), device=device) ** 2 32 | + torch.normal(mean=0.0, std=1, size=np.shape(input_layer), device=device) ** 2) / np.sqrt(2) 33 | return input_layer * h + noise, h 34 | 35 | def forward(self, input, avg_pwr=None, power=1): 36 | B = input.size()[0] 37 | if avg_pwr is None: 38 | avg_pwr = torch.mean(input ** 2) 39 | channel_tx = np.sqrt(power) * input / torch.sqrt(avg_pwr * 2) 40 | else: 41 | channel_tx = np.sqrt(power) * input / torch.sqrt(avg_pwr * 2) 42 | input_shape = channel_tx.shape 43 | channel_in = channel_tx.reshape(B, -1) 44 | channel_in = channel_in[:, ::2] + channel_in[:, 1::2] * 1j 45 | channel_usage = channel_in.numel() 46 | if self.chan_type == 'awgn': 47 | channel_output = self.channel_forward(channel_in) 48 | channel_rx = torch.zeros_like(channel_tx.reshape(B, -1)) 49 | channel_rx[:, ::2] = torch.real(channel_output) 50 | channel_rx[:, 1::2] = torch.imag(channel_output) 51 | channel_rx = channel_rx.reshape(input_shape) 52 | if self.rescale: 53 | return channel_rx * torch.sqrt(avg_pwr * 2), channel_usage 54 | else: 55 | return channel_rx, channel_usage 56 | elif self.chan_type == 'rayleigh': 57 | channel_output, channel_response = self.channel_forward(channel_in) 58 | # h = torch.zeros_like(channel_tx.reshape(B, -1)) 59 | # h[:, ::2] = channel_response 60 | # h[:, 1::2] = channel_response 61 | # h = h.reshape(input_shape) 62 | channel_rx = torch.zeros_like(channel_tx.reshape(B, -1)) 63 | channel_rx[:, ::2] = torch.real(channel_output) 64 | channel_rx[:, 1::2] = torch.imag(channel_output) 65 | channel_rx = channel_rx.reshape(input_shape) 66 | if self.rescale: 67 | return channel_rx * torch.sqrt(avg_pwr * 2), channel_response, channel_usage 68 | else: 69 | return channel_rx, channel_response, channel_usage 70 | 71 | def channel_forward(self, channel_in): 72 | if self.chan_type == 0 or self.chan_type == 'noiseless': 73 | return channel_in 74 | 75 | elif self.chan_type == 1 or self.chan_type == 'awgn': 76 | channel_tx = channel_in 77 | sigma = np.sqrt(1.0 / (2 * 10 ** (self.chan_param / 10))) 78 | chan_output = self.gaussian_noise_layer(channel_tx, 79 | std=sigma) 80 | return chan_output 81 | 82 | elif self.chan_type == 2 or self.chan_type == 'rayleigh': 83 | channel_tx = channel_in 84 | sigma = np.sqrt(1.0 / (2 * 10 ** (self.chan_param / 10))) 85 | chan_output, h = self.rayleigh_noise_layer(channel_tx, 86 | std=sigma) 87 | return chan_output, h 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python >=3.8](https://img.shields.io/badge/Python->=3.7-yellow.svg) 2 | ![PyTorch >=1.9](https://img.shields.io/badge/PyTorch->=1.7-blue.svg) 3 | 4 | # DiffCom: Channel Received Signal is a Natural Condition to Guide Diffusion Posterior Sampling [[pdf]](https://arxiv.org/abs/2406.07390) 5 | 6 | Here is the implementation of the 7 | paper "[DiffCom: Channel Received Signal is a Natural Condition to Guide Diffusion Posterior Sampling](https://semcomm.github.io/DiffCom/)". 8 | 9 | Project website: [https://semcomm.github.io/DiffCom/](https://semcomm.github.io/DiffCom/) 10 | 11 | ## Abstract 12 | 13 | End-to-end visual communication systems typically optimize a trade-off between channel bandwidth costs and signal-level 14 | distortion metrics. However, under challenging physical conditions, this traditional coding and transmission paradigm 15 | often results in unrealistic reconstructions with perceptible blurring and aliasing artifacts, despite the inclusion of 16 | perceptual or adversarial losses for optimizing. This issue primarily stems from the receiver’s limited knowledge about 17 | the underlying data manifold and the use of deterministic decoding mechanisms. 18 | To address these limitations, this paper 19 | introduces DiffCom, a novel end-to-end generative communication paradigm that utilizes off-the-shelf generative priors 20 | and probabilistic diffusion models for decoding, thereby improving perceptual quality without heavily relying on 21 | bandwidth costs and received signal quality. Unlike traditional systems that rely on deterministic decoders optimized 22 | solely for distortion metrics, our DiffCom leverages raw channel-received signal as a fine-grained condition to guide 23 | stochastic posterior sampling. Our approach ensures that reconstructions remain on the manifold of real data with a 24 | novel confirming constraint, enhancing the robustness and reliability of the generated outcomes. 25 | 26 | ## Overview of the DiffCom system architecture 27 | 28 | 29 | 30 | ## RDP curves on [FFHQ](https://github.com/NVlabs/ffhq-dataset) testset 31 | 32 | 33 | 34 | ## Generalization to unseen wireless conditions 35 | 36 | 37 | 38 | ## Blind-DiffCom Achieves Pilot-Free Transmission 39 | 40 | 41 | 42 | ## Requirements 43 | 44 | Clone the repo and create a conda environment (we use PyTorch 1.9, CUDA 11.1). 45 | 46 | TODO: check the dependencies 47 | 48 | ## Model Download 49 | 50 | We provide 3 pre-trained ADJSCC models in [this link](https://drive.google.com/drive/folders/1N0EzzxCv1wh6JeFr0g8vkmB0Qj23ozZJ?usp=sharing), please download them and put them in the `_djscc/ckpt` folder. 51 | 52 | The pre-trained Diffusion models are available at the following links, please download them and put them in the `model_zoo` folder. 53 | 54 | | Model | Download link | 55 | |-------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------:| 56 | | 256x256_diffusion_uncond.pt(ILSVRC 2012 subset of ImageNet) | [download link](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt) | 57 | | ffhq_10m.pt | [download link](https://drive.google.com/drive/folders/1jElnRoFv7b31fG0v6pTSQkelbSX3xGZh?usp=sharing) | 58 | 59 | TODO: provide implementations of DiffCom based on \`SwinJSCC\` and \`NTSCC+\`. 60 | 61 | ## Inference Code 62 | 63 | ```bash 64 | python main_diffcom.py --opt ./configs/diffcom.yaml 65 | ``` 66 | 67 | Please check the `diffcom.yaml` file for more details. 68 | The codebase now supports `diffcom`, `hifi_diffcom`, and `blind_diffcom`. 69 | 70 | You can change the hyperparameters in the YAML file to run the corresponding method. 71 | 72 | ## Acknowledgement 73 | 74 | The implementation is based on [DPS](https://github.com/DPS2022/diffusion-posterior-sampling), [PSLD](https://github.com/LituRout/PSLD). 75 | We thank the authors for sharing their code. 76 | 77 | ## Citation 78 | 79 | If you find this code useful for your research, please cite our paper 80 | 81 | ``` 82 | @article{wang2024diffcom, 83 | title={DiffCom: Channel Received Signal is a Natural Condition to Guide Diffusion Posterior Sampling}, 84 | author={Wang, Sixian and Dai, Jincheng and Tan, Kailin and Qin, Xiaoqi and Niu, Kai and Zhang, Ping}, 85 | journal={arXiv preprint arXiv:2406.07390}, 86 | year={2024} 87 | } 88 | ``` -------------------------------------------------------------------------------- /website/src/components/section1/Section1.js: -------------------------------------------------------------------------------- 1 | import React, {Fragment} from "react"; 2 | 3 | const OverviewBlock = () => ( 4 |
5 |
6 | {"overview"} 10 | 11 |

12 | Overview of the DiffCom system architecture, and overall concept of 13 | the proposed method. 14 |

15 |
16 |
17 | ) 18 | 19 | const AbstactBlock = () => ( 20 |
21 |
22 |
23 |
24 |

Abstract

25 |
26 |

27 | End-to-end visual communication systems typically optimize a trade-off between channel 28 | bandwidth 29 | costs and signal-level distortion metrics. 30 | However, under challenging physical conditions, such a discriminative 31 | communication paradigm often results in unrealistic reconstructions with perceptible 32 | blurring and aliasing artifacts, despite the inclusion of perceptual or adversarial losses 33 | during training. This issue primarily stems from the receiver's limited knowledge about the 34 | underlying data 35 | manifold and the use of deterministic decoding mechanisms. 36 |

37 |

38 | We propose DiffCom, a novel end-to-end generative communication paradigm that utilizes off-the-shelf 40 | generative priors from diffusion models for decoding , thereby improving perceptual 41 | quality 42 | without heavily relying on bandwidth costs and received signal quality. 43 | Unlike traditional systems that rely on deterministic decoders optimized solely for 44 | distortion 45 | metrics, our DiffCom leverages raw channel-received signal as a fine-grained condition to guide stochastic posterior 46 | sampling. Our approach ensures that reconstructions remain on the manifold of real data 47 | with a 48 | novel confirming constraint, enhancing the robustness and reliability of the generated 49 | outcomes. 50 | Furthermore, DiffCom incorporates a blind posterior 51 | sampling 52 | technique to address 53 | scenarios with unknown forward transmission characteristics. 54 |

55 |

56 | Experimental results demonstrate that: 57 |

    58 |
  • DiffCom achieves SOTA transmission performance in 59 | terms of 60 | multiple perceptual quality metrics, such as LPIPS, DISTS, FID, and so on. 61 |
  • 62 |
  • DiffCom significantly enhances the robustness of 63 | current 64 | methods against various transmission-related degradations, including mismatched SNR, 65 | unseen 66 | fading, blind channel estimation, PAPR reduction, and inter-symbol interference. 67 |
  • 68 |
69 |

70 |
71 |
72 |
73 |
74 |
75 | ) 76 | 77 | const Section1 = () => { 78 | return ( 79 | 80 |
81 | 82 | 83 |
84 | ); 85 | } 86 | 87 | export default Section1; 88 | -------------------------------------------------------------------------------- /website/src/App.css: -------------------------------------------------------------------------------- 1 | .App { 2 | text-align: center; 3 | } 4 | 5 | .App-logo { 6 | height: 40vmin; 7 | pointer-events: none; 8 | } 9 | 10 | @media (prefers-reduced-motion: no-preference) { 11 | .App-logo { 12 | animation: App-logo-spin infinite 20s linear; 13 | } 14 | } 15 | 16 | .App-header { 17 | background-color: #282c34; 18 | min-height: 100vh; 19 | display: flex; 20 | flex-direction: column; 21 | align-items: center; 22 | justify-content: center; 23 | font-size: calc(10px + 2vmin); 24 | color: white; 25 | } 26 | 27 | .App-link { 28 | color: #61dafb; 29 | } 30 | 31 | @keyframes App-logo-spin { 32 | from { 33 | transform: rotate(0deg); 34 | } 35 | to { 36 | transform: rotate(360deg); 37 | } 38 | } 39 | 40 | html { 41 | background-color: #fff; 42 | font-size: 16px; 43 | -moz-osx-font-smoothing: grayscale; 44 | -webkit-font-smoothing: antialiased; 45 | min-width: 300px; 46 | overflow-x: hidden; 47 | overflow-y: scroll; 48 | text-rendering: optimizeLegibility; 49 | -webkit-text-size-adjust: 100%; 50 | -moz-text-size-adjust: 100%; 51 | -ms-text-size-adjust: 100%; 52 | text-size-adjust: 100%; 53 | } 54 | 55 | a { 56 | color: #3273dc; 57 | cursor: pointer; 58 | text-decoration: none; 59 | } 60 | 61 | p { 62 | display: block; 63 | margin-block-start: 1em; 64 | margin-block-end: 1em; 65 | margin-inline-start: 0px; 66 | margin-inline-end: 0px; 67 | } 68 | 69 | img { 70 | height: auto; 71 | max-width: 100%; 72 | } 73 | 74 | body { 75 | color: #4a4a4a; 76 | font-size: 1em; 77 | font-weight: 400; 78 | line-height: 1.5; 79 | } 80 | 81 | pre { 82 | -webkit-overflow-scrolling: touch; 83 | background-color: #f5f5f5; 84 | color: #4a4a4a; 85 | font-size: .875em; 86 | overflow-x: auto; 87 | padding: 1.25rem 1.5rem; 88 | white-space: pre; 89 | word-wrap: normal; 90 | } 91 | 92 | code, pre { 93 | -moz-osx-font-smoothing: auto; 94 | -webkit-font-smoothing: auto; 95 | font-family: monospace; 96 | } 97 | 98 | .footer { 99 | background-color: #fafafa; 100 | padding: 3rem 1.5rem 6rem; 101 | } 102 | 103 | .title { 104 | color: #363636; 105 | font-size: 2rem; 106 | font-weight: 600; 107 | line-height: 1.125; 108 | } 109 | 110 | .subtitle { 111 | color: #4a4a4a; 112 | font-size: 1.25rem; 113 | font-weight: 400; 114 | line-height: 1.25; 115 | } 116 | 117 | .subtitle, .title { 118 | word-break: break-word; 119 | } 120 | 121 | .title.is-1{ 122 | font-size: 3rem; 123 | } 124 | 125 | .title.is-2 { 126 | margin-top: -0.5rem; 127 | font-size: 2.0rem; 128 | } 129 | 130 | .title.is-3 { 131 | margin-top: -0.5rem; 132 | font-size: 1.7rem; 133 | } 134 | 135 | .is-size-5 { 136 | font-size: 1.25rem!important; 137 | } 138 | 139 | .publication-title { 140 | font-family: 'Google Sans', sans-serif; 141 | } 142 | 143 | .publication-authors { 144 | font-family: 'Google Sans', sans-serif; 145 | } 146 | 147 | .publication-authors a { 148 | color: hsl(204, 86%, 53%) !important; 149 | } 150 | 151 | .author-block { 152 | display: inline-block; 153 | } 154 | 155 | .author-block-small { 156 | display: inline-block; 157 | font-size: 1rem; 158 | } 159 | 160 | @media screen and (min-width: 1024px){ 161 | .container { 162 | max-width: 960px; 163 | }} 164 | 165 | .hero { 166 | align-items: stretch; 167 | display: flex; 168 | flex-direction: column; 169 | justify-content: space-between; 170 | } 171 | 172 | .hero-body { 173 | flex-grow: 1; 174 | flex-shrink: 0; 175 | padding: 1.5rem 1.5rem; 176 | } 177 | 178 | .hero.is-light { 179 | background-color: #f5f5f5; 180 | color: rgba(0,0,0,.7); 181 | } 182 | 183 | .container { 184 | flex-grow: 1; 185 | margin: 0 auto; 186 | position: relative; 187 | width: auto; 188 | } 189 | .section { 190 | padding: 1.5rem 1.5rem; 191 | } 192 | 193 | .column{ 194 | display: block; 195 | flex-basis: 0; 196 | flex-grow: 1; 197 | flex-shrink: 1; 198 | padding: 0.75rem; 199 | } 200 | 201 | @media screen and (min-width: 769px), print{ 202 | .column.is-four-fifths, .column.is-four-fifths-tablet { 203 | flex: none; 204 | width: 80%; 205 | }} 206 | 207 | .columns { 208 | margin-left: -.75rem; 209 | margin-right: -.75rem; 210 | margin-top: -.75rem; 211 | } 212 | 213 | @media screen and (min-width: 769px), print{ 214 | .columns:not(.is-desktop) { 215 | display: flex; 216 | }} 217 | 218 | .columns.is-centered { 219 | justify-content: center; 220 | } 221 | .columns:last-child { 222 | margin-bottom: -.75rem; 223 | } 224 | 225 | .has-text-centered { 226 | text-align: center!important; 227 | } 228 | 229 | .has-text-justified { 230 | text-align: justify!important; 231 | } 232 | 233 | 234 | .task-description { 235 | text-align: center; 236 | padding: 2rem 0; 237 | } 238 | 239 | .result-display { 240 | padding: 0.3rem 0; 241 | } 242 | 243 | /* button style */ 244 | .button { 245 | background-color: #fff; 246 | border-color: #dbdbdb; 247 | border-width: 1px; 248 | color: #363636; 249 | cursor: pointer; 250 | justify-content: center; 251 | padding-bottom: calc(.5em - 1px); 252 | padding-left: 1em; 253 | padding-right: 1em; 254 | padding-top: calc(.5em - 1px); 255 | text-align: center; 256 | white-space: nowrap; 257 | } 258 | 259 | 260 | .task-btns { 261 | display: flex; 262 | justify-content: center; 263 | align-items: center; 264 | padding: 1.0rem 0; 265 | } 266 | 267 | .generate-progress { 268 | justify-content: center; 269 | } -------------------------------------------------------------------------------- /guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /guided_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/dists_loss/DISTS_pt.py: -------------------------------------------------------------------------------- 1 | # This is a pytoch implementation of DISTS metric. 2 | # Requirements: python >= 3.6, pytorch >= 1.0 3 | 4 | import numpy as np 5 | import os, sys 6 | import torch 7 | from torchvision import models, transforms 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class L2pooling(nn.Module): 13 | def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): 14 | super(L2pooling, self).__init__() 15 | self.padding = (filter_size - 2) // 2 16 | self.stride = stride 17 | self.channels = channels 18 | a = np.hanning(filter_size)[1:-1] 19 | g = torch.Tensor(a[:, None] * a[None, :]) 20 | g = g / torch.sum(g) 21 | self.register_buffer('filter', g[None, None, :, :].repeat((self.channels, 1, 1, 1))) 22 | 23 | def forward(self, input): 24 | input = input ** 2 25 | out = F.conv2d(input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1]) 26 | return (out + 1e-12).sqrt() 27 | 28 | 29 | class DISTS(torch.nn.Module): 30 | def __init__(self, load_weights=True): 31 | super(DISTS, self).__init__() 32 | vgg_pretrained_features = models.vgg16(pretrained=True).features 33 | self.stage1 = torch.nn.Sequential() 34 | self.stage2 = torch.nn.Sequential() 35 | self.stage3 = torch.nn.Sequential() 36 | self.stage4 = torch.nn.Sequential() 37 | self.stage5 = torch.nn.Sequential() 38 | for x in range(0, 4): 39 | self.stage1.add_module(str(x), vgg_pretrained_features[x]) 40 | self.stage2.add_module(str(4), L2pooling(channels=64)) 41 | for x in range(5, 9): 42 | self.stage2.add_module(str(x), vgg_pretrained_features[x]) 43 | self.stage3.add_module(str(9), L2pooling(channels=128)) 44 | for x in range(10, 16): 45 | self.stage3.add_module(str(x), vgg_pretrained_features[x]) 46 | self.stage4.add_module(str(16), L2pooling(channels=256)) 47 | for x in range(17, 23): 48 | self.stage4.add_module(str(x), vgg_pretrained_features[x]) 49 | self.stage5.add_module(str(23), L2pooling(channels=512)) 50 | for x in range(24, 30): 51 | self.stage5.add_module(str(x), vgg_pretrained_features[x]) 52 | 53 | for param in self.parameters(): 54 | param.requires_grad = False 55 | 56 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1)) 57 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1)) 58 | 59 | self.chns = [3, 64, 128, 256, 512, 512] 60 | self.register_parameter("alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))) 61 | self.register_parameter("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))) 62 | self.alpha.data.normal_(0.1, 0.01) 63 | self.beta.data.normal_(0.1, 0.01) 64 | if load_weights: 65 | weights = torch.load(os.path.join(sys.prefix, '/media/D/wangjun/DEEPJSCC/DISTS/weights.pt')) 66 | self.alpha.data = weights['alpha'] 67 | self.beta.data = weights['beta'] 68 | 69 | def forward_once(self, x): 70 | h = (x - self.mean) / self.std 71 | h = self.stage1(h) 72 | h_relu1_2 = h 73 | h = self.stage2(h) 74 | h_relu2_2 = h 75 | h = self.stage3(h) 76 | h_relu3_3 = h 77 | h = self.stage4(h) 78 | h_relu4_3 = h 79 | h = self.stage5(h) 80 | h_relu5_3 = h 81 | return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3] 82 | 83 | def forward(self, x, y, require_grad=False, batch_average=False): 84 | if require_grad: 85 | feats0 = self.forward_once(x) 86 | feats1 = self.forward_once(y) 87 | else: 88 | with torch.no_grad(): 89 | feats0 = self.forward_once(x) 90 | feats1 = self.forward_once(y) 91 | dist1 = 0 92 | dist2 = 0 93 | c1 = 1e-6 94 | c2 = 1e-6 95 | w_sum = self.alpha.sum() + self.beta.sum() 96 | alpha = torch.split(self.alpha / w_sum, self.chns, dim=1) 97 | beta = torch.split(self.beta / w_sum, self.chns, dim=1) 98 | for k in range(len(self.chns)): 99 | x_mean = feats0[k].mean([2, 3], keepdim=True) 100 | y_mean = feats1[k].mean([2, 3], keepdim=True) 101 | S1 = (2 * x_mean * y_mean + c1) / (x_mean ** 2 + y_mean ** 2 + c1) 102 | dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True) 103 | 104 | x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True) 105 | y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True) 106 | xy_cov = (feats0[k] * feats1[k]).mean([2, 3], keepdim=True) - x_mean * y_mean 107 | S2 = (2 * xy_cov + c2) / (x_var + y_var + c2) 108 | dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True) 109 | 110 | score = 1 - (dist1 + dist2).squeeze() 111 | if batch_average: 112 | return score.mean() 113 | else: 114 | return score 115 | 116 | 117 | def prepare_image(image, resize=True): 118 | if resize and min(image.size) > 256: 119 | image = transforms.functional.resize(image, 256) 120 | image = transforms.ToTensor()(image) 121 | return image.unsqueeze(0) 122 | 123 | 124 | if __name__ == '__main__': 125 | from PIL import Image 126 | import argparse 127 | 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--ref', type=str, default='/media/D/wangjun/kodak1/kodim01.png') 130 | parser.add_argument('--dist', type=str, default='/media/D/wangjun/kodak1/test.png') 131 | args = parser.parse_args() 132 | 133 | ref = prepare_image(Image.open(args.ref).convert("RGB")) 134 | dist = prepare_image(Image.open(args.dist).convert("RGB")) 135 | assert ref.shape == dist.shape 136 | 137 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 138 | model = DISTS().to(device) 139 | ref = ref.to(device) 140 | dist = dist.to(device) 141 | score = model(ref, dist) 142 | print(score.item()) -------------------------------------------------------------------------------- /_pdjscc/net/network.py: -------------------------------------------------------------------------------- 1 | from .encoder import * 2 | from .decoder import * 3 | from .discriminator import * 4 | from CommonModules.loss.distortion import Distortion 5 | from .channel import Channel 6 | from random import choice 7 | from CommonModules.loss.distortion import MS_SSIM 8 | from CommonModules.perceptual_similarity.perceptual_loss import PerceptualLoss 9 | from CommonModules.loss import gan_loss 10 | from collections import namedtuple 11 | from functools import partial 12 | 13 | 14 | def pad_factor(input_image, spatial_dims, factor): 15 | """Pad `input_image` (N,C,H,W) such that H and W are divisible by `factor`.""" 16 | 17 | if isinstance(factor, int) is True: 18 | factor_H = factor 19 | factor_W = factor_H 20 | else: 21 | factor_H, factor_W = factor 22 | 23 | H, W = spatial_dims[0], spatial_dims[1] 24 | pad_H = (factor_H - (H % factor_H)) % factor_H 25 | pad_W = (factor_W - (W % factor_W)) % factor_W 26 | return F.pad(input_image, pad=(0, pad_W, 0, pad_H), mode='reflect') 27 | 28 | 29 | class ADJSCC(nn.Module): 30 | def __init__(self, config): 31 | super(ADJSCC, self).__init__() 32 | if config.logger: 33 | config.logger.info("【Network】: Built Distributed JSCC model, C={}, k/n={}".format(config.C, config.kdivn)) 34 | 35 | self.config = config 36 | self.Encoder = Encoder(config) 37 | self.Decoder = Decoder(config) 38 | if config.use_discriminator: 39 | self.Discriminator = Discriminator(image_dims = config.image_dims, C=config.C) 40 | self.use_discriminator = config.use_discriminator 41 | self.channel = Channel(config) 42 | self.pass_channel = config.pass_channel 43 | self.MS_SSIM = MS_SSIM(data_range=1., levels=4, channel=3).cuda() 44 | self._lpips = PerceptualLoss(model='net-lin', net='alex', use_gpu=torch.cuda.is_available(), gpu_ids=[torch.device("cuda:0")]) 45 | self.gan_loss = partial(gan_loss.gan_loss, config.gan_loss_type) 46 | self.distortion_loss = Distortion(config) 47 | def feature_pass_channel(self, feature): 48 | noisy_feature = self.channel(feature) 49 | return noisy_feature 50 | def discriminator_forward(self, reconstruction, input_image, latents_quantized, train_generator): 51 | """ Train on gen/real batches simultaneously. """ 52 | x_gen = reconstruction 53 | x_real = input_image 54 | Disc_out = namedtuple("disc_out", 55 | ["D_real", "D_gen", "D_real_logits", "D_gen_logits"]) 56 | 57 | # Alternate between training discriminator and compression models 58 | if train_generator is False: 59 | x_gen = x_gen.detach() 60 | 61 | D_in = torch.cat([x_real, x_gen], dim=0) 62 | 63 | latents = latents_quantized.detach() 64 | latents = torch.repeat_interleave(latents, 2, dim=0) 65 | 66 | D_out, D_out_logits = self.Discriminator(D_in, latents) 67 | D_out = torch.squeeze(D_out) 68 | D_out_logits = torch.squeeze(D_out_logits) 69 | 70 | D_real, D_gen = torch.chunk(D_out, 2, dim=0) 71 | D_real_logits, D_gen_logits = torch.chunk(D_out_logits, 2, dim=0) 72 | 73 | return Disc_out(D_real, D_gen, D_real_logits, D_gen_logits) 74 | 75 | def GAN_loss(self,reconstruction, input_image,latents_quantized,train_generator=False): 76 | """ 77 | train_generator: Flag to send gradients to generator 78 | """ 79 | disc_out = self.discriminator_forward(reconstruction, input_image,latents_quantized,train_generator) 80 | D_loss = self.gan_loss(disc_out, mode='discriminator_loss') 81 | G_loss = self.gan_loss(disc_out, mode='generator_loss') 82 | D_gen = torch.mean(disc_out.D_gen).item() 83 | D_real = torch.mean(disc_out.D_real).item() 84 | 85 | 86 | return D_loss, G_loss,D_gen,D_real 87 | 88 | 89 | 90 | def forward(self, input_sequence,train_generator = True,given_SNR=None): 91 | B, C, H, W = input_sequence.shape 92 | if self.training == False: 93 | n_encoder_downsamples = self.Encoder.n_downsampling_layers 94 | factor = 2 ** n_encoder_downsamples 95 | x = pad_factor(input_sequence, input_sequence.size()[2:], factor) 96 | else: 97 | x = input_sequence 98 | 99 | if given_SNR is not None: 100 | self.channel.chan_param = given_SNR 101 | else: 102 | random_SNR = choice(self.config.multiple_snr) 103 | self.channel.chan_param = random_SNR 104 | 105 | SNR = torch.ones([B, 1]).to(x.device) * self.channel.chan_param 106 | feature = self.Encoder(x, SNR) 107 | if self.pass_channel: 108 | noisy_feature = self.feature_pass_channel(feature) 109 | else: 110 | noisy_feature = feature 111 | x_hat = self.Decoder(noisy_feature, SNR) 112 | if self.training == False: 113 | x_hat = x_hat[:, :, :H, :W] 114 | mse_loss = self.distortion_loss(input_sequence, x_hat) 115 | lpips_loss = self._lpips(input_sequence, x_hat, normalize=True).mean() 116 | ms_ssim_loss = self.MS_SSIM(input_sequence, x_hat).mean() 117 | return ms_ssim_loss,mse_loss, lpips_loss, x_hat 118 | else: 119 | mse_loss = self.distortion_loss(input_sequence, x_hat) 120 | lpips_loss = self._lpips(input_sequence,x_hat,normalize=True).mean() 121 | ms_ssim_loss = self.MS_SSIM(input_sequence, x_hat).mean() 122 | if self.use_discriminator: 123 | D_loss, G_loss,D_gen,D_real = self.GAN_loss(x_hat,x,feature,train_generator) 124 | return ms_ssim_loss,mse_loss,lpips_loss,x_hat,D_loss, G_loss,D_gen,D_real 125 | else: 126 | return ms_ssim_loss, mse_loss, lpips_loss, x_hat 127 | 128 | 129 | if __name__ == '__main__': 130 | import torch 131 | import torch.nn.functional as F 132 | from ADJSCC.config import config 133 | 134 | input_Tensor = torch.ones([2, 3, 256, 256]).cuda() 135 | model = ADJSCC(config).cuda() 136 | recon_image, distortion_loss = model(input_Tensor) 137 | print(recon_image.shape) 138 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/dists_loss/DISTS/DISTS_pt.py: -------------------------------------------------------------------------------- 1 | # This is a pytoch implementation of DISTS metric. 2 | # Requirements: python >= 3.6, pytorch >= 1.0 3 | 4 | import numpy as np 5 | import os, sys 6 | import torch 7 | from torchvision import models, transforms 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import os 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 12 | 13 | class L2pooling(nn.Module): 14 | def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): 15 | super(L2pooling, self).__init__() 16 | self.padding = (filter_size - 2) // 2 17 | self.stride = stride 18 | self.channels = channels 19 | a = np.hanning(filter_size)[1:-1] 20 | g = torch.Tensor(a[:, None] * a[None, :]) 21 | g = g / torch.sum(g) 22 | self.register_buffer('filter', g[None, None, :, :].repeat((self.channels, 1, 1, 1))) 23 | 24 | def forward(self, input): 25 | input = input ** 2 26 | out = F.conv2d(input, self.filter, stride=self.stride, padding=self.padding, groups=input.shape[1]) 27 | return (out + 1e-12).sqrt() 28 | 29 | 30 | class DISTS(torch.nn.Module): 31 | def __init__(self, load_weights=True): 32 | super(DISTS, self).__init__() 33 | vgg_pretrained_features = models.vgg16(pretrained=True).features 34 | self.stage1 = torch.nn.Sequential() 35 | self.stage2 = torch.nn.Sequential() 36 | self.stage3 = torch.nn.Sequential() 37 | self.stage4 = torch.nn.Sequential() 38 | self.stage5 = torch.nn.Sequential() 39 | for x in range(0, 4): 40 | self.stage1.add_module(str(x), vgg_pretrained_features[x]) 41 | self.stage2.add_module(str(4), L2pooling(channels=64)) 42 | for x in range(5, 9): 43 | self.stage2.add_module(str(x), vgg_pretrained_features[x]) 44 | self.stage3.add_module(str(9), L2pooling(channels=128)) 45 | for x in range(10, 16): 46 | self.stage3.add_module(str(x), vgg_pretrained_features[x]) 47 | self.stage4.add_module(str(16), L2pooling(channels=256)) 48 | for x in range(17, 23): 49 | self.stage4.add_module(str(x), vgg_pretrained_features[x]) 50 | self.stage5.add_module(str(23), L2pooling(channels=512)) 51 | for x in range(24, 30): 52 | self.stage5.add_module(str(x), vgg_pretrained_features[x]) 53 | 54 | for param in self.parameters(): 55 | param.requires_grad = False 56 | 57 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1)) 58 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1)) 59 | 60 | self.chns = [3, 64, 128, 256, 512, 512] 61 | self.register_parameter("alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))) 62 | self.register_parameter("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1))) 63 | self.alpha.data.normal_(0.1, 0.01) 64 | self.beta.data.normal_(0.1, 0.01) 65 | if load_weights: 66 | weights = torch.load(os.path.join(sys.prefix,'/media/D/wangjun/DEEPJSCC/DISTS/weights.pt')) 67 | self.alpha.data = weights['alpha'] 68 | self.beta.data = weights['beta'] 69 | 70 | def forward_once(self, x): 71 | h = (x - self.mean) / self.std 72 | h = self.stage1(h) 73 | h_relu1_2 = h 74 | h = self.stage2(h) 75 | h_relu2_2 = h 76 | h = self.stage3(h) 77 | h_relu3_3 = h 78 | h = self.stage4(h) 79 | h_relu4_3 = h 80 | h = self.stage5(h) 81 | h_relu5_3 = h 82 | return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3] 83 | 84 | def forward(self, x, y, require_grad=False, batch_average=False): 85 | if require_grad: 86 | feats0 = self.forward_once(x) 87 | feats1 = self.forward_once(y) 88 | else: 89 | with torch.no_grad(): 90 | feats0 = self.forward_once(x) 91 | feats1 = self.forward_once(y) 92 | dist1 = 0 93 | dist2 = 0 94 | c1 = 1e-6 95 | c2 = 1e-6 96 | w_sum = self.alpha.sum() + self.beta.sum() 97 | alpha = torch.split(self.alpha / w_sum, self.chns, dim=1) 98 | beta = torch.split(self.beta / w_sum, self.chns, dim=1) 99 | for k in range(len(self.chns)): 100 | x_mean = feats0[k].mean([2, 3], keepdim=True) 101 | y_mean = feats1[k].mean([2, 3], keepdim=True) 102 | S1 = (2 * x_mean * y_mean + c1) / (x_mean ** 2 + y_mean ** 2 + c1) 103 | dist1 = dist1 + (alpha[k] * S1).sum(1, keepdim=True) 104 | 105 | x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True) 106 | y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True) 107 | xy_cov = (feats0[k] * feats1[k]).mean([2, 3], keepdim=True) - x_mean * y_mean 108 | S2 = (2 * xy_cov + c2) / (x_var + y_var + c2) 109 | dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True) 110 | 111 | score = 1 - (dist1 + dist2).squeeze() 112 | if batch_average: 113 | return score.mean() 114 | else: 115 | return score 116 | 117 | 118 | def prepare_image(image, resize=True): 119 | if resize and min(image.size) > 256: 120 | image = transforms.functional.resize(image, 256) 121 | image = transforms.ToTensor()(image) 122 | return image.unsqueeze(0) 123 | 124 | 125 | if __name__ == '__main__': 126 | from PIL import Image 127 | import argparse 128 | 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument('--ref', type=str, default='/media/D/wangjun/kodak1/kodim01.png') 131 | parser.add_argument('--dist', type=str, default='/media/D/wangjun/kodak1/test.png') 132 | args = parser.parse_args() 133 | 134 | ref = prepare_image(Image.open(args.ref).convert("RGB")) 135 | dist = prepare_image(Image.open(args.dist).convert("RGB")) 136 | assert ref.shape == dist.shape 137 | 138 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 139 | model = DISTS().to(device) 140 | ref = ref.to(device) 141 | dist = dist.to(device) 142 | score = model(ref, dist) 143 | print(score.item()) 144 | # score: 0.3347 -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | # from skimage.measure import compare_ssim 7 | import torch 8 | 9 | from . import dist_model 10 | 11 | 12 | class PerceptualLoss(torch.nn.Module): 13 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], 14 | version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | super(PerceptualLoss, self).__init__() 17 | print('Setting up Perceptual loss...') 18 | self.use_gpu = use_gpu 19 | self.spatial = spatial 20 | self.gpu_ids = gpu_ids 21 | self.model = dist_model.DistModel() 22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, 23 | gpu_ids=gpu_ids, version=version) 24 | print('...[%s] initialized' % self.model.name()) 25 | print('...Done') 26 | 27 | def forward(self, pred, target, normalize=False): 28 | """ 29 | Pred and target are Variables. 30 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 31 | If normalize is False, assumes the images are already between [-1,+1] 32 | 33 | Inputs pred and target are Nx3xHxW 34 | Output pytorch Variable N long 35 | """ 36 | 37 | if normalize: 38 | target = 2 * target - 1 39 | pred = 2 * pred - 1 40 | 41 | return self.model.forward(target, pred) 42 | 43 | 44 | def normalize_tensor(in_feat, eps=1e-10): 45 | l2_norm = torch.sum(in_feat ** 2, dim=1, keepdim=True) 46 | norm_factor = torch.sqrt(l2_norm + eps) 47 | # return in_feat/(norm_factor+eps) 48 | return in_feat / (norm_factor) 49 | 50 | 51 | def l2(p0, p1, range=255.): 52 | return .5 * np.mean((p0 / range - p1 / range) ** 2) 53 | 54 | 55 | def psnr(p0, p1, peak=255.): 56 | return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2)) 57 | 58 | 59 | # def dssim(p0, p1, range=255.): 60 | # return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 61 | 62 | def rgb2lab(in_img, mean_cent=False): 63 | from skimage import color 64 | img_lab = color.rgb2lab(in_img) 65 | if (mean_cent): 66 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 67 | return img_lab 68 | 69 | 70 | def tensor2np(tensor_obj): 71 | # change dimension of a tensor object into a numpy array 72 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 73 | 74 | 75 | def np2tensor(np_obj): 76 | # change dimenion of np array into tensor array 77 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 78 | 79 | 80 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 81 | # image tensor to lab tensor 82 | from skimage import color 83 | 84 | img = tensor2im(image_tensor) 85 | img_lab = color.rgb2lab(img) 86 | if (mc_only): 87 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 88 | if (to_norm and not mc_only): 89 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 90 | img_lab = img_lab / 100. 91 | 92 | return np2tensor(img_lab) 93 | 94 | 95 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 96 | from skimage import color 97 | import warnings 98 | warnings.filterwarnings("ignore") 99 | 100 | lab = tensor2np(lab_tensor) * 100. 101 | lab[:, :, 0] = lab[:, :, 0] + 50 102 | 103 | rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1) 104 | if (return_inbnd): 105 | # convert back to lab, see if we match 106 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 107 | mask = 1. * np.isclose(lab_back, lab, atol=2.) 108 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 109 | return (im2tensor(rgb_back), mask) 110 | else: 111 | return im2tensor(rgb_back) 112 | 113 | 114 | def rgb2lab(input): 115 | from skimage import color 116 | return color.rgb2lab(input / 255.) 117 | 118 | 119 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): 120 | image_numpy = image_tensor[0].cpu().float().numpy() 121 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 122 | return image_numpy.astype(imtype) 123 | 124 | 125 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): 126 | return torch.Tensor((image / factor - cent) 127 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 128 | 129 | 130 | def tensor2vec(vector_tensor): 131 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 132 | 133 | 134 | def voc_ap(rec, prec, use_07_metric=False): 135 | """ ap = voc_ap(rec, prec, [use_07_metric]) 136 | Compute VOC AP given precision and recall. 137 | If use_07_metric is true, uses the 138 | VOC 07 11 point method (default:False). 139 | """ 140 | if use_07_metric: 141 | # 11 point metric 142 | ap = 0. 143 | for t in np.arange(0., 1.1, 0.1): 144 | if np.sum(rec >= t) == 0: 145 | p = 0 146 | else: 147 | p = np.max(prec[rec >= t]) 148 | ap = ap + p / 11. 149 | else: 150 | # correct AP calculation 151 | # first append sentinel values at the end 152 | mrec = np.concatenate(([0.], rec, [1.])) 153 | mpre = np.concatenate(([0.], prec, [0.])) 154 | 155 | # compute the precision envelope 156 | for i in range(mpre.size - 1, 0, -1): 157 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 158 | 159 | # to calculate area under PR curve, look for points 160 | # where X axis (recall) changes value 161 | i = np.where(mrec[1:] != mrec[:-1])[0] 162 | 163 | # and sum (\Delta recall) * prec 164 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 165 | return ap 166 | 167 | 168 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): 169 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 170 | image_numpy = image_tensor[0].cpu().float().numpy() 171 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 172 | return image_numpy.astype(imtype) 173 | 174 | 175 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): 176 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 177 | return torch.Tensor((image / factor - cent) 178 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 179 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | class squeezenet(torch.nn.Module): 6 | def __init__(self, requires_grad=False, pretrained=True): 7 | super(squeezenet, self).__init__() 8 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 9 | self.slice1 = torch.nn.Sequential() 10 | self.slice2 = torch.nn.Sequential() 11 | self.slice3 = torch.nn.Sequential() 12 | self.slice4 = torch.nn.Sequential() 13 | self.slice5 = torch.nn.Sequential() 14 | self.slice6 = torch.nn.Sequential() 15 | self.slice7 = torch.nn.Sequential() 16 | self.N_slices = 7 17 | for x in range(2): 18 | self.slice1.add_module(str(x), pretrained_features[x]) 19 | for x in range(2,5): 20 | self.slice2.add_module(str(x), pretrained_features[x]) 21 | for x in range(5, 8): 22 | self.slice3.add_module(str(x), pretrained_features[x]) 23 | for x in range(8, 10): 24 | self.slice4.add_module(str(x), pretrained_features[x]) 25 | for x in range(10, 11): 26 | self.slice5.add_module(str(x), pretrained_features[x]) 27 | for x in range(11, 12): 28 | self.slice6.add_module(str(x), pretrained_features[x]) 29 | for x in range(12, 13): 30 | self.slice7.add_module(str(x), pretrained_features[x]) 31 | if not requires_grad: 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | def forward(self, X): 36 | h = self.slice1(X) 37 | h_relu1 = h 38 | h = self.slice2(h) 39 | h_relu2 = h 40 | h = self.slice3(h) 41 | h_relu3 = h 42 | h = self.slice4(h) 43 | h_relu4 = h 44 | h = self.slice5(h) 45 | h_relu5 = h 46 | h = self.slice6(h) 47 | h_relu6 = h 48 | h = self.slice7(h) 49 | h_relu7 = h 50 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 51 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 52 | 53 | return out 54 | 55 | 56 | class alexnet(torch.nn.Module): 57 | def __init__(self, requires_grad=False, pretrained=True): 58 | super(alexnet, self).__init__() 59 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 60 | self.slice1 = torch.nn.Sequential() 61 | self.slice2 = torch.nn.Sequential() 62 | self.slice3 = torch.nn.Sequential() 63 | self.slice4 = torch.nn.Sequential() 64 | self.slice5 = torch.nn.Sequential() 65 | self.N_slices = 5 66 | for x in range(2): 67 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 68 | for x in range(2, 5): 69 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 70 | for x in range(5, 8): 71 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 72 | for x in range(8, 10): 73 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 74 | for x in range(10, 12): 75 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 76 | if not requires_grad: 77 | for param in self.parameters(): 78 | param.requires_grad = False 79 | 80 | def forward(self, X): 81 | h = self.slice1(X) 82 | h_relu1 = h 83 | h = self.slice2(h) 84 | h_relu2 = h 85 | h = self.slice3(h) 86 | h_relu3 = h 87 | h = self.slice4(h) 88 | h_relu4 = h 89 | h = self.slice5(h) 90 | h_relu5 = h 91 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 92 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 93 | 94 | return out 95 | 96 | class vgg16(torch.nn.Module): 97 | def __init__(self, requires_grad=False, pretrained=True): 98 | super(vgg16, self).__init__() 99 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 100 | self.slice1 = torch.nn.Sequential() 101 | self.slice2 = torch.nn.Sequential() 102 | self.slice3 = torch.nn.Sequential() 103 | self.slice4 = torch.nn.Sequential() 104 | self.slice5 = torch.nn.Sequential() 105 | self.N_slices = 5 106 | for x in range(4): 107 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 108 | for x in range(4, 9): 109 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(9, 16): 111 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(16, 23): 113 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(23, 30): 115 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 116 | if not requires_grad: 117 | for param in self.parameters(): 118 | param.requires_grad = False 119 | 120 | def forward(self, X): 121 | h = self.slice1(X) 122 | h_relu1_2 = h 123 | h = self.slice2(h) 124 | h_relu2_2 = h 125 | h = self.slice3(h) 126 | h_relu3_3 = h 127 | h = self.slice4(h) 128 | h_relu4_3 = h 129 | h = self.slice5(h) 130 | h_relu5_3 = h 131 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 132 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 133 | 134 | return out 135 | 136 | 137 | 138 | class resnet(torch.nn.Module): 139 | def __init__(self, requires_grad=False, pretrained=True, num=18): 140 | super(resnet, self).__init__() 141 | if(num==18): 142 | self.net = tv.resnet18(pretrained=pretrained) 143 | elif(num==34): 144 | self.net = tv.resnet34(pretrained=pretrained) 145 | elif(num==50): 146 | self.net = tv.resnet50(pretrained=pretrained) 147 | elif(num==101): 148 | self.net = tv.resnet101(pretrained=pretrained) 149 | elif(num==152): 150 | self.net = tv.resnet152(pretrained=pretrained) 151 | self.N_slices = 5 152 | 153 | self.conv1 = self.net.conv1 154 | self.bn1 = self.net.bn1 155 | self.relu = self.net.relu 156 | self.maxpool = self.net.maxpool 157 | self.layer1 = self.net.layer1 158 | self.layer2 = self.net.layer2 159 | self.layer3 = self.net.layer3 160 | self.layer4 = self.net.layer4 161 | 162 | def forward(self, X): 163 | h = self.conv1(X) 164 | h = self.bn1(h) 165 | h = self.relu(h) 166 | h_relu1 = h 167 | h = self.maxpool(h) 168 | h = self.layer1(h) 169 | h_conv2 = h 170 | h = self.layer2(h) 171 | h_conv3 = h 172 | h = self.layer3(h) 173 | h_conv4 = h 174 | h = self.layer4(h) 175 | h_conv5 = h 176 | 177 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 178 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 179 | 180 | return out 181 | -------------------------------------------------------------------------------- /_ntsccp/net/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | from torch.autograd import Function 33 | 34 | 35 | def find_named_module(module, query): 36 | """Helper function to find a named module. Returns a `nn.Module` or `None` 37 | 38 | Args: 39 | module (nn.Module): the root module 40 | query (str): the module name to find 41 | 42 | Returns: 43 | nn.Module or None 44 | """ 45 | 46 | return next((m for n, m in module.named_modules() if n == query), None) 47 | 48 | 49 | def find_named_buffer(module, query): 50 | """Helper function to find a named buffer. Returns a `torch.Tensor` or `None` 51 | 52 | Args: 53 | module (nn.Module): the root module 54 | query (str): the buffer name to find 55 | 56 | Returns: 57 | torch.Tensor or None 58 | """ 59 | return next((b for n, b in module.named_buffers() if n == query), None) 60 | 61 | 62 | def _update_registered_buffer( 63 | module, 64 | buffer_name, 65 | state_dict_key, 66 | state_dict, 67 | policy="resize_if_empty", 68 | dtype=torch.int, 69 | ): 70 | new_size = state_dict[state_dict_key].size() 71 | registered_buf = find_named_buffer(module, buffer_name) 72 | 73 | if policy in ("resize_if_empty", "resize"): 74 | if registered_buf is None: 75 | raise RuntimeError(f'buffer "{buffer_name}" was not registered') 76 | 77 | if policy == "resize" or registered_buf.numel() == 0: 78 | registered_buf.resize_(new_size) 79 | 80 | elif policy == "register": 81 | if registered_buf is not None: 82 | raise RuntimeError(f'buffer "{buffer_name}" was already registered') 83 | 84 | module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0)) 85 | 86 | else: 87 | raise ValueError(f'Invalid policy "{policy}"') 88 | 89 | 90 | def update_registered_buffers( 91 | module, 92 | module_name, 93 | buffer_names, 94 | state_dict, 95 | policy="resize_if_empty", 96 | dtype=torch.int, 97 | ): 98 | """Update the registered buffers in a module according to the tensors sized 99 | in a state_dict. 100 | 101 | (There's no way in torch to directly load a buffer with a dynamic size) 102 | 103 | Args: 104 | module (nn.Module): the module 105 | module_name (str): module name in the state dict 106 | buffer_names (list(str)): list of the buffer names to resize in the module 107 | state_dict (dict): the state dict 108 | policy (str): Update policy, choose from 109 | ('resize_if_empty', 'resize', 'register') 110 | dtype (dtype): Type of buffer to be registered (when policy is 'register') 111 | """ 112 | valid_buffer_names = [n for n, _ in module.named_buffers()] 113 | for buffer_name in buffer_names: 114 | if buffer_name not in valid_buffer_names: 115 | raise ValueError(f'Invalid buffer name "{buffer_name}"') 116 | 117 | for buffer_name in buffer_names: 118 | _update_registered_buffer( 119 | module, 120 | buffer_name, 121 | f"{module_name}.{buffer_name}", 122 | state_dict, 123 | policy, 124 | dtype, 125 | ) 126 | 127 | 128 | def conv(in_channels, out_channels, kernel_size=5, stride=2): 129 | return nn.Conv2d( 130 | in_channels, 131 | out_channels, 132 | kernel_size=kernel_size, 133 | stride=stride, 134 | padding=kernel_size // 2, 135 | ) 136 | 137 | 138 | def deconv(in_channels, out_channels, kernel_size=5, stride=2): 139 | return nn.ConvTranspose2d( 140 | in_channels, 141 | out_channels, 142 | kernel_size=kernel_size, 143 | stride=stride, 144 | output_padding=stride - 1, 145 | padding=kernel_size // 2, 146 | ) 147 | 148 | 149 | def quantize_ste(x): 150 | """Differentiable quantization via the Straight-Through-Estimator.""" 151 | # STE (straight-through estimator) trick: x_hard - x_soft.detach() + x_soft 152 | return (torch.round(x) - x).detach() + x 153 | 154 | 155 | def DEMUX(x): 156 | B, C, H, W = x.shape 157 | x_part1 = torch.ones_like(x)[:, :, :H // 2, :W] 158 | x_part1[:, :, :, 0::2] = x[:, :, 0::2, 0::2] 159 | x_part1[:, :, :, 1::2] = x[:, :, 1::2, 1::2] 160 | 161 | x_part2 = torch.ones_like(x)[:, :, :H // 2, :W] 162 | x_part2[:, :, :, 0::2] = x[:, :, 1::2, 0::2] 163 | x_part2[:, :, :, 1::2] = x[:, :, 0::2, 1::2] 164 | return x_part1, x_part2 165 | 166 | 167 | def MUX(x_part1, x_part2): 168 | # B, C, H_half, W = x_anchor.shape 169 | # H = H_half * 2 170 | x = torch.cat([torch.ones_like(x_part1), torch.ones_like(x_part1)], dim=2) 171 | x[:, :, 0::2, 0::2] = x_part1[:, :, :, 0::2] 172 | x[:, :, 1::2, 1::2] = x_part1[:, :, :, 1::2] 173 | x[:, :, 1::2, 0::2] = x_part2[:, :, :, 0::2] 174 | x[:, :, 0::2, 1::2] = x_part2[:, :, :, 1::2] 175 | return x 176 | 177 | 178 | # pylint: disable=W0221 179 | class LowerBound(Function): 180 | @staticmethod 181 | def forward(ctx, inputs, bound): 182 | b = torch.ones_like(inputs) * bound 183 | ctx.save_for_backward(inputs, b) 184 | return torch.max(inputs, b) 185 | 186 | @staticmethod 187 | def backward(ctx, grad_output): 188 | inputs, b = ctx.saved_tensors 189 | pass_through_1 = inputs >= b 190 | pass_through_2 = grad_output < 0 191 | 192 | pass_through = pass_through_1 | pass_through_2 193 | return pass_through.type(grad_output.dtype) * grad_output, None 194 | -------------------------------------------------------------------------------- /conditioning_method/diffcom.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from utils import utils_model 6 | 7 | __CONDITIONING_METHOD__ = {} 8 | 9 | 10 | def register_conditioning_method(name: str): 11 | def wrapper(cls): 12 | if __CONDITIONING_METHOD__.get(name, None): 13 | raise NameError(f"Name {name} is already registered!") 14 | __CONDITIONING_METHOD__[name] = cls 15 | return cls 16 | 17 | return wrapper 18 | 19 | 20 | def get_conditioning_method(name: str, **kwargs): 21 | if __CONDITIONING_METHOD__.get(name, None) is None: 22 | raise NameError(f"Name {name} is not defined!") 23 | return __CONDITIONING_METHOD__[name](**kwargs) 24 | 25 | 26 | class ConsistencyLoss(nn.Module): 27 | def __init__(self, config, device): 28 | super().__init__() 29 | self.config = config 30 | zeta = config.diffcom_series[config.conditioning_method]['zeta'] 31 | gamma = config.diffcom_series[config.conditioning_method]['gamma'] 32 | self.weight = { 33 | 'x_mse': gamma, 34 | 'ofdm_sig': zeta, 35 | } 36 | 37 | def forward(self, measurement, x_0_hat, cof, operator, operation_mode): 38 | x_0_hat = (x_0_hat / 2 + 0.5) # .clip(0, 1) 39 | s = operator.encode(x_0_hat) 40 | if operation_mode == 'latent': 41 | recon_measurement = { 42 | 'ofdm_sig': operator.forward(s, cof) 43 | } 44 | elif operation_mode == 'pixel': 45 | recon_measurement = { 46 | 'x_mse': x_0_hat 47 | } 48 | elif operation_mode == 'joint': 49 | ofdm_sig = operator.forward(s, cof) 50 | s_hat = operator.transpose(ofdm_sig, cof) 51 | x_confirming = operator.decode(s_hat) 52 | recon_measurement = { 53 | 'ofdm_sig': ofdm_sig, 54 | 'x_mse': x_confirming 55 | } 56 | loss = {} 57 | for key in recon_measurement.keys(): 58 | loss[key] = self.weight[key] * torch.linalg.norm(measurement[key] - recon_measurement[key]) 59 | return loss 60 | 61 | 62 | def get_lr(config, t, T): 63 | lr_base = config['learning_rate'] 64 | # exponential decay to 0 65 | if config['lr_schedule'] == 'exp': 66 | lr_min = config['lr_min'] 67 | lr = lr_min + (lr_base - lr_min) * np.exp(-t / T) 68 | # linear decay 69 | elif config['lr_schedule'] == 'linear': 70 | lr_min = config['lr_min'] 71 | lr = lr_min + (lr_base - lr_min) * (t / T) 72 | # constant 73 | else: 74 | lr = lr_base 75 | return lr 76 | 77 | 78 | @register_conditioning_method(name='diffcom') 79 | class DiffCom(nn.Module): 80 | def __init__(self): 81 | super().__init__() 82 | self.conditioning_method = 'latent' 83 | 84 | def conditioning(self, config, i, ns, x_t, h_t, power, 85 | measurement, unet, diffusion, operator, loss_wrapper, last_timestep): 86 | h_0_hat = h_t 87 | h_t_minus_1_prime = h_t 88 | h_t_minus_1 = h_t 89 | 90 | t_step = ns.seq[i] 91 | sigma_t = ns.reduced_alpha_cumprod[t_step].cpu().numpy() 92 | x_t = x_t.requires_grad_() 93 | x_t_minus_1_prime, x_0_hat, _ = utils_model.model_fn(x_t, 94 | noise_level=sigma_t * 255, 95 | model_out_type='pred_x_prev_and_start', \ 96 | model_diffusion=unet, 97 | diffusion=diffusion, 98 | ddim_sample=config.ddim_sample) 99 | if last_timestep: 100 | loss = loss_wrapper.forward(measurement, x_0_hat, h_0_hat, operator, self.conditioning_method) 101 | return x_0_hat, h_0_hat, x_t_minus_1_prime, h_t_minus_1_prime, loss 102 | else: 103 | loss = loss_wrapper.forward(measurement, x_0_hat, h_t, operator, self.conditioning_method) 104 | total_loss = sum(loss.values()) 105 | x_grad = torch.autograd.grad(outputs=total_loss, inputs=x_t)[0] 106 | learning_rate = get_lr(config.diffcom_series[config.conditioning_method], t_step, 107 | ns.t_start - 1) 108 | x_t_minus_1 = x_t_minus_1_prime - x_grad * learning_rate 109 | x_t_minus_1 = x_t_minus_1.detach_() 110 | return x_0_hat, h_0_hat, x_t_minus_1, h_t_minus_1, loss 111 | 112 | 113 | @register_conditioning_method(name='hifi_diffcom') 114 | class HiFiDiffCom(DiffCom): 115 | def __init__(self): 116 | super().__init__() 117 | self.conditioning_method = 'joint' 118 | 119 | 120 | @register_conditioning_method(name='blind_diffcom') 121 | class BlindDiffCom(DiffCom): 122 | def __init__(self): 123 | super().__init__() 124 | 125 | def conditioning(self, config, i, ns, x_t, h_t, power, 126 | measurement, unet, diffusion, operator, loss_wrapper, last_timestep): 127 | t_step = ns.seq[i] 128 | sigma_t = ns.reduced_alpha_cumprod[t_step].cpu().numpy() 129 | x_t = x_t.requires_grad_() 130 | x_t_minus_1_prime, x_0_hat, _ = utils_model.model_fn(x_t, 131 | noise_level=sigma_t * 255, 132 | model_out_type='pred_x_prev_and_start', \ 133 | model_diffusion=unet, 134 | diffusion=diffusion, 135 | ddim_sample=config.ddim_sample) 136 | 137 | assert (config.conditioning_method == 'blind_diffcom') 138 | 139 | h_t = h_t.requires_grad_() 140 | h_score = - h_t / (power ** 2) 141 | h_0_hat = (1 / ns.alphas_cumprod[t_step]) * ( 142 | h_t + ns.sqrt_1m_alphas_cumprod[t_step] * h_score) 143 | h_t_minus_1_prime = ns.posterior_mean_coef2[t_step] * h_t + ns.posterior_mean_coef1[t_step] * h_0_hat + \ 144 | ns.posterior_variance[t_step] * (torch.randn_like(h_t) + 1j * torch.randn_like(h_t)) 145 | 146 | if last_timestep: 147 | loss = loss_wrapper.forward(measurement, x_0_hat, h_0_hat, operator, self.conditioning_method) 148 | return x_0_hat, h_0_hat, x_t_minus_1_prime, h_t_minus_1_prime, loss 149 | else: 150 | loss = loss_wrapper.forward(measurement, x_0_hat, h_0_hat, operator, self.conditioning_method) 151 | total_loss = sum(loss.values()) 152 | x_grad, h_t_grad = torch.autograd.grad(outputs=total_loss, inputs=[x_t, h_t]) 153 | learning_rate = config.diffcom_series['blind_diffcom']['learning_rate'] 154 | learning_rate = (learning_rate - 0) * (t_step / (ns.t_start - 1)) 155 | x_t_minus_1 = x_t_minus_1_prime - x_grad * learning_rate 156 | x_t_minus_1 = x_t_minus_1.detach_() 157 | lr_h = config.diffcom_series['blind_diffcom']['h_lr'] 158 | lr_h = (lr_h - 0) * (t_step / (ns.t_start - 1)) 159 | h_t_minus_1 = h_t_minus_1_prime - h_t_grad * lr_h 160 | h_t_minus_1 = h_t_minus_1.detach_() 161 | return x_0_hat, h_0_hat, x_t_minus_1, h_t_minus_1, loss 162 | -------------------------------------------------------------------------------- /website/src/components/section3/Section3.js: -------------------------------------------------------------------------------- 1 | import React, {useState} from "react"; 2 | import {Button, Grid, Stack, ToggleButton, ToggleButtonGroup} from '@mui/material'; 3 | import ReactSwipe from 'react-swipe' 4 | import {ReactCompareSlider, ReactCompareSliderImage} from 'react-compare-slider'; 5 | import {AiFillLeftCircle, AiFillRightCircle} from 'react-icons/ai' 6 | 7 | const CenterWrapper = (props) => { 8 | return ( 9 |
10 |
11 |
12 |
13 | {props.content} 14 |
15 |
16 |
17 |
18 | ); 19 | } 20 | 21 | const IamgeComareSlider = ({imgs}) => { 22 | return ( 23 | } 25 | itemTwo={} 26 | /> 27 | ); 28 | } 29 | 30 | const Carousel = ({images, kernels, task, method, index, onButton}) => { 31 | let reactSwipeEl; 32 | 33 | const nextIndex = (index, change, length) => { 34 | let next_idx = (index + change); 35 | if (next_idx < 0) { 36 | next_idx = length + next_idx; 37 | } else { 38 | next_idx = next_idx % length; 39 | } 40 | return next_idx; 41 | } 42 | 43 | const pushPrev = () => { 44 | reactSwipeEl.prev(); 45 | onButton(nextIndex(index, -1, images.length)); 46 | } 47 | 48 | const pushNext = () => { 49 | reactSwipeEl.next(); 50 | onButton(nextIndex(index, 1, images.length)); 51 | } 52 | 53 | return ( 54 | 55 | 56 | {/**/} 57 | {/**/} 58 | 59 | {/*

Slide the button for comparison

*/} 60 |

Left: HiFi-DiffCom (with {method} Encoder); Right: {method}

61 | (reactSwipeEl = el)} 65 | childCount={images.length} 66 | > 67 | {images.map((image_pair) => { 68 | return ( 69 |
70 | 71 |
72 | ); 73 | })} 74 |
75 |
76 | 77 |
78 | 79 |
80 |
81 |
82 | 83 | 84 | 86 | 88 | 89 | 90 |
91 | ); 92 | } 93 | 94 | const GridKernel = ({kernels}) => { 95 | return ( 96 | 97 | 98 | 99 |

VTM + 5G LDPC

100 | {"loading.."}/ 103 |
104 |
105 | 106 | 107 |

Original Image

108 | {"loading.."}/ 111 |
112 |
113 |
114 | ); 115 | } 116 | 117 | function range(start, end) { 118 | let array = []; 119 | for (let i = start; i < end; i++) { 120 | array.push(i); 121 | } 122 | return array; 123 | } 124 | 125 | const ImageDisplay = ({method}) => { 126 | const task = 'SNR1'; 127 | const [index, setIndex] = useState(0); 128 | 129 | const images = range(0, 3).map((idx) => { 130 | return ({ 131 | 'input': process.env.PUBLIC_URL + '/imgs/results/' + method + '/' + task + '/input_' + idx + '.png', 132 | 'recon': process.env.PUBLIC_URL + '/imgs/results/' + method + '/' + task + '/recon_' + idx + '.png', 133 | }); 134 | }) 135 | 136 | const kernels = range(0, 3).map((idx) => { 137 | return ({ 138 | 'recon': process.env.PUBLIC_URL + '/imgs/results/' + method + '/' + task + '/vtm_' + idx + '.png', 139 | 'truth': process.env.PUBLIC_URL + '/imgs/results/' + method + '/' + task + '/ori_' + idx + '.png', 140 | }); 141 | }); 142 | 143 | return ( 144 | 145 | ) 146 | } 147 | 148 | 149 | const Content = () => { 150 | // const task_pair = { 151 | // 'SNR1': '0dB', 152 | // 'SNR2': '10dB' 153 | // } 154 | const method_pair = { 155 | 'DeepJSCC': 'DeepJSCC', 156 | 'NTSCC': 'NTSCC' 157 | }; 158 | 159 | // const tasks = ['SNR1', 'SNR2']; 160 | // const [task, setTask] = useState('SNR1'); 161 | 162 | const methods = ['DeepJSCC', 'NTSCC']; 163 | const [method, setMethod] = useState('DeepJSCC'); 164 | 165 | // const onTaskToggle = (button_val) => { 166 | // setTask(button_val); 167 | // }; 168 | 169 | const onMethodToggle = (button_val) => { 170 | setMethod(button_val); 171 | }; 172 | 173 | return ( 174 |
175 |

DiffCom Exhibits Superior Transmission Performance

176 | 181 |

182 | HiFi-DiffCom vs.   183 |

184 | {methods.map(t => ( 185 | { 186 | onMethodToggle(t) 187 | }} id={t} key={t}> 188 | {method_pair[t]} 189 | )) 190 | } 191 |
192 | , AWGN channel, CSNR = 0dB, CBR = 1/48; 193 | 194 |
195 | ); 196 | } 197 | 198 | const Section3 = () => { 199 | return ( 200 | }/> 201 | ); 202 | } 203 | 204 | export default Section3 205 | -------------------------------------------------------------------------------- /_pdjscc/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from config import config 3 | from data.dataset import get_loader 4 | from net.network import ADJSCC 5 | import torch.optim as optim 6 | from utils import * 7 | import time 8 | from loss_utils.utils import * 9 | 10 | 11 | 12 | # initialize model 13 | config.batch_size = 8 14 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_list 15 | logger = logger_configuration(config, save_log=True) 16 | logger.info(config.__dict__) 17 | net = ADJSCC(config) 18 | if len(config.gpu_list.split(',')) > 1: 19 | net = torch.nn.DataParallel(net).cuda() 20 | else: 21 | net = net.cuda() 22 | logger.info(net) 23 | predict = torch.load(config.predict) 24 | net.load_state_dict(predict, strict=False) 25 | net = net.cuda() 26 | G_params = set(p for n, p in net.named_parameters() if not n.startswith("Discriminator")) 27 | 28 | optimizer_G = optim.Adam(G_params, lr=config.g_learning_rate) 29 | if config.use_discriminator: 30 | D_params = set(p for n, p in net.named_parameters() if n.startswith("Discriminator")) 31 | optimizer_D = optim.Adam(D_params, lr=config.d_learning_rate) 32 | 33 | global_step = 0 34 | train_generator = False 35 | train_discriminator_steps = 0 36 | train_generator_steps = 0 37 | 38 | # load dataset 39 | trainloader, testloader = get_loader(config) 40 | global_step = 0 41 | 42 | # model training 43 | for epoch in range(config.epochs): 44 | net.train() 45 | device = next(net.parameters()).device 46 | elapsed, losses,ms_ssim_losses,mse_losses,lpips_losses,d_loss, g_loss,d_gen,g_real,d_real,psnrs = [AverageMeter() for _ in range(11)] 47 | metrics = [elapsed, losses,mse_losses,ms_ssim_losses,lpips_losses,d_loss, g_loss,d_gen,g_real,d_real,psnrs] 48 | for batch_idx, input_image in enumerate(trainloader): 49 | start_time = time.time() 50 | input_image = input_image.to(device) 51 | global_step += 1 52 | if config.use_discriminator: 53 | ms_ssim_loss,mse_loss,lpips_loss, x_hat,D_loss, G_loss,D_gen,D_real = net.forward(input_image,train_generator) 54 | if train_generator == True and D_gen < config.dis_acc: 55 | loss = config.K_M*mse_loss+config.K_P*lpips_loss +config.beta*G_loss 56 | optimizer_G.zero_grad() 57 | loss.backward() 58 | # torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5) 59 | optimizer_G.step() 60 | train_generator_steps += 1 61 | if train_generator_steps == config.generator_steps : 62 | train_generator = False 63 | train_generator_steps = 0 64 | else: 65 | loss = D_loss 66 | optimizer_D.zero_grad() 67 | loss.backward() 68 | # for name, parms in net.named_parameters(): 69 | # if name.startswith("Discriminator"): 70 | # print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \ 71 | # ' -->grad_value:', parms.grad) 72 | optimizer_D.step() 73 | train_discriminator_steps +=1 74 | if train_discriminator_steps == config.discriminator_steps : 75 | train_generator = True 76 | train_discriminator_steps = 0 77 | elapsed.update(time.time() - start_time) 78 | losses.update(loss.item()) 79 | mse_losses.update(mse_loss.item()) 80 | ms_ssim_losses.update(ms_ssim_loss.item()) 81 | lpips_losses.update(lpips_loss.item()) 82 | d_gen.update(D_gen) 83 | d_real.update(D_real) 84 | g_loss.update(G_loss.item()) 85 | d_loss.update(D_loss.item()) 86 | else: 87 | ms_ssim_loss, mse_loss, lpips_loss, x_hat = net.forward(input_image) 88 | # print('recon_Image:{}'.format(x_hat)) 89 | # print(mse_loss) 90 | loss = config.K_M * mse_loss + config.K_P * lpips_loss 91 | optimizer_G.zero_grad() 92 | loss.backward() 93 | # for name, parms in net.named_parameters(): 94 | # if name.startswith("Decoder.conv_block_out"): 95 | # print('-->name:', name, '-->grad_requirs:', parms.requires_grad, \ 96 | # ' -->grad_value:', parms.grad) 97 | torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5) 98 | optimizer_G.step() 99 | elapsed.update(time.time() - start_time) 100 | losses.update(loss.item()) 101 | mse_losses.update(mse_loss.item()) 102 | ms_ssim_losses.update(ms_ssim_loss.item()) 103 | lpips_losses.update(lpips_loss.item()) 104 | 105 | if mse_loss.item() > 0: 106 | psnr = 10 * (torch.log(255. * 255. / mse_loss) / np.log(10)) 107 | psnrs.update(psnr.item()) 108 | else: 109 | psnrs.update(100) 110 | 111 | if (global_step % config.print_step) == 0: 112 | process = (global_step % trainloader.__len__()) / (trainloader.__len__()) * 100.0 113 | log = (' | '.join([ 114 | f'Step [{global_step % trainloader.__len__()}/{trainloader.__len__()}={process:.2f}%]', 115 | f'Loss {losses.val:.3f} ({losses.avg:.3f})', 116 | f'Time {elapsed.avg:.2f}', 117 | f'PSNR {psnrs.val:.2f} ({psnrs.avg:.2f})', 118 | f'LPIPS {lpips_losses.val:.3f} ({lpips_losses.avg:.2f})', 119 | f'MS-SSIM {ms_ssim_losses.val:.3f} ({ms_ssim_losses.avg:.3f})', 120 | f'G_loss {g_loss.val:.3f}({g_loss.avg:.3f})', 121 | f'D_loss {d_loss.val:.3f}({d_loss.avg:.3f})', 122 | f'D_gen {d_gen.val:.3f} ({d_gen.avg:.3f})', 123 | f'D_real {d_real.val:.3f} ({d_real.avg:.3f})', 124 | f'Epoch {epoch}', 125 | f'Lr {config.g_learning_rate}', 126 | ])) 127 | logger.info(log) 128 | for i in metrics: 129 | i.clear() 130 | if (global_step % config.test_step) == 0: 131 | net.eval() 132 | with torch.no_grad(): 133 | for batch_idx, input_image in enumerate(testloader): 134 | start_time = time.time() 135 | input_image = input_image.cuda() 136 | ms_ssim_loss,mse_loss, lpips_loss, x_hat = net.forward(input_image) 137 | elapsed.update(time.time() - start_time) 138 | losses.update(loss.item()) 139 | ms_ssim_losses.update(ms_ssim_loss.item()) 140 | mse_losses.update(mse_loss.item()) 141 | lpips_losses.update(lpips_loss.item()) 142 | if mse_loss.item() > 0: 143 | psnr = 10 * (torch.log(255. * 255. / mse_loss) / np.log(10)) 144 | psnrs.update(psnr.item()) 145 | else: 146 | psnrs.update(100) 147 | log = (' | '.join([ 148 | f'test_Loss {losses.val:.3f} ({losses.avg:.3f})', 149 | f'test_Time {elapsed.avg:.2f}', 150 | f'test_PSNR {psnrs.val:.3f} ({psnrs.avg:.3f})', 151 | f'test_LPIPS {lpips_losses.val:.3f} ({lpips_losses.avg:.3f})', 152 | f'test_MS-SSIM {ms_ssim_losses.val:.3f} ({ms_ssim_losses.avg:.3f})' 153 | ])) 154 | logger.info(log) 155 | for i in metrics: 156 | i.clear() 157 | net.train() 158 | 159 | if global_step % 5000 == 0 and global_step > 1: 160 | save_model(net, 161 | save_path=config.models + '/{}_EP{}_Step{}.model'.format(config.filename, epoch, global_step)) 162 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/perceptual_similarity/networks_basic.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | from torch.autograd import Variable 9 | import numpy as np 10 | from skimage import color 11 | from . import pretrained_networks as pn 12 | from . import perceptual_loss as pl 13 | 14 | def spatial_average(in_tens, keepdim=True): 15 | return in_tens.mean([2,3],keepdim=keepdim) 16 | 17 | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W 18 | in_H, in_W = in_tens.shape[2], in_tens.shape[3] 19 | scale_factor_H, scale_factor_W = 1.*out_HW[0]/in_H, 1.*out_HW[1]/in_W 20 | 21 | return nn.Upsample(scale_factor=(scale_factor_H, scale_factor_W), mode='bilinear', align_corners=False)(in_tens) 22 | 23 | # Learned perceptual metric 24 | class PNetLin(nn.Module): 25 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): 26 | super(PNetLin, self).__init__() 27 | 28 | self.pnet_type = pnet_type 29 | self.pnet_tune = pnet_tune 30 | self.pnet_rand = pnet_rand 31 | self.spatial = spatial 32 | self.lpips = lpips 33 | self.version = version 34 | self.scaling_layer = ScalingLayer() 35 | 36 | if(self.pnet_type in ['vgg','vgg16']): 37 | net_type = pn.vgg16 38 | self.chns = [64,128,256,512,512] 39 | elif(self.pnet_type=='alex'): 40 | net_type = pn.alexnet 41 | self.chns = [64,192,384,256,256] 42 | elif(self.pnet_type=='squeeze'): 43 | net_type = pn.squeezenet 44 | self.chns = [64,128,256,384,384,512,512] 45 | self.L = len(self.chns) 46 | 47 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 48 | 49 | if(lpips): 50 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 51 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 52 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 53 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 54 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 55 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 56 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 57 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 58 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 59 | self.lins+=[self.lin5,self.lin6] 60 | 61 | def forward(self, in0, in1, retPerLayer=False): 62 | # v0.0 - original release had a bug, where input was not scaled 63 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 64 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 65 | feats0, feats1, diffs = {}, {}, {} 66 | 67 | for kk in range(self.L): 68 | feats0[kk], feats1[kk] = pl.normalize_tensor(outs0[kk]), pl.normalize_tensor(outs1[kk]) 69 | diffs[kk] = (feats0[kk]-feats1[kk])**2 70 | 71 | if(self.lpips): 72 | if(self.spatial): 73 | res = [upsample(self.lins[kk].model(diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] 74 | else: 75 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] 76 | else: 77 | if(self.spatial): 78 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] 79 | else: 80 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 81 | 82 | val = res[0] 83 | for l in range(1,self.L): 84 | val += res[l] 85 | 86 | if(retPerLayer): 87 | return (val, res) 88 | else: 89 | return val 90 | 91 | class ScalingLayer(nn.Module): 92 | def __init__(self): 93 | super(ScalingLayer, self).__init__() 94 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 95 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 96 | 97 | def forward(self, inp): 98 | return (inp - self.shift) / self.scale 99 | 100 | 101 | class NetLinLayer(nn.Module): 102 | ''' A single linear layer which does a 1x1 conv ''' 103 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 104 | super(NetLinLayer, self).__init__() 105 | 106 | layers = [nn.Dropout(),] if(use_dropout) else [] 107 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 108 | self.model = nn.Sequential(*layers) 109 | 110 | 111 | class Dist2LogitLayer(nn.Module): 112 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 113 | def __init__(self, chn_mid=32, use_sigmoid=True): 114 | super(Dist2LogitLayer, self).__init__() 115 | 116 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 117 | layers += [nn.LeakyReLU(0.2,True),] 118 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 119 | layers += [nn.LeakyReLU(0.2,True),] 120 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 121 | if(use_sigmoid): 122 | layers += [nn.Sigmoid(),] 123 | self.model = nn.Sequential(*layers) 124 | 125 | def forward(self,d0,d1,eps=0.1): 126 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 127 | 128 | class BCERankingLoss(nn.Module): 129 | def __init__(self, chn_mid=32): 130 | super(BCERankingLoss, self).__init__() 131 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 132 | # self.parameters = list(self.net.parameters()) 133 | self.loss = torch.nn.BCELoss() 134 | 135 | def forward(self, d0, d1, judge): 136 | per = (judge+1.)/2. 137 | self.logit = self.net.forward(d0,d1) 138 | return self.loss(self.logit, per) 139 | 140 | # L2, DSSIM metrics 141 | class FakeNet(nn.Module): 142 | def __init__(self, use_gpu=True, colorspace='Lab'): 143 | super(FakeNet, self).__init__() 144 | self.use_gpu = use_gpu 145 | self.colorspace=colorspace 146 | 147 | class L2(FakeNet): 148 | 149 | def forward(self, in0, in1, retPerLayer=None): 150 | assert(in0.size()[0]==1) # currently only supports batchSize 1 151 | 152 | if(self.colorspace=='RGB'): 153 | (N,C,X,Y) = in0.size() 154 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 155 | return value 156 | elif(self.colorspace=='Lab'): 157 | value = pl.l2(pl.tensor2np(pl.tensor2tensorlab(in0.data,to_norm=False)), 158 | pl.tensor2np(pl.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 159 | ret_var = Variable( torch.Tensor((value,) ) ) 160 | if(self.use_gpu): 161 | ret_var = ret_var.cuda() 162 | return ret_var 163 | 164 | class DSSIM(FakeNet): 165 | 166 | def forward(self, in0, in1, retPerLayer=None): 167 | assert(in0.size()[0]==1) # currently only supports batchSize 1 168 | 169 | if(self.colorspace=='RGB'): 170 | value = pl.dssim(1.*pl.tensor2im(in0.data), 1.*pl.tensor2im(in1.data), range=255.).astype('float') 171 | elif(self.colorspace=='Lab'): 172 | value = pl.dssim(pl.tensor2np(pl.tensor2tensorlab(in0.data,to_norm=False)), 173 | pl.tensor2np(pl.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 174 | ret_var = Variable( torch.Tensor((value,) ) ) 175 | if(self.use_gpu): 176 | ret_var = ret_var.cuda() 177 | return ret_var 178 | 179 | def print_network(net): 180 | num_params = 0 181 | for param in net.parameters(): 182 | num_params += param.numel() 183 | print('Network',net) 184 | print('Total number of parameters: %d' % num_params) 185 | -------------------------------------------------------------------------------- /website/src/components/section4/Section4.js: -------------------------------------------------------------------------------- 1 | import React, {useState} from "react"; 2 | import { Grid, Slider, Stack, ToggleButton, ToggleButtonGroup } from "@mui/material"; 3 | import { GiBackwardTime } from "react-icons/gi" 4 | // import Chart from "react-apexcharts"; 5 | import { MathJax } from 'better-react-mathjax'; 6 | 7 | // let _ = require('lodash'); 8 | 9 | const CenterWrapper = (props) => { 10 | return ( 11 |
12 |
13 |
14 |
15 | {props.content} 16 |
17 |
18 |
19 |
20 | ); 21 | } 22 | 23 | // const ErrorGraph = ({data, task}) => { 24 | // let color = (task==="deblur") ? "#d4526e":"#33b2bf"; 25 | // let state = { 26 | // series: data, 27 | // options: { 28 | // chart: { 29 | // height: '10rem', 30 | // type: 'rangeArea', 31 | // animations: { 32 | // // speed: 500 33 | // enabled: false 34 | // }, 35 | // }, 36 | // grid: {show: false}, 37 | // zoom: {enabled: false}, 38 | // xaxis: { 39 | // overwriteCategories: ["1000", "900", "800", "700", "600", "500", "400", "300", "200", "100", "0"], 40 | // tickAmount: 10, 41 | // max:1000, 42 | // }, 43 | // yaxis: { 44 | // title: { 45 | // text: 'Residue' 46 | // } 47 | // }, 48 | // colors: [color, color], 49 | // dataLabels: { 50 | // enabled: false 51 | // }, 52 | // fill: { 53 | // opacity: [0.24, 1] 54 | // }, 55 | // legend: { 56 | // show: false 57 | // }, 58 | // stroke: { 59 | // curve: 'straight', 60 | // width: [0, 2] 61 | // }, 62 | // tooltip: { 63 | // x: { 64 | // formatter: function(value) { 65 | // return 't='+ (1000-value).toString(); 66 | // }, 67 | // }, 68 | // } 69 | // }, 70 | // } 71 | // 72 | // return ( 73 | //
74 | // 75 | //
); 76 | // }; 77 | 78 | 79 | function generate_path(task, time){ 80 | let base = process.env.PUBLIC_URL + '/imgs/progress/' + task 81 | let time_str = (typeof time === 'number') ? (time).toString():time; 82 | let number = time_str.padStart(4, '0'); 83 | 84 | return { 85 | // 'img_input': base + '/io/input.png', 86 | 'img_progress': base + '/img/x_0^' + number + '.png', 87 | 'ker_progress': base + '/ker/cof_hat_' + number + '.png', 88 | 'img_label': base + '/io/input.png', 89 | 'ker_label': base + '/io/truth_ker.png' 90 | } 91 | } 92 | 93 | const ImageGrid = ({task, time}) => { 94 | let paths = generate_path(task, 1000-time); 95 | return ( 96 | 97 | {/**/} 98 | {/*

Reconstruction

*/} 99 | {/*
*/} 100 | {/* input*/} 101 | {/*
*/} 102 | {/*
*/} 103 | 104 |

{"$ \\mathbf{\\hat{x}}_{0|t}$"}

105 |
106 | x0 107 |
108 |
109 | 110 |

Original Image

111 |
112 | truth 113 |
114 |
115 | 116 |

Estimated Channel Response Vector

117 |
118 | h_est 119 |
120 |
121 | 122 |

Ground Truth Channel Response Vector

123 |
124 | h_gt 125 |
126 |
127 |
128 | ); 129 | } 130 | 131 | // const KernelGrid = ({task, time}) => { 132 | // let paths = generate_path(task, 1000-time); 133 | // return ( 134 | // 135 | // {/**/} 136 | // {/* /!* input *!/*/} 137 | // {/**/} 138 | // 139 | // x0 140 | // 141 | // 142 | // truth 143 | // 144 | // 145 | // ); 146 | // } 147 | 148 | const Content = () => { 149 | 150 | const [time, setTime] = useState(1000); 151 | const [task, setTask] = useState("Sample1"); 152 | const tasks = ['Sample1', 'Sample2', 'Sample3']; 153 | // const data = {'Sample1': require('./deblur_data.json'), 154 | // 'Sample2': require('./turbulence_data.json')}; 155 | // const [partialData, setPartialData] = useState(data['Sample1']); 156 | 157 | // function sliceData(idx, task){ 158 | // let discrete_idx = parseInt(idx/10); 159 | // let current = _.cloneDeep(data[task]); 160 | // if (discrete_idx > 2){ 161 | // current[0].data = current[0].data.slice(0, discrete_idx); 162 | // current[1].data = current[1].data.slice(0, discrete_idx); 163 | // } 164 | // return current; 165 | // } 166 | 167 | const handleSlider = (e, v) => { 168 | setTime(v); 169 | // setPartialData(sliceData(v, task)) 170 | } 171 | 172 | const onTaskToggle = (task) => { 173 | setTask(task); 174 | // setPartialData(sliceData(time, task)); 175 | } 176 | 177 | return ( 178 |
179 |

Blind-DiffCom Achieves Pilot-Free Transmission

180 | 184 | {tasks.map(t => ( 185 | {onTaskToggle(t)}} id={t} key={t}> 186 | {t} 187 | )) 188 | } 189 | 190 | {/*
*/} 191 | {/**/} 192 | 193 | {/**/} 194 | {/*

Measured on 20 samples. Mean {'$\\pm 1.0\\sigma$'} is displayed.

*/} 195 | {/*
*/} 196 | 197 | 198 | 199 | 200 | 201 | {/**/} 202 | {/* */} 203 | {/**/} 204 | 205 | 206 | 207 | 208 | 209 | 210 |

⎻⎻⎻ Drag time slider ⎻⎻→

211 |

Progress of channel estimation-free transmission over a multipath fading channel with L = 8 paths.

212 |
213 | ) 214 | } 215 | 216 | const Section4 = () => { 217 | return ( 218 | } /> 219 | ); 220 | } 221 | 222 | export default Section4; 223 | 224 | 225 | -------------------------------------------------------------------------------- /_pdjscc/data/dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | from glob import glob 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from torch.utils.data.dataset import Dataset 10 | from torchvision import transforms, datasets 11 | 12 | NUM_DATASET_WORKERS = 6 13 | SCALE_MIN = 0.75 14 | SCALE_MAX = 0.95 15 | 16 | 17 | class OpenImages(Dataset): 18 | """OpenImages dataset from [1]. 19 | 20 | Parameters 21 | ---------- 22 | root : string 23 | Root directory of dataset. 24 | 25 | References 26 | ---------- 27 | [1] https://storage.googleapis.com/openimages/web/factsfigures.html 28 | 29 | """ 30 | files = {"train": "train", "test": "test", "val": "validation"} 31 | 32 | def __init__(self, config, data_dir): 33 | self.imgs = [] 34 | for dir in data_dir: 35 | self.imgs += glob(os.path.join(dir, '*.jpg')) 36 | self.imgs += glob(os.path.join(dir, '*.png')) 37 | _, self.im_height, self.im_width = config.img_size 38 | self.crop_size = self.im_height 39 | self.image_dims = (3, self.im_height, self.im_width) 40 | self.scale_min = SCALE_MIN 41 | self.scale_max = SCALE_MAX 42 | self.normalize = config.normalize 43 | self.require_H = False 44 | 45 | def _transforms(self, scale, H, W): 46 | """ 47 | Up(down)scale and randomly crop to `crop_size` x `crop_size` 48 | """ 49 | transforms_list = [transforms.ToPILImage(), 50 | transforms.RandomHorizontalFlip(), 51 | transforms.Resize((math.ceil(scale * H), math.ceil(scale * W))), 52 | transforms.RandomCrop((self.im_height, self.im_width))] 53 | # transforms.Resize((self.im_height, self.im_width)), 54 | # transforms.ToTensor()] 55 | transforms_list += [transforms.ToTensor()] 56 | if self.normalize is True: 57 | transforms_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 58 | 59 | return transforms.Compose(transforms_list) 60 | 61 | def __getitem__(self, idx): 62 | """ TODO: This definitely needs to be optimized. 63 | Get the image of `idx` 64 | 65 | Return 66 | ------ 67 | sample : torch.Tensor 68 | Tensor in [0.,1.] of shape `img_size`. 69 | 70 | """ 71 | # img values already between 0 and 255 72 | img_path = self.imgs[idx] 73 | # filesize = os.path.getsize(img_path) 74 | # This is faster but less convenient 75 | # H X W X C `ndarray` 76 | # img = imread(img_path) 77 | # img_dims = img.shape 78 | # H, W = img_dims[0], img_dims[1] 79 | # PIL 80 | try: 81 | img = Image.open(img_path) 82 | img = img.convert('RGB') 83 | except: 84 | img_path = self.imgs[idx + 1] 85 | img = Image.open(img_path) 86 | img = img.convert('RGB') 87 | print("ERROR!") 88 | # except: 89 | # img_path = self.imgs[idx + 1] 90 | # img = Image.open(img_path) 91 | # img = img.convert('RGB') 92 | # print("ERROR!") 93 | # img = cv2.imread(img_path) 94 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 95 | img = np.asarray(img) 96 | W, H, C = img.shape 97 | # bpp = filesize * 8. / (H * W) 98 | 99 | shortest_side_length = min(H, W) 100 | 101 | minimum_scale_factor = float(self.crop_size) / float(shortest_side_length) 102 | scale_low = max(minimum_scale_factor, self.scale_min) 103 | scale_high = max(scale_low, self.scale_max) 104 | scale = np.random.uniform(scale_low, scale_high) 105 | self.dynamic_transform = self._transforms(scale, H, W) 106 | transformed = self.dynamic_transform(img) 107 | return transformed 108 | 109 | def __len__(self): 110 | return len(self.imgs) 111 | 112 | 113 | def get_cifar10_loader(config): 114 | dataset_ = datasets.CIFAR10 115 | dataset_dir = '/home/wangsixian/Dataset/CIFAR10' 116 | transform_train = transforms.Compose([ 117 | transforms.RandomHorizontalFlip(), 118 | transforms.ToTensor(), 119 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 120 | 121 | transform_test = transforms.Compose([ 122 | transforms.ToTensor(), 123 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 124 | 125 | train_dataset = dataset_(root=dataset_dir, 126 | train=True, 127 | transform=transform_train, 128 | download=False) 129 | 130 | test_dataset = dataset_(root=dataset_dir, 131 | train=False, 132 | transform=transform_test, 133 | download=False) 134 | 135 | train_loader = data.DataLoader(dataset=train_dataset, 136 | batch_size=config.batch_size, 137 | num_workers=NUM_DATASET_WORKERS, 138 | drop_last=True, 139 | shuffle=True) 140 | 141 | test_loader = data.DataLoader(dataset=test_dataset, 142 | batch_size=config.batch_size, 143 | shuffle=False) 144 | 145 | return train_loader, test_loader 146 | 147 | 148 | def get_loader(config): 149 | train_dataset = OpenImages(config, config.train_data_dir) 150 | test_dataset = Datasets(config.test_data_dir) 151 | 152 | def worker_init_fn_seed(worker_id): 153 | seed = 10 154 | seed += worker_id 155 | np.random.seed(seed) 156 | 157 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 158 | num_workers=NUM_DATASET_WORKERS, 159 | pin_memory=True, 160 | batch_size=config.batch_size, 161 | worker_init_fn=worker_init_fn_seed, 162 | shuffle=True) 163 | 164 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 165 | batch_size=1, 166 | shuffle=False) 167 | 168 | return train_loader, test_loader 169 | 170 | 171 | def get_test_loader(config): 172 | test_dataset = Datasets(config.test_data_dir) 173 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 174 | batch_size=1, 175 | shuffle=False) 176 | 177 | return test_loader 178 | 179 | 180 | class Datasets(Dataset): 181 | def __init__(self, data_dir): 182 | self.data_dir = data_dir 183 | self.imgs = [] 184 | for dir in data_dir: 185 | self.imgs += glob(os.path.join(dir, '*.jpg')) 186 | self.imgs += glob(os.path.join(dir, '*.png')) 187 | self.imgs.sort() 188 | self.normalize = False 189 | _, self.im_height, self.im_width = (3, 256, 256) 190 | self.transform = transforms.Compose([ # transforms.RandomHorizontalFlip(0), 191 | # transforms.CenterCrop((256, 256)), 192 | transforms.ToTensor()]) 193 | 194 | def __getitem__(self, item): 195 | image_ori = self.imgs[item] 196 | image = Image.open(image_ori).convert('RGB') 197 | img = self.transform(image) 198 | if self.normalize: 199 | return 2 * img - 1 200 | else: 201 | return img 202 | 203 | def __len__(self): 204 | return len(self.imgs) 205 | 206 | 207 | class TestKodakDataset(Dataset): 208 | def __init__(self, data_dir): 209 | self.image_path = sorted(glob(data_dir)) # [:100] 210 | 211 | def __getitem__(self, item): 212 | image_ori = self.image_path[item] 213 | image = Image.open(image_ori).convert('RGB') 214 | transform = transforms.Compose([ 215 | transforms.CenterCrop(256), 216 | transforms.ToTensor(), 217 | ]) 218 | return transform(image) 219 | 220 | def __len__(self): 221 | return len(self.image_path) 222 | 223 | 224 | if __name__ == '__main__': 225 | import os 226 | import sys 227 | 228 | sys.path.append("/media/D/wangsixian/DJSCC") 229 | from ADJSCC.config import config 230 | 231 | config.train_data_dir = ['/home/wangsixian/Dataset/openimages/**'] 232 | dataset = OpenImages(config, config.train_data_dir) 233 | print(dataset.__len__()) 234 | for i in range(dataset.__len__()): 235 | # try: 236 | img = dataset.__getitem__(i) 237 | print(img.shape) 238 | break 239 | -------------------------------------------------------------------------------- /guided_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | for p in self.master_params: 203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 204 | opt.step() 205 | zero_master_grads(self.master_params) 206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 207 | self.lg_loss_scale += self.fp16_scale_growth 208 | return True 209 | 210 | def _optimize_normal(self, opt: th.optim.Optimizer): 211 | grad_norm, param_norm = self._compute_norms() 212 | logger.logkv_mean("grad_norm", grad_norm) 213 | logger.logkv_mean("param_norm", param_norm) 214 | opt.step() 215 | return True 216 | 217 | def _compute_norms(self, grad_scale=1.0): 218 | grad_norm = 0.0 219 | param_norm = 0.0 220 | for p in self.master_params: 221 | with th.no_grad(): 222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 223 | if p.grad is not None: 224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 226 | 227 | def master_params_to_state_dict(self, master_params): 228 | return master_params_to_state_dict( 229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 230 | ) 231 | 232 | def state_dict_to_master_params(self, state_dict): 233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 234 | 235 | 236 | def check_overflow(value): 237 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 238 | -------------------------------------------------------------------------------- /_pdjscc/loss_utils/loss/distortion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | @torch.jit.script 6 | def create_window(window_size: int, sigma: float, channel: int): 7 | ''' 8 | Create 1-D gauss kernel 9 | :param window_size: the size of gauss kernel 10 | :param sigma: sigma of normal distribution 11 | :param channel: input channel 12 | :return: 1D kernel 13 | ''' 14 | coords = torch.arange(window_size, dtype=torch.float) 15 | coords -= window_size // 2 16 | 17 | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) 18 | g /= g.sum() 19 | 20 | g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1) 21 | return g 22 | 23 | 24 | @torch.jit.script 25 | def _gaussian_filter(x, window_1d, use_padding: bool): 26 | ''' 27 | Blur input with 1-D kernel 28 | :param x: batch of tensors to be blured 29 | :param window_1d: 1-D gauss kernel 30 | :param use_padding: padding image before conv 31 | :return: blured tensors 32 | ''' 33 | C = x.shape[1] 34 | padding = 0 35 | if use_padding: 36 | window_size = window_1d.shape[3] 37 | padding = window_size // 2 38 | out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C) 39 | out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C) 40 | return out 41 | 42 | 43 | @torch.jit.script 44 | def ssim(X, Y, window, data_range: float, use_padding: bool = False): 45 | ''' 46 | Calculate ssim index for X and Y 47 | :param X: images 48 | :param Y: images 49 | :param window: 1-D gauss kernel 50 | :param data_range: value range of input images. (usually 1.0 or 255) 51 | :param use_padding: padding image before conv 52 | :return: 53 | ''' 54 | 55 | K1 = 0.01 56 | K2 = 0.03 57 | compensation = 1.0 58 | 59 | C1 = (K1 * data_range) ** 2 60 | C2 = (K2 * data_range) ** 2 61 | 62 | mu1 = _gaussian_filter(X, window, use_padding) 63 | mu2 = _gaussian_filter(Y, window, use_padding) 64 | sigma1_sq = _gaussian_filter(X * X, window, use_padding) 65 | sigma2_sq = _gaussian_filter(Y * Y, window, use_padding) 66 | sigma12 = _gaussian_filter(X * Y, window, use_padding) 67 | 68 | mu1_sq = mu1.pow(2) 69 | mu2_sq = mu2.pow(2) 70 | mu1_mu2 = mu1 * mu2 71 | 72 | sigma1_sq = compensation * (sigma1_sq - mu1_sq) 73 | sigma2_sq = compensation * (sigma2_sq - mu2_sq) 74 | sigma12 = compensation * (sigma12 - mu1_mu2) 75 | 76 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) 77 | # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan. 78 | cs_map = F.relu(cs_map) 79 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map 80 | 81 | ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW 82 | cs = cs_map.mean(dim=(1, 2, 3)) 83 | 84 | return ssim_val, cs 85 | 86 | 87 | @torch.jit.script 88 | def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8): 89 | ''' 90 | interface of ms-ssim 91 | :param X: a batch of images, (N,C,H,W) 92 | :param Y: a batch of images, (N,C,H,W) 93 | :param window: 1-D gauss kernel 94 | :param data_range: value range of input images. (usually 1.0 or 255) 95 | :param weights: weights for different levels 96 | :param use_padding: padding image before conv 97 | :param eps: use for avoid grad nan. 98 | :return: 99 | ''' 100 | weights = weights[:, None] 101 | 102 | levels = weights.shape[0] 103 | vals = [] 104 | for i in range(levels): 105 | ss, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding) 106 | 107 | if i < levels - 1: 108 | vals.append(cs) 109 | X = F.avg_pool2d(X, kernel_size=2, stride=2, ceil_mode=True) 110 | Y = F.avg_pool2d(Y, kernel_size=2, stride=2, ceil_mode=True) 111 | else: 112 | vals.append(ss) 113 | 114 | vals = torch.stack(vals, dim=0) 115 | # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. 116 | vals = vals.clamp_min(eps) 117 | # The origin ms-ssim op. 118 | ms_ssim_val = torch.prod(vals[:-1] ** weights[:-1] * vals[-1:] ** weights[-1:], dim=0) 119 | # The new ms-ssim op. But I don't know which is best. 120 | # ms_ssim_val = torch.prod(vals ** weights, dim=0) 121 | # In this file's image training demo. I feel the old ms-ssim more better. So I keep use old ms-ssim op. 122 | return ms_ssim_val 123 | 124 | 125 | class SSIM(torch.jit.ScriptModule): 126 | __constants__ = ['data_range', 'use_padding'] 127 | 128 | def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False): 129 | ''' 130 | :param window_size: the size of gauss kernel 131 | :param window_sigma: sigma of normal distribution 132 | :param data_range: value range of input images. (usually 1.0 or 255) 133 | :param channel: input channels (default: 3) 134 | :param use_padding: padding image before conv 135 | ''' 136 | super().__init__() 137 | assert window_size % 2 == 1, 'Window size must be odd.' 138 | window = create_window(window_size, window_sigma, channel) 139 | self.register_buffer('window', window) 140 | self.data_range = data_range 141 | self.use_padding = use_padding 142 | 143 | @torch.jit.script_method 144 | def forward(self, X, Y): 145 | r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding) 146 | return r[0] 147 | 148 | 149 | class MS_SSIM(torch.jit.ScriptModule): 150 | __constants__ = ['data_range', 'use_padding', 'eps'] 151 | 152 | def __init__(self, window_size=11, window_sigma=1.5, data_range=1.0, channel=3, use_padding=False, weights=None, 153 | levels=None, eps=1e-8): 154 | """ 155 | class for ms-ssim 156 | :param window_size: the size of gauss kernel 157 | :param window_sigma: sigma of normal distribution 158 | :param data_range: value range of input images. (usually 1.0 or 255) 159 | :param channel: input channels 160 | :param use_padding: padding image before conv 161 | :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 162 | :param levels: number of downsampling 163 | :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. 164 | """ 165 | super().__init__() 166 | assert window_size % 2 == 1, 'Window size must be odd.' 167 | self.data_range = data_range 168 | self.use_padding = use_padding 169 | self.eps = eps 170 | 171 | window = create_window(window_size, window_sigma, channel) 172 | self.register_buffer('window', window) 173 | 174 | if weights is None: 175 | weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] 176 | weights = torch.tensor(weights, dtype=torch.float) 177 | 178 | if levels is not None: 179 | weights = weights[:levels] 180 | weights = weights / weights.sum() 181 | 182 | self.register_buffer('weights', weights) 183 | 184 | @torch.jit.script_method 185 | def forward(self, X, Y): 186 | return 1 - ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights, 187 | use_padding=self.use_padding, eps=self.eps) 188 | 189 | 190 | class MSE(torch.nn.Module): 191 | def __init__(self, normalization=True,scaled = True): 192 | super(MSE, self).__init__() 193 | self.squared_difference = torch.nn.MSELoss(reduction='none') 194 | self.normalization = normalization 195 | self.scaled = scaled 196 | 197 | def forward(self, X, Y): 198 | # [-1 1] to [0 1] 199 | if self.normalization: 200 | X = (X + 1) / 2 201 | Y = (Y + 1) / 2 202 | return torch.mean(self.squared_difference(X * 255., Y * 255.)) # / 255. 203 | 204 | 205 | class Distortion(torch.nn.Module): 206 | def __init__(self, config): 207 | super(Distortion, self).__init__() 208 | if config.distortion_metric == 'MSE': 209 | self.dist = MSE(normalization=False) 210 | elif config.distortion_metric == 'SSIM': 211 | self.dist = SSIM() 212 | elif config.distortion_metric == 'MS-SSIM': 213 | self.dist = MS_SSIM(data_range=1., levels=4, channel=3).cuda() 214 | else: 215 | config.logger.info("Unknown distortion type!") 216 | raise ValueError 217 | 218 | def forward(self, X, Y,normalization = False): 219 | return self.dist.forward(X, Y).mean() # / 255. 220 | 221 | 222 | if __name__ == '__main__': 223 | rand_im1 = (torch.randint(0, 255, [4, 3, 256, 128], dtype=torch.float32) / 255.).cuda() 224 | rand_im2 = (torch.randint(0, 255, [4, 3, 256, 128], dtype=torch.float32) / 255.).cuda() 225 | losser = MS_SSIM(data_range=1., levels=4, channel=3).cuda() 226 | loss = losser(rand_im1, rand_im2) 227 | print(loss) 228 | --------------------------------------------------------------------------------