├── Structure.png ├── util ├── __init__.py ├── util.pyc ├── __init__.pyc ├── __pycache__ │ ├── hdm.cpython-35.pyc │ ├── hdm.cpython-36.pyc │ ├── html.cpython-35.pyc │ ├── html.cpython-36.pyc │ ├── ldpc.cpython-35.pyc │ ├── mod.cpython-35.pyc │ ├── nnls.cpython-35.pyc │ ├── nnls.cpython-36.pyc │ ├── util.cpython-35.pyc │ ├── util.cpython-36.pyc │ ├── polar.cpython-35.pyc │ ├── polar.cpython-36.pyc │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── channel.cpython-35.pyc │ ├── image_pool.cpython-35.pyc │ ├── image_pool.cpython-36.pyc │ ├── visualizer.cpython-35.pyc │ ├── visualizer.cpython-36.pyc │ └── inception_score.cpython-35.pyc ├── image_pool.py ├── html.py ├── util.py ├── get_data.py └── visualizer.py ├── models ├── Pilot_bit.pt ├── utils.py ├── __init__.py ├── JSCC_model.py ├── base_model.py ├── channel.py ├── JSCCOFDM_model.py └── networks.py ├── data ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── data_loader.cpython-35.pyc │ ├── data_loader.cpython-36.pyc │ ├── CelebA_dataset.cpython-35.pyc │ ├── base_dataset.cpython-35.pyc │ ├── base_dataset.cpython-36.pyc │ ├── image_folder.cpython-35.pyc │ ├── image_folder.cpython-36.pyc │ ├── aligned_dataset.cpython-35.pyc │ ├── aligned_dataset.cpython-36.pyc │ ├── base_data_loader.cpython-35.pyc │ ├── base_data_loader.cpython-36.pyc │ ├── unaligned_dataset.cpython-35.pyc │ ├── custom_dataset_data_loader.cpython-35.pyc │ └── custom_dataset_data_loader.cpython-36.pyc ├── data_loader.py ├── base_data_loader.py ├── preprocess_celeba.sh ├── custom_dataset_data_loader.py ├── single_dataset.py ├── image_folder.py ├── CelebA_dataset.py ├── aligned_dataset.py ├── colorization_dataset.py ├── unaligned_dataset.py ├── template_dataset.py ├── __init__.py └── base_dataset.py ├── options ├── __init__.py ├── test_options.py ├── train_options.py └── base_options.py ├── README.md ├── test.py └── train.py /Structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/Structure.png -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /util/util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/util.pyc -------------------------------------------------------------------------------- /util/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__init__.pyc -------------------------------------------------------------------------------- /models/Pilot_bit.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/models/Pilot_bit.pt -------------------------------------------------------------------------------- /util/__pycache__/hdm.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/hdm.cpython-35.pyc -------------------------------------------------------------------------------- /util/__pycache__/hdm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/hdm.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/html.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/html.cpython-35.pyc -------------------------------------------------------------------------------- /util/__pycache__/html.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/html.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/ldpc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/ldpc.cpython-35.pyc -------------------------------------------------------------------------------- /util/__pycache__/mod.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/mod.cpython-35.pyc -------------------------------------------------------------------------------- /util/__pycache__/nnls.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/nnls.cpython-35.pyc -------------------------------------------------------------------------------- /util/__pycache__/nnls.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/nnls.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/util.cpython-35.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/polar.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/polar.cpython-35.pyc -------------------------------------------------------------------------------- /util/__pycache__/polar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/polar.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/channel.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/channel.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /util/__pycache__/image_pool.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/image_pool.cpython-35.pyc -------------------------------------------------------------------------------- /util/__pycache__/image_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/image_pool.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/visualizer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/visualizer.cpython-35.pyc -------------------------------------------------------------------------------- /util/__pycache__/visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/CelebA_dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/CelebA_dataset.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/base_dataset.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/image_folder.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/image_folder.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/image_folder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/image_folder.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/aligned_dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/aligned_dataset.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/aligned_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/aligned_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/base_data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/base_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/inception_score.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/util/__pycache__/inception_score.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/unaligned_dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/unaligned_dataset.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/custom_dataset_data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/custom_dataset_data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /data/__pycache__/custom_dataset_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingyuyng/Deep-JSCC-for-images-with-OFDM/HEAD/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def CreateDataLoader(opt): 4 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 5 | data_loader = CustomDatasetDataLoader() 6 | print(data_loader.name()) 7 | data_loader.initialize(opt) 8 | return data_loader 9 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(self): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /data/preprocess_celeba.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | cd data 4 | unzip img_align_celeba.zip 5 | mkdir CelebA_trainval 6 | mv img_align_celeba CelebA_trainval 7 | mkdir -p CelebA_test/img_align_celeba 8 | 9 | trainval_dir="CelebA_trainval/img_align_celeba" 10 | test_dir="CelebA_test/img_align_celeba" 11 | while IFS=' ' read -r fpath status; do 12 | if [ $status -eq 2 ]; then 13 | mv $trainval_dir"/"$fpath $test_dir"/"$fpath 14 | fi 15 | done < list_eval_partition.txt 16 | -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | from data.aligned_dataset import AlignedDataset 8 | dataset = AlignedDataset() 9 | 10 | print("dataset [%s] was created" % (dataset.name())) 11 | dataset.initialize(opt) 12 | return dataset 13 | 14 | class CustomDatasetDataLoader(BaseDataLoader): 15 | def name(self): 16 | return 'CustomDatasetDataLoader' 17 | 18 | def initialize(self, opt): 19 | BaseDataLoader.initialize(self, opt) 20 | self.dataset = CreateDataset(opt) 21 | self.dataloader = torch.utils.data.DataLoader( 22 | self.dataset, 23 | batch_size=opt.batchSize, 24 | shuffle=not opt.serial_batches, 25 | num_workers=int(opt.nThreads)) 26 | 27 | def load_data(self): 28 | return self.dataloader 29 | 30 | def __len__(self): 31 | return min(len(self.dataset), self.opt.max_dataset_size) 32 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--output_path', type=str, default='./results/', help='saves results here.') 13 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 14 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 15 | # Dropout and Batchnorm has different behavioir during training and test. 16 | parser.add_argument('--num_test', type=int, default=10000, help='how many test images to run') 17 | parser.add_argument('--how_many_channel', type=int, default=5, help='number of transmission per image') 18 | 19 | # rewrite devalue values 20 | parser.set_defaults(model='test') 21 | # To avoid cropping, the load_size should be the same as crop_size 22 | parser.set_defaults(load_size=parser.get_default('crop_size')) 23 | self.isTrain = False 24 | return parser 25 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_transform 2 | from data.image_folder import make_dataset 3 | from PIL import Image 4 | 5 | 6 | class SingleDataset(BaseDataset): 7 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 8 | 9 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 10 | """ 11 | 12 | def __init__(self, opt): 13 | """Initialize this dataset class. 14 | 15 | Parameters: 16 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 17 | """ 18 | BaseDataset.__init__(self, opt) 19 | self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size)) 20 | input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 21 | self.transform = get_transform(opt, grayscale=(input_nc == 1)) 22 | 23 | def __getitem__(self, index): 24 | """Return a data point and its metadata information. 25 | 26 | Parameters: 27 | index - - a random integer for data indexing 28 | 29 | Returns a dictionary that contains A and A_paths 30 | A(tensor) - - an image in one domain 31 | A_paths(str) - - the path of the image 32 | """ 33 | A_path = self.A_paths[index] 34 | A_img = Image.open(A_path).convert('RGB') 35 | A = self.transform(A_img) 36 | return {'A': A, 'A_paths': A_path} 37 | 38 | def __len__(self): 39 | """Return the total number of images in the dataset.""" 40 | return len(self.A_paths) 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-JSCC-for-images-with-OFDM 2 | 3 | ![Structure](Structure.png) 4 | 5 | ## Environments 6 | 7 | python=3.8.0 8 | 9 | numpy=1.24.4 10 | 11 | pytorch=1.13.1+cu117 12 | 13 | cuda=12.2 14 | 15 | dominate=2.8.0 16 | 17 | scipy=1.10.1 18 | 19 | visdom=0.2.4 20 | 21 | ## Test the OFDM model 22 | 23 | The test script for the OFDM system implementation is in `models/channel.py`. 24 | 25 | ## Datasets 26 | 27 | This repository contains codes for CIFAR-10 and CelebA. For CelebA, you will need to download the dataset under `data` folder. You can also use other datasets but you need to customize the dataloader. One example is `data/CelebA_dataset.py`. 28 | 29 | ## Train the model 30 | 31 | All available options are under `options` folder. Change `--feedforward` for different models. For example, set feedforward as 'IMPLICIT' for IMPLICIT model in the paper. Set feedforward as 'EXPLICIT-RES' for EXPLICIT model in the paper. 32 | 33 | One example for training: 34 | 35 | python train.py --gpu_ids '0' --feedforward 'EXPLICIT-RES' --N_pilot 2 --n_downsample 2 --C_channel 12 --S 6 36 | --SNR 20 --dataset_mode 'CIFAR10' --n_epochs 200 --n_epochs_decay 200 --lr 1e-3 37 | 38 | Suppose the input image has a size of C x W x H. To keep the size consistent, you would need to satisfy the requirement: WH/(2^(2xn_downsample))xC_channel = Sx128 39 | 40 | ## Reference 41 | 42 | > Mingyu Yang, Chenghong Bian, Hun-Seok Kim, "Deep Joint Source Channel Coding for WirelessImage Transmission with OFDM", accepted by ICC 2021 43 | 44 | @article{yang2021deep, 45 | title={Deep Joint Source Channel Coding for WirelessImage Transmission with OFDM}, 46 | author={Yang, Mingyu and Bian, Chenghong and Kim, Hun-Seok}, 47 | journal={arXiv preprint arXiv:2101.03909}, 48 | year={2021} 49 | } 50 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | '.tif', '.TIF', '.tiff', '.TIFF', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir, max_dataset_size=float("inf")): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | return images[:min(max_dataset_size, len(images))] 34 | 35 | 36 | def default_loader(path): 37 | return Image.open(path).convert('RGB') 38 | 39 | 40 | class ImageFolder(data.Dataset): 41 | 42 | def __init__(self, root, transform=None, return_paths=False, 43 | loader=default_loader): 44 | imgs = make_dataset(root) 45 | if len(imgs) == 0: 46 | raise(RuntimeError("Found 0 images in: " + root + "\n" 47 | "Supported image extensions are: " + 48 | ",".join(IMG_EXTENSIONS))) 49 | 50 | self.root = root 51 | self.imgs = imgs 52 | self.transform = transform 53 | self.return_paths = return_paths 54 | self.loader = loader 55 | 56 | def __getitem__(self, index): 57 | path = self.imgs[index] 58 | img = self.loader(path) 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | if self.return_paths: 62 | return img, path 63 | else: 64 | return img 65 | 66 | def __len__(self): 67 | return len(self.imgs) 68 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images 55 | -------------------------------------------------------------------------------- /data/CelebA_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_params, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | from torchvision import transforms 6 | 7 | class CelebADataset(BaseDataset): 8 | """A dataset class for paired image dataset. 9 | 10 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. 11 | During test time, you need to prepare a directory '/path/to/data/test'. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | self.dir_AB = opt.dataroot # get the image directory 22 | self.paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths 23 | assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image 24 | self.input_nc = self.opt.input_nc 25 | self.output_nc = self.opt.output_nc 26 | 27 | def __getitem__(self, index): 28 | """Return a data point and its metadata information. 29 | 30 | Parameters: 31 | index - - a random integer for data indexing 32 | 33 | Returns a dictionary that contains A, B, A_paths and B_paths 34 | A (tensor) - - an image in the input domain 35 | B (tensor) - - its corresponding image in the target domain 36 | A_paths (str) - - image paths 37 | B_paths (str) - - image paths (same as A_paths) 38 | """ 39 | # read a image given a random integer index 40 | path = self.paths[index] 41 | img = Image.open(path).convert('RGB') 42 | # split AB image into A and B 43 | w, h = img.size 44 | 45 | transform = transforms.Compose([ 46 | transforms.CenterCrop((140, 140)), 47 | transforms.Resize((64, 64)), 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 50 | 51 | # apply the same transform to both A and B 52 | img = transform(img) 53 | 54 | return {'data': img, 'path': path} 55 | 56 | def __len__(self): 57 | """Return the total number of images in the dataset.""" 58 | return len(self.paths) 59 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | import math 7 | 8 | 9 | 10 | def clipping(clipping_ratio, x): 11 | 12 | amp = x.abs() 13 | sigma = torch.sqrt(torch.mean(amp**2, -1, True)) 14 | ratio = sigma*clipping_ratio/amp 15 | scale = torch.min(ratio, torch.ones_like(ratio)) 16 | 17 | with torch.no_grad(): 18 | bias = x*scale - x 19 | 20 | return x + bias 21 | 22 | 23 | def add_cp(x, cp_len): 24 | return torch.cat((x[...,-cp_len:], x), dim=-1) 25 | 26 | 27 | def rm_cp(x, cp_len): 28 | return x[...,cp_len:] 29 | 30 | 31 | def batch_conv1d(x, weights): 32 | ''' 33 | Enable batch-wise convolution using group convolution operations 34 | x: BxN 35 | weight: BxL 36 | ''' 37 | 38 | assert x.shape[0] == weights.shape[0] 39 | 40 | b, n = x.shape 41 | l = weights.shape[1] 42 | 43 | x = x.unsqueeze(0) # 1xBxN 44 | weights = weights.unsqueeze(1) # Bx1xL 45 | x = F.pad(x, (l-1, 0), "constant", 0) # 1xBx(N+L-1) 46 | out = F.conv1d(x, weight=weights, bias=None, stride=1, dilation=1, groups=b, padding=0) # 1xBxN 47 | 48 | return out 49 | 50 | def PAPR(x): 51 | power = torch.mean((x.abs())**2, -1) 52 | pwr_max, _ = torch.max((x.abs())**2, -1) 53 | return 10*torch.log10(pwr_max/power) 54 | 55 | def normalize(x, power): 56 | pwr = torch.mean(x.abs()**2, -1, True) 57 | return np.sqrt(power)*x/torch.sqrt(pwr) 58 | 59 | 60 | def ZF_equalization(H_est, Y): 61 | # H_est: NxPx1xMx2 62 | # Y: NxPxSxMx2 63 | return Y/H_est 64 | 65 | def MMSE_equalization(H_est, Y, noise_pwr): 66 | # H_est: NxPx1xM 67 | # Y: NxPxSxM 68 | # no = complex_multiplication(Y, complex_conjugate(H_est)) 69 | # de = complex_amp2(H_est)**2 + noise_pwr.unsqueeze(-1) 70 | # return no/de 71 | no = Y * H_est.conj() 72 | de = H_est.abs()**2 + noise_pwr.unsqueeze(-1) 73 | return no/de 74 | 75 | def LS_channel_est(pilot_tx, pilot_rx): 76 | # pilot_tx: NxPx1xM 77 | # pilot_rx: NxPxS'xM 78 | return torch.mean(pilot_rx, 2, True)/pilot_tx 79 | 80 | def LMMSE_channel_est(pilot_tx, pilot_rx, noise_pwr): 81 | # pilot_tx: NxPx1xM 82 | # pilot_rx: NxPxS'xM 83 | #return complex_multiplication(torch.mean(pilot_rx, 2, True), complex_conjugate(pilot_tx))/(1+(noise_pwr.unsqueeze(-1)/pilot_rx.shape[2])) 84 | return torch.mean(pilot_rx, 2, True)*pilot_tx.conj()/(1+(noise_pwr.unsqueeze(-1)/pilot_rx.shape[2])) 85 | 86 | 87 | -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_params, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | 6 | 7 | class AlignedDataset(BaseDataset): 8 | """A dataset class for paired image dataset. 9 | 10 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. 11 | During test time, you need to prepare a directory '/path/to/data/test'. 12 | """ 13 | 14 | def __init__(self, opt): 15 | """Initialize this dataset class. 16 | 17 | Parameters: 18 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | BaseDataset.__init__(self, opt) 21 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory 22 | self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths 23 | assert(self.opt.load_size >= self.opt.crop_size) # crop_size should be smaller than the size of loaded image 24 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 25 | self.output_nc = self.opt.input_nc if self.opt.direction == 'BtoA' else self.opt.output_nc 26 | 27 | def __getitem__(self, index): 28 | """Return a data point and its metadata information. 29 | 30 | Parameters: 31 | index - - a random integer for data indexing 32 | 33 | Returns a dictionary that contains A, B, A_paths and B_paths 34 | A (tensor) - - an image in the input domain 35 | B (tensor) - - its corresponding image in the target domain 36 | A_paths (str) - - image paths 37 | B_paths (str) - - image paths (same as A_paths) 38 | """ 39 | # read a image given a random integer index 40 | AB_path = self.AB_paths[index] 41 | AB = Image.open(AB_path).convert('RGB') 42 | # split AB image into A and B 43 | w, h = AB.size 44 | w2 = int(w / 2) 45 | A = AB.crop((0, 0, w2, h)) 46 | B = AB.crop((w2, 0, w, h)) 47 | 48 | # apply the same transform to both A and B 49 | transform_params = get_params(self.opt, A.size) 50 | A_transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 51 | B_transform = get_transform(self.opt, transform_params, grayscale=(self.output_nc == 1)) 52 | 53 | A = A_transform(A) 54 | B = B_transform(B) 55 | 56 | return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} 57 | 58 | def __len__(self): 59 | """Return the total number of images in the dataset.""" 60 | return len(self.AB_paths) 61 | -------------------------------------------------------------------------------- /data/colorization_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from skimage import color # require skimage 5 | from PIL import Image 6 | import numpy as np 7 | import torchvision.transforms as transforms 8 | 9 | 10 | class ColorizationDataset(BaseDataset): 11 | """This dataset class can load a set of natural images in RGB, and convert RGB format into (L, ab) pairs in Lab color space. 12 | 13 | This dataset is required by pix2pix-based colorization model ('--model colorization') 14 | """ 15 | @staticmethod 16 | def modify_commandline_options(parser, is_train): 17 | """Add new dataset-specific options, and rewrite default values for existing options. 18 | 19 | Parameters: 20 | parser -- original option parser 21 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 22 | 23 | Returns: 24 | the modified parser. 25 | 26 | By default, the number of channels for input image is 1 (L) and 27 | the number of channels for output image is 2 (ab). The direction is from A to B 28 | """ 29 | parser.set_defaults(input_nc=1, output_nc=2, direction='AtoB') 30 | return parser 31 | 32 | def __init__(self, opt): 33 | """Initialize this dataset class. 34 | 35 | Parameters: 36 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 37 | """ 38 | BaseDataset.__init__(self, opt) 39 | self.dir = os.path.join(opt.dataroot, opt.phase) 40 | self.AB_paths = sorted(make_dataset(self.dir, opt.max_dataset_size)) 41 | assert(opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == 'AtoB') 42 | self.transform = get_transform(self.opt, convert=False) 43 | 44 | def __getitem__(self, index): 45 | """Return a data point and its metadata information. 46 | 47 | Parameters: 48 | index - - a random integer for data indexing 49 | 50 | Returns a dictionary that contains A, B, A_paths and B_paths 51 | A (tensor) - - the L channel of an image 52 | B (tensor) - - the ab channels of the same image 53 | A_paths (str) - - image paths 54 | B_paths (str) - - image paths (same as A_paths) 55 | """ 56 | path = self.AB_paths[index] 57 | im = Image.open(path).convert('RGB') 58 | im = self.transform(im) 59 | im = np.array(im) 60 | lab = color.rgb2lab(im).astype(np.float32) 61 | lab_t = transforms.ToTensor()(lab) 62 | A = lab_t[[0], ...] / 50.0 - 1.0 63 | B = lab_t[[1, 2], ...] / 110.0 64 | return {'A': A, 'B': B, 'A_paths': path, 'B_paths': path} 65 | 66 | def __len__(self): 67 | """Return the total number of images in the dataset.""" 68 | return len(self.AB_paths) 69 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /data/unaligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | 7 | 8 | class UnalignedDataset(BaseDataset): 9 | """ 10 | This dataset class can load unaligned/unpaired datasets. 11 | 12 | It requires two directories to host training images from domain A '/path/to/data/trainA' 13 | and from domain B '/path/to/data/trainB' respectively. 14 | You can train the model with the dataset flag '--dataroot /path/to/data'. 15 | Similarly, you need to prepare two directories: 16 | '/path/to/data/testA' and '/path/to/data/testB' during test time. 17 | """ 18 | 19 | def __init__(self, opt): 20 | """Initialize this dataset class. 21 | 22 | Parameters: 23 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 24 | """ 25 | BaseDataset.__init__(self, opt) 26 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA' 27 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB' 28 | 29 | self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' 30 | self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 31 | self.A_size = len(self.A_paths) # get the size of dataset A 32 | self.B_size = len(self.B_paths) # get the size of dataset B 33 | btoA = self.opt.direction == 'BtoA' 34 | input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image 35 | output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image 36 | self.transform_A = get_transform(self.opt, grayscale=(input_nc == 1)) 37 | self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1)) 38 | 39 | def __getitem__(self, index): 40 | """Return a data point and its metadata information. 41 | 42 | Parameters: 43 | index (int) -- a random integer for data indexing 44 | 45 | Returns a dictionary that contains A, B, A_paths and B_paths 46 | A (tensor) -- an image in the input domain 47 | B (tensor) -- its corresponding image in the target domain 48 | A_paths (str) -- image paths 49 | B_paths (str) -- image paths 50 | """ 51 | A_path = self.A_paths[index % self.A_size] # make sure index is within then range 52 | if self.opt.serial_batches: # make sure index is within then range 53 | index_B = index % self.B_size 54 | else: # randomize the index for domain B to avoid fixed pairs. 55 | index_B = random.randint(0, self.B_size - 1) 56 | B_path = self.B_paths[index_B] 57 | A_img = Image.open(A_path).convert('RGB') 58 | B_img = Image.open(B_path).convert('RGB') 59 | # apply image transformation 60 | A = self.transform_A(A_img) 61 | B = self.transform_B(B_img) 62 | 63 | return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path} 64 | 65 | def __len__(self): 66 | """Return the total number of images in the dataset. 67 | 68 | As we have two datasets with potentially different number of images, 69 | we take a maximum of 70 | """ 71 | return max(self.A_size, self.B_size) 72 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | if not isinstance(input_image, np.ndarray): 17 | if isinstance(input_image, torch.Tensor): # get the data from a variable 18 | image_tensor = input_image.data 19 | else: 20 | return input_image 21 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 22 | if image_numpy.shape[0] == 1: # grayscale to RGB 23 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 24 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 25 | else: # if it is a numpy array, do nothing 26 | image_numpy = input_image 27 | return image_numpy.astype(imtype) 28 | 29 | 30 | def diagnose_network(net, name='network'): 31 | """Calculate and print the mean of average absolute(gradients) 32 | 33 | Parameters: 34 | net (torch network) -- Torch network 35 | name (str) -- the name of the network 36 | """ 37 | mean = 0.0 38 | count = 0 39 | for param in net.parameters(): 40 | if param.grad is not None: 41 | mean += torch.mean(torch.abs(param.grad.data)) 42 | count += 1 43 | if count > 0: 44 | mean = mean / count 45 | print(name) 46 | print(mean) 47 | 48 | 49 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 50 | """Save a numpy image to the disk 51 | 52 | Parameters: 53 | image_numpy (numpy array) -- input numpy array 54 | image_path (str) -- the path of the image 55 | """ 56 | 57 | image_pil = Image.fromarray(image_numpy) 58 | h, w, _ = image_numpy.shape 59 | 60 | if aspect_ratio > 1.0: 61 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 62 | if aspect_ratio < 1.0: 63 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 64 | image_pil.save(image_path) 65 | 66 | 67 | def print_numpy(x, val=True, shp=False): 68 | """Print the mean, min, max, median, std, and size of a numpy array 69 | 70 | Parameters: 71 | val (bool) -- if print the values of the numpy array 72 | shp (bool) -- if print the shape of the numpy array 73 | """ 74 | x = x.astype(np.float64) 75 | if shp: 76 | print('shape,', x.shape) 77 | if val: 78 | x = x.flatten() 79 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 80 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 81 | 82 | 83 | def mkdirs(paths): 84 | """create empty directories if they don't exist 85 | 86 | Parameters: 87 | paths (str list) -- a list of directory paths 88 | """ 89 | if isinstance(paths, list) and not isinstance(paths, str): 90 | for path in paths: 91 | mkdir(path) 92 | else: 93 | mkdir(paths) 94 | 95 | 96 | def mkdir(path): 97 | """create a single empty directory if it didn't exist 98 | 99 | Parameters: 100 | path (str) -- a single directory path 101 | """ 102 | if not os.path.exists(path): 103 | os.makedirs(path) 104 | -------------------------------------------------------------------------------- /data/template_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | from data.base_dataset import BaseDataset, get_transform 15 | # from data.image_folder import make_dataset 16 | # from PIL import Image 17 | 18 | 19 | class TemplateDataset(BaseDataset): 20 | """A template dataset class for you to implement custom datasets.""" 21 | @staticmethod 22 | def modify_commandline_options(parser, is_train): 23 | """Add new dataset-specific options, and rewrite default values for existing options. 24 | 25 | Parameters: 26 | parser -- original option parser 27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 28 | 29 | Returns: 30 | the modified parser. 31 | """ 32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') 33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values 34 | return parser 35 | 36 | def __init__(self, opt): 37 | """Initialize this dataset class. 38 | 39 | Parameters: 40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 41 | 42 | A few things can be done here. 43 | - save the options (have been done in BaseDataset) 44 | - get image paths and meta information of the dataset. 45 | - define the image transformation. 46 | """ 47 | # save the option and dataset root 48 | BaseDataset.__init__(self, opt) 49 | # get the image paths of your dataset; 50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 51 | # define the default transform function. You can use ; You can also define your custom transform function 52 | self.transform = get_transform(opt) 53 | 54 | def __getitem__(self, index): 55 | """Return a data point and its metadata information. 56 | 57 | Parameters: 58 | index -- a random integer for data indexing 59 | 60 | Returns: 61 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 62 | 63 | Step 1: get a random image path: e.g., path = self.image_paths[index] 64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 66 | Step 4: return a data point as a dictionary. 67 | """ 68 | path = 'temp' # needs to be a string 69 | data_A = None # needs to be a tensor 70 | data_B = None # needs to be a tensor 71 | return {'data_A': data_A, 'data_B': data_B, 'path': path} 72 | 73 | def __len__(self): 74 | """Return the total number of images.""" 75 | return len(self.image_paths) 76 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_threads)) 80 | 81 | def load_data(self): 82 | return self 83 | 84 | def __len__(self): 85 | """Return the number of data in the dataset""" 86 | return min(len(self.dataset), self.opt.max_dataset_size) 87 | 88 | def __iter__(self): 89 | """Return a batch of data""" 90 | for i, data in enumerate(self.dataloader): 91 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 92 | break 93 | yield data 94 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument('--display_freq', type=int, default=50, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 15 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 17 | parser.add_argument('--display_env', type=str, default='C6SNR5_NotNG', help='visdom display environment name (default is "main")') 18 | parser.add_argument('--display_port', type=int, default=8998, help='visdom port of the web display') 19 | parser.add_argument('--update_html_freq', type=int, default=100, help='frequency of saving training results to html') 20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 22 | # network saving and loading parameters 23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 24 | parser.add_argument('--save_epoch_freq', type=int, default=40, help='frequency of saving checkpoints at the end of epochs') 25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 29 | # training parameters 30 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs with the initial learning rate') 31 | parser.add_argument('--n_epochs_decay', type=int, default=200, help='number of epochs to linearly decay learning rate to zero') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 33 | parser.add_argument('--lr', type=float, default=1e-3, help='initial learning rate for adam') 34 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 35 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 36 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 37 | parser.add_argument('--lambda_L2', type=float, default=128, help='weights for the L2 loss') 38 | parser.add_argument('--lambda_feat', type=float, default=1, help='weights for the L2 loss') 39 | parser.add_argument('--lambda_papr', type=float, default=0, help='weights for the L2 loss') 40 | parser.add_argument('--lambda_ce', type=float, default=10, help='weights for the L2 loss') 41 | parser.add_argument('--lambda_eq', type=float, default=10, help='weights for the L2 loss') 42 | parser.add_argument('--is_Feat', action='store_true', default=1, help='whether to use feature matching loss for generative training') 43 | 44 | self.isTrain = True 45 | return parser 46 | -------------------------------------------------------------------------------- /util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """A Python script for downloading CycleGAN or pix2pix datasets. 13 | 14 | Parameters: 15 | technique (str) -- One of: 'cyclegan' or 'pix2pix'. 16 | verbose (bool) -- If True, print additional information. 17 | 18 | Examples: 19 | >>> from util.get_data import GetData 20 | >>> gd = GetData(technique='cyclegan') 21 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 22 | 23 | Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' 24 | and 'scripts/download_cyclegan_model.sh'. 25 | """ 26 | 27 | def __init__(self, technique='cyclegan', verbose=True): 28 | url_dict = { 29 | 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', 30 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 31 | } 32 | self.url = url_dict.get(technique.lower()) 33 | self._verbose = verbose 34 | 35 | def _print(self, text): 36 | if self._verbose: 37 | print(text) 38 | 39 | @staticmethod 40 | def _get_options(r): 41 | soup = BeautifulSoup(r.text, 'lxml') 42 | options = [h.text for h in soup.find_all('a', href=True) 43 | if h.text.endswith(('.zip', 'tar.gz'))] 44 | return options 45 | 46 | def _present_options(self): 47 | r = requests.get(self.url) 48 | options = self._get_options(r) 49 | print('Options:\n') 50 | for i, o in enumerate(options): 51 | print("{0}: {1}".format(i, o)) 52 | choice = input("\nPlease enter the number of the " 53 | "dataset above you wish to download:") 54 | return options[int(choice)] 55 | 56 | def _download_data(self, dataset_url, save_path): 57 | if not isdir(save_path): 58 | os.makedirs(save_path) 59 | 60 | base = basename(dataset_url) 61 | temp_save_path = join(save_path, base) 62 | 63 | with open(temp_save_path, "wb") as f: 64 | r = requests.get(dataset_url) 65 | f.write(r.content) 66 | 67 | if base.endswith('.tar.gz'): 68 | obj = tarfile.open(temp_save_path) 69 | elif base.endswith('.zip'): 70 | obj = ZipFile(temp_save_path, 'r') 71 | else: 72 | raise ValueError("Unknown File Type: {0}.".format(base)) 73 | 74 | self._print("Unpacking Data...") 75 | obj.extractall(save_path) 76 | obj.close() 77 | os.remove(temp_save_path) 78 | 79 | def get(self, save_path, dataset=None): 80 | """ 81 | 82 | Download a dataset. 83 | 84 | Parameters: 85 | save_path (str) -- A directory to save the data to. 86 | dataset (str) -- (optional). A specific dataset to download. 87 | Note: this must include the file extension. 88 | If None, options will be presented for you 89 | to choose from. 90 | 91 | Returns: 92 | save_path_full (str) -- the absolute path to the downloaded data. 93 | 94 | """ 95 | if dataset is None: 96 | selected_dataset = self._present_options() 97 | else: 98 | selected_dataset = dataset 99 | 100 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 101 | 102 | if isdir(save_path_full): 103 | warn("\n'{0}' already exists. Voiding Download.".format( 104 | save_path_full)) 105 | else: 106 | self._print('Downloading Data...') 107 | url = "{0}/{1}".format(self.url, selected_dataset) 108 | self._download_data(url, save_path=save_path) 109 | 110 | return abspath(save_path_full) 111 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import time 4 | from models import create_model 5 | from options.test_options import TestOptions 6 | from data.data_loader import CreateDataLoader 7 | import util.util as util 8 | from util.visualizer import Visualizer 9 | import os 10 | import numpy as np 11 | import torch 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | import scipy.io as sio 15 | import models.channel as chan 16 | import shutil 17 | from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM 18 | import math 19 | 20 | # Extract the options 21 | opt = TestOptions().parse() 22 | 23 | opt.batch_size = 1 # batch size 24 | 25 | if opt.dataset_mode == 'CIFAR10': 26 | opt.dataroot='./data' 27 | opt.size = 32 28 | transform = transforms.Compose( 29 | [transforms.ToTensor(), 30 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 31 | 32 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 33 | download=True, transform=transform) 34 | dataset = torch.utils.data.DataLoader(testset, batch_size=opt.batch_size, 35 | shuffle=False, num_workers=2) 36 | dataset_size = len(dataset) 37 | print('#training images = %d' % dataset_size) 38 | 39 | elif opt.dataset_mode == 'CelebA': 40 | opt.dataroot = './data/celeba/CelebA_test' 41 | opt.load_size = 80 42 | opt.crop_size = 64 43 | opt.size = 64 44 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 45 | dataset_size = len(dataset) 46 | print('#training images = %d' % dataset_size) 47 | else: 48 | raise Exception('Not implemented yet') 49 | 50 | ######################################## OFDM setting ########################################### 51 | model = create_model(opt) # create a model given opt.model and other options 52 | model.setup(opt) # regular setup: load and print networks; create schedulers 53 | model.eval() 54 | 55 | if os.path.exists(output_path) == False: 56 | os.makedirs(output_path) 57 | else: 58 | shutil.rmtree(output_path) 59 | os.makedirs(output_path) 60 | 61 | PSNR_list = [] 62 | SSIM_list = [] 63 | for i, data in enumerate(dataset): 64 | if i >= opt.num_test: # only apply our model to opt.num_test images. 65 | break 66 | 67 | start_time = time.time() 68 | 69 | if opt.dataset_mode == 'CIFAR10': 70 | input = data[0] 71 | elif opt.dataset_mode == 'CelebA': 72 | input = data['data'] 73 | 74 | model.set_input(input.repeat(opt.how_many_channel,1,1,1)) 75 | model.forward() 76 | fake = model.fake 77 | 78 | # Get the int8 generated images 79 | img_gen_numpy = fake.detach().cpu().float().numpy() 80 | img_gen_numpy = (np.transpose(img_gen_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0 81 | img_gen_int8 = img_gen_numpy.astype(np.uint8) 82 | 83 | origin_numpy = input.detach().cpu().float().numpy() 84 | origin_numpy = (np.transpose(origin_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0 85 | origin_int8 = origin_numpy.astype(np.uint8) 86 | 87 | 88 | diff = np.mean((np.float64(img_gen_int8)-np.float64(origin_int8))**2, (1,2,3)) 89 | 90 | PSNR = 10*np.log10((255**2)/diff) 91 | PSNR_list.append(np.mean(PSNR)) 92 | 93 | img_gen_tensor = torch.from_numpy(np.transpose(img_gen_int8, (0, 3, 1, 2))).float() 94 | origin_tensor = torch.from_numpy(np.transpose(origin_int8, (0, 3, 1, 2))).float() 95 | 96 | ssim_val = ssim(img_gen_tensor, origin_tensor.repeat(opt.how_many_channel,1,1,1), data_range=255, size_average=False) # return (N,) 97 | SSIM_list.append(torch.mean(ssim_val)) 98 | 99 | # Save the first sampled image 100 | save_path = output_path + '/' + str(i) + '_PSNR_' + str(PSNR[0]) +'_SSIM_' + str(ssim_val[0])+'.png' 101 | util.save_image(util.tensor2im(fake[0].unsqueeze(0)), save_path, aspect_ratio=1) 102 | 103 | save_path = output_path + '/' + str(i) + '.png' 104 | util.save_image(util.tensor2im(input), save_path, aspect_ratio=1) 105 | if i%100 == 0: 106 | print(i) 107 | 108 | 109 | print('PSNR: '+str(np.mean(PSNR_list))) 110 | print('SSIM: '+str(np.mean(SSIM_list))) 111 | print('MSE CE: '+str(np.mean(H_err_list))) 112 | print('MSE EQ: '+str(np.mean(x_err_list))) 113 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import time 4 | from models import create_model 5 | from data import create_dataset 6 | from options.train_options import TrainOptions 7 | from data.data_loader import CreateDataLoader 8 | import util.util as util 9 | from util.visualizer import Visualizer 10 | import os 11 | import numpy as np 12 | import torch 13 | import torchvision 14 | import torchvision.transforms as transforms 15 | import scipy.io as sio 16 | 17 | 18 | # Extract the options 19 | opt = TrainOptions().parse() 20 | 21 | 22 | if opt.dataset_mode == 'CIFAR10': 23 | opt.dataroot='./data' 24 | opt.size = 32 25 | transform = transforms.Compose( 26 | [transforms.RandomHorizontalFlip(p=0.5), 27 | transforms.RandomCrop(opt.size, padding=5, pad_if_needed=True, fill=0, padding_mode='reflect'), 28 | transforms.ToTensor(), 29 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 30 | 31 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 32 | download=True, transform=transform) 33 | dataset = torch.utils.data.DataLoader(trainset, batch_size=opt.batch_size, 34 | shuffle=True, num_workers=2, drop_last=True) 35 | dataset_size = len(dataset) 36 | print('#training images = %d' % dataset_size) 37 | 38 | elif opt.dataset_mode == 'CelebA': 39 | opt.dataroot = './data/celeba/CelebA_train' 40 | opt.load_size = 80 41 | opt.crop_size = 64 42 | opt.size = 64 43 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 44 | dataset_size = len(dataset) 45 | print('#training images = %d' % dataset_size) 46 | else: 47 | raise Exception('Not implemented yet') 48 | 49 | 50 | model = create_model(opt) # create a model given opt.model and other options 51 | model.setup(opt) # regular setup: load and print networks; create schedulers 52 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots 53 | total_iters = 0 # the total number of training iterations 54 | 55 | ################ Train with the Discriminator 56 | for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): # outer loop for different epochs; we save the model by , + 57 | epoch_start_time = time.time() # timer for entire epoch 58 | iter_data_time = time.time() # timer for data loading per iteration 59 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 60 | visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch 61 | 62 | 63 | for i, data in enumerate(dataset): # inner loop within one epoch 64 | iter_start_time = time.time() # timer for computation per iteration 65 | if total_iters % opt.print_freq == 0: 66 | t_data = iter_start_time - iter_data_time 67 | 68 | total_iters += opt.batch_size 69 | epoch_iter += opt.batch_size 70 | 71 | if opt.dataset_mode == 'CIFAR10': 72 | input = data[0] 73 | elif opt.dataset_mode == 'CelebA': 74 | input = data['data'] 75 | 76 | model.set_input(input) # unpack data from dataset and apply preprocessing 77 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 78 | 79 | if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file 80 | save_result = total_iters % opt.update_html_freq == 0 81 | model.compute_visuals() 82 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 83 | 84 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk 85 | losses = model.get_current_losses() 86 | t_comp = (time.time() - iter_start_time) / opt.batch_size 87 | visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) 88 | if opt.display_id > 0: 89 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) 90 | 91 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations 92 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 93 | save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' 94 | model.save_networks(save_suffix) 95 | iter_data_time = time.time() 96 | 97 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs 98 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 99 | model.save_networks('latest') 100 | model.save_networks(epoch) 101 | 102 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) 103 | model.update_learning_rate() 104 | 105 | 106 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.root = opt.dataroot 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_params(opt, size): 64 | w, h = size 65 | new_h = h 66 | new_w = w 67 | if opt.preprocess == 'resize_and_crop': 68 | new_h = new_w = opt.load_size 69 | elif opt.preprocess == 'scale_width_and_crop': 70 | new_w = opt.load_size 71 | new_h = opt.load_size * h // w 72 | 73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 75 | 76 | flip = random.random() > 0.5 77 | 78 | return {'crop_pos': (x, y), 'flip': flip} 79 | 80 | 81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 82 | transform_list = [] 83 | if grayscale: 84 | transform_list.append(transforms.Grayscale(1)) 85 | if 'resize' in opt.preprocess: 86 | osize = [opt.load_size, opt.load_size] 87 | transform_list.append(transforms.Resize(osize, method)) 88 | elif 'scale_width' in opt.preprocess: 89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) 90 | 91 | if 'crop' in opt.preprocess: 92 | if params is None: 93 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 94 | else: 95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 96 | 97 | if opt.preprocess == 'none': 98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 99 | 100 | if not opt.no_flip: 101 | if params is None: 102 | transform_list.append(transforms.RandomHorizontalFlip()) 103 | elif params['flip']: 104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 105 | 106 | if convert: 107 | transform_list += [transforms.ToTensor()] 108 | if grayscale: 109 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 110 | else: 111 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 112 | return transforms.Compose(transform_list) 113 | 114 | 115 | def __make_power_2(img, base, method=Image.BICUBIC): 116 | ow, oh = img.size 117 | h = int(round(oh / base) * base) 118 | w = int(round(ow / base) * base) 119 | if h == oh and w == ow: 120 | return img 121 | 122 | __print_size_warning(ow, oh, w, h) 123 | return img.resize((w, h), method) 124 | 125 | 126 | def __scale_width(img, target_size, crop_size, method=Image.BICUBIC): 127 | ow, oh = img.size 128 | if ow == target_size and oh >= crop_size: 129 | return img 130 | w = target_size 131 | h = int(max(target_size * oh / ow, crop_size)) 132 | return img.resize((w, h), method) 133 | 134 | 135 | def __crop(img, pos, size): 136 | ow, oh = img.size 137 | x1, y1 = pos 138 | tw = th = size 139 | if (ow > tw or oh > th): 140 | return img.crop((x1, y1, x1 + tw, y1 + th)) 141 | return img 142 | 143 | 144 | def __flip(img, flip): 145 | if flip: 146 | return img.transpose(Image.FLIP_LEFT_RIGHT) 147 | return img 148 | 149 | 150 | def __print_size_warning(ow, oh, w, h): 151 | """Print warning information about image size(only print once)""" 152 | if not hasattr(__print_size_warning, 'has_printed'): 153 | print("The image size needs to be a multiple of 4. " 154 | "The loaded image size was (%d, %d), so it was adjusted to " 155 | "(%d, %d). This adjustment will be done to all images " 156 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 157 | __print_size_warning.has_printed = True 158 | -------------------------------------------------------------------------------- /models/JSCC_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import numpy as np 4 | import torch 5 | import os 6 | from torch.autograd import Variable 7 | from util.image_pool import ImagePool 8 | from .base_model import BaseModel 9 | from . import channel 10 | from . import networks 11 | 12 | 13 | class PLAINModel(BaseModel): 14 | 15 | def __init__(self, opt): 16 | BaseModel.__init__(self, opt) 17 | 18 | 19 | # specify the training losses you want to print out. The training/test scripts will call 20 | self.loss_names = ['G_GAN', 'G_L2', 'G_Feat', 'D_real', 'D_fake'] 21 | # specify the images you want to save/display. The training/test scripts will call 22 | self.visual_names = ['real_A', 'fake', 'real_B'] 23 | 24 | # specify the models you want to save to the disk. The training/test scripts will call and 25 | if self.opt.gan_mode != 'none': 26 | self.model_names = ['E', 'G', 'D'] 27 | else: # during test time, only load G 28 | self.model_names = ['E', 'G'] 29 | 30 | # define networks (both generator and discriminator) 31 | self.netE = networks.define_E(input_nc=opt.input_nc, ngf=opt.ngf, max_ngf=opt.max_ngf, 32 | n_downsample=opt.n_downsample, C_channel=opt.C_channel, 33 | n_blocks=opt.n_blocks, norm=opt.norm_EG, init_type=opt.init_type, 34 | init_gain=opt.init_gain, gpu_ids=self.gpu_ids, first_kernel=opt.first_kernel) 35 | 36 | self.netG = networks.define_G(output_nc=opt.output_nc, ngf=opt.ngf, max_ngf=opt.max_ngf, 37 | n_downsample=opt.n_downsample, C_channel=opt.C_channel, 38 | n_blocks=opt.n_blocks, norm=opt.norm_EG, init_type=opt.init_type, 39 | init_gain=opt.init_gain, gpu_ids=self.gpu_ids, first_kernel=opt.first_kernel, activation=opt.activation) 40 | 41 | if self.opt.gan_mode != 'none': # define a discriminator; 42 | 43 | self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.n_layers_D, 44 | opt.norm_D, opt.init_type, opt.init_gain, self.gpu_ids) 45 | 46 | 47 | print('---------- Networks initialized -------------') 48 | 49 | # set loss functions and optimizers 50 | if self.isTrain: 51 | self.criterionGAN = networks.GANLoss(opt.gan_mode, opt.label_smooth, 1-opt.label_smooth).to(self.device) 52 | self.criterionFeat = torch.nn.L1Loss() 53 | self.criterionL2 = torch.nn.MSELoss() 54 | 55 | params = list(self.netE.parameters()) + list(self.netG.parameters()) 56 | self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) 57 | self.optimizers.append(self.optimizer_G) 58 | 59 | if self.opt.gan_mode != 'none': 60 | params = list(self.netD.parameters()) 61 | self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) 62 | self.optimizers.append(self.optimizer_D) 63 | 64 | 65 | self.normalize = networks.Normalize() 66 | self.opt = opt 67 | 68 | self.channel = channel.plain_channel(opt, self.device, pwr=1) 69 | 70 | 71 | def name(self): 72 | return 'PLAIN_Model' 73 | 74 | def set_input(self, image): 75 | self.real_A = image.clone().to(self.device) 76 | self.real_B = image.clone().to(self.device) 77 | 78 | def set_encode(self, image): 79 | self.real_A = image.clone().to(self.device) 80 | self.real_B = image.clone().to(self.device) 81 | 82 | def set_decode(self, latent): 83 | self.latent = latent.to(self.device) 84 | 85 | def set_img_path(self, path): 86 | self.image_paths = path 87 | 88 | def forward(self): 89 | 90 | # Generate latent vector 91 | self.latent = self.netE(self.real_A) 92 | 93 | # 2. Pass the channel 94 | latent = self.channel(self.latent, self.opt.SNR) 95 | 96 | # 3. Reconstruction 97 | self.fake = self.netG(latent) 98 | 99 | def backward_D(self): 100 | """Calculate GAN loss for the discriminator""" 101 | # Fake; stop backprop to the generator by detaching fake_B 102 | 103 | _, pred_fake = self.netD(self.fake.detach()) 104 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 105 | 106 | real_data = self.real_B 107 | _, pred_real = self.netD(real_data) 108 | self.loss_D_real = self.criterionGAN(pred_real, True) 109 | 110 | if self.opt.gan_mode in ['lsgan', 'vanilla']: 111 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 112 | self.loss_D.backward() 113 | elif self.opt.gan_mode == 'wgangp': 114 | penalty, grad = networks.cal_gradient_penalty(self.netD, real_data, self.fake.detach(), self.device, type='mixed', constant=1.0, lambda_gp=10.0) 115 | self.loss_D = self.loss_D_fake + self.loss_D_real + penalty 116 | self.loss_D.backward(retain_graph=True) 117 | 118 | def backward_G(self): 119 | """Calculate GAN and L1 loss for the generator""" 120 | # First, G(A) should fake the discriminator 121 | 122 | if self.opt.gan_mode != 'none': 123 | feat_fake, pred_fake = self.netD(self.fake) 124 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 125 | 126 | if self.opt.is_Feat: 127 | feat_real, pred_real = self.netD(self.real_B) 128 | self.loss_G_Feat = 0 129 | for j in range(len(feat_real)): 130 | self.loss_G_Feat += self.criterionFeat(feat_real[j].detach(), feat_fake[j]) * self.opt.lambda_feat 131 | else: 132 | self.loss_G_Feat = 0 133 | 134 | else: 135 | self.loss_G_GAN = 0 136 | self.loss_G_Feat = 0 137 | 138 | self.loss_G_L2 = self.criterionL2(self.fake, self.real_B) * self.opt.lambda_L2 139 | # combine loss and calculate gradients 140 | self.loss_G = self.loss_G_GAN + self.loss_G_Feat + self.loss_G_L2 141 | self.loss_G.backward() 142 | 143 | def optimize_parameters(self): 144 | self.forward() # compute fake images: G(A) 145 | # update D 146 | if self.opt.gan_mode != 'none': 147 | self.set_requires_grad(self.netD, True) # enable backprop for D 148 | self.optimizer_D.zero_grad() # set D's gradients to zero 149 | self.backward_D() # calculate gradients for D 150 | self.optimizer_D.step() # update D's weights 151 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 152 | else: 153 | self.loss_D_fake = 0 154 | self.loss_D_real = 0 155 | # update G 156 | 157 | self.optimizer_G.zero_grad() # set G's gradients to zero 158 | self.backward_G() # calculate graidents for G 159 | self.optimizer_G.step() # udpate G's weights 160 | 161 | 162 | def get_encoded(self): 163 | return self.netE(self.real_A) 164 | 165 | def get_decoded(self, latent): 166 | if self.opt.channel == 'awgn': # AWGN channel 167 | self.latent = self.normalize(latent, 1) 168 | elif self.opt.channel == 'bsc': # BSC channel 169 | self.latent = torch.sigmoid(latent) 170 | 171 | # 2. Pass the channel 172 | latent_input = self.channel(self.latent) 173 | 174 | # 3. Reconstruction 175 | return self.netG(latent_input) 176 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import os 4 | import torch 5 | import sys 6 | from . import networks 7 | from collections import OrderedDict 8 | 9 | class BaseModel(torch.nn.Module): 10 | def name(self): 11 | return 'BaseModel' 12 | 13 | def __init__(self, opt): 14 | super().__init__() 15 | """Initialize the BaseModel class. 16 | 17 | Parameters: 18 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | 20 | When creating your custom class, you need to implement your own initialization. 21 | In this function, you should first call 22 | Then, you need to define four lists: 23 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 24 | -- self.model_names (str list): define networks used in our training. 25 | -- self.visual_names (str list): specify the images that you want to display and save. 26 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 27 | """ 28 | self.opt = opt 29 | self.gpu_ids = opt.gpu_ids 30 | self.isTrain = opt.isTrain 31 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 32 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 33 | self.loss_names = [] 34 | self.model_names = [] 35 | self.visual_names = [] 36 | self.optimizers = [] 37 | self.image_paths = [] 38 | self.metric = 0 # used for learning rate policy 'plateau' 39 | 40 | def set_input(self, input): 41 | pass 42 | 43 | def forward(self): 44 | pass 45 | 46 | # used in test time, no backprop 47 | def test(self): 48 | pass 49 | 50 | def compute_visuals(self): 51 | """Calculate additional output images for visdom and HTML visualization""" 52 | pass 53 | 54 | def get_image_paths(self): 55 | """ Return image paths that are used to load current data""" 56 | return self.image_paths 57 | 58 | def optimize_parameters(self): 59 | pass 60 | 61 | def get_current_visuals(self): 62 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 63 | visual_ret = OrderedDict() 64 | for name in self.visual_names: 65 | if isinstance(name, str): 66 | visual_ret[name] = getattr(self, name) 67 | return visual_ret 68 | 69 | def get_current_losses(self): 70 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 71 | errors_ret = OrderedDict() 72 | for name in self.loss_names: 73 | if isinstance(name, str): 74 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 75 | return errors_ret 76 | 77 | def setup(self, opt): 78 | """Load and print networks; create schedulers 79 | 80 | Parameters: 81 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 82 | """ 83 | if self.isTrain: 84 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 85 | if not self.isTrain or opt.continue_train: 86 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 87 | self.load_networks(load_suffix) 88 | self.print_networks(opt.verbose) 89 | 90 | def eval(self): 91 | """Make models eval mode during test time""" 92 | for name in self.model_names: 93 | if isinstance(name, str): 94 | net = getattr(self, 'net' + name) 95 | net.eval() 96 | 97 | def update_learning_rate(self): 98 | """Update learning rates for all the networks; called at the end of every epoch""" 99 | for scheduler in self.schedulers: 100 | if self.opt.lr_policy == 'plateau': 101 | scheduler.step(self.metric) 102 | else: 103 | scheduler.step() 104 | 105 | lr = self.optimizers[0].param_groups[0]['lr'] 106 | print('learning rate = %.7f' % lr) 107 | 108 | # helper saving function that can be used by subclasses 109 | def save_networks(self, epoch): 110 | """Save all the networks to the disk. 111 | 112 | Parameters: 113 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 114 | """ 115 | for name in self.model_names: 116 | if isinstance(name, str): 117 | save_filename = '%s_net_%s.pth' % (epoch, name) 118 | save_path = os.path.join(self.save_dir, save_filename) 119 | net = getattr(self, 'net' + name) 120 | 121 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 122 | torch.save(net.module.cpu().state_dict(), save_path) 123 | net.cuda(self.gpu_ids[0]) 124 | else: 125 | torch.save(net.cpu().state_dict(), save_path) 126 | 127 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 128 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 129 | key = keys[i] 130 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 131 | if module.__class__.__name__.startswith('InstanceNorm') and \ 132 | (key == 'running_mean' or key == 'running_var'): 133 | if getattr(module, key) is None: 134 | state_dict.pop('.'.join(keys)) 135 | if module.__class__.__name__.startswith('InstanceNorm') and \ 136 | (key == 'num_batches_tracked'): 137 | state_dict.pop('.'.join(keys)) 138 | else: 139 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 140 | 141 | def load_networks(self, epoch): 142 | """Load all the networks from the disk. 143 | 144 | Parameters: 145 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 146 | """ 147 | for name in self.model_names: 148 | if isinstance(name, str): 149 | load_filename = '%s_net_%s.pth' % (epoch, name) 150 | load_path = os.path.join(self.save_dir, load_filename) 151 | net = getattr(self, 'net' + name) 152 | if isinstance(net, torch.nn.DataParallel): 153 | net = net.module 154 | print('loading the model from %s' % load_path) 155 | # if you are using PyTorch newer than 0.4 (e.g., built from 156 | # GitHub source), you can remove str() on self.device 157 | state_dict = torch.load(load_path, map_location=str(self.device)) 158 | if hasattr(state_dict, '_metadata'): 159 | del state_dict._metadata 160 | 161 | # patch InstanceNorm checkpoints prior to 0.4 162 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 163 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 164 | net.load_state_dict(state_dict) 165 | 166 | def print_networks(self, verbose): 167 | """Print the total number of parameters in the network and (if verbose) network architecture 168 | 169 | Parameters: 170 | verbose (bool) -- if verbose: print the network architecture 171 | """ 172 | print('---------- Networks initialized -------------') 173 | for name in self.model_names: 174 | if isinstance(name, str): 175 | net = getattr(self, 'net' + name) 176 | num_params = 0 177 | for param in net.parameters(): 178 | num_params += param.numel() 179 | if verbose: 180 | print(net) 181 | print('[Network %s] Total number of parameters : %.6f M' % (name, num_params / 1e6)) 182 | print('-----------------------------------------------') 183 | 184 | def set_requires_grad(self, nets, requires_grad=False): 185 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 186 | Parameters: 187 | nets (network list) -- a list of networks 188 | requires_grad (bool) -- whether the networks require gradients or not 189 | """ 190 | if not isinstance(nets, list): 191 | nets = [nets] 192 | for net in nets: 193 | if net is not None: 194 | for param in net.parameters(): 195 | param.requires_grad = requires_grad -------------------------------------------------------------------------------- /models/channel.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import math 8 | 9 | import sys 10 | sys.path.append('./') 11 | 12 | from models.utils import clipping, add_cp, rm_cp, batch_conv1d, PAPR, normalize 13 | 14 | 15 | 16 | # Realization of multipath channel as a nn module 17 | class Channel(nn.Module): 18 | def __init__(self, opt, device): 19 | super(Channel, self).__init__() 20 | self.opt = opt 21 | 22 | # Generate unit power profile 23 | power = torch.exp(-torch.arange(opt.L).float()/opt.decay).view(1,1,opt.L) # 1x1xL 24 | self.power = power/torch.sum(power) # Normalize the path power to sum to 1 25 | self.device = device 26 | 27 | def sample(self, N, P, M, L): 28 | # Sample the channel coefficients 29 | cof = torch.sqrt(self.power/2) * (torch.randn(N, P, L) + 1j*torch.randn(N, P, L)) 30 | cof_zp = torch.cat((cof, torch.zeros((N,P,M-L))), -1) 31 | H_t = torch.fft.fft(cof_zp, dim=-1) 32 | return cof, H_t 33 | 34 | def forward(self, input, cof=None): 35 | # Input size: NxPx(Sx(M+K)) 36 | # Output size: NxPx(Sx(M+K)) 37 | # Also return the true channel 38 | # Generate Channel Matrix 39 | 40 | N, P, SMK = input.shape 41 | 42 | # If the channel is not given, random sample one from the channel model 43 | if cof is None: 44 | cof, H_t = self.sample(N, P, self.opt.M, self.opt.L) 45 | else: 46 | cof_zp = torch.cat((cof, torch.zeros((N,P,self.opt.M-self.opt.L,2))), 2) 47 | cof_zp = torch.view_as_complex(cof_zp) 48 | H_t = torch.fft.fft(cof_zp, dim=-1) 49 | 50 | signal_real = input.real.float().view(N*P, -1) # (NxP)x(Sx(M+K)) 51 | signal_imag = input.imag.float().view(N*P, -1) # (NxP)x(Sx(M+K)) 52 | 53 | ind = torch.linspace(self.opt.L-1, 0, self.opt.L).long() 54 | cof_real = cof.real[...,ind].view(N*P, -1).float().to(self.device) # (NxP)xL 55 | cof_imag = cof.imag[...,ind].view(N*P, -1).float().to(self.device) # (NxP)xL 56 | 57 | output_real = batch_conv1d(signal_real, cof_real) - batch_conv1d(signal_imag, cof_imag) # (NxP)x(L+SMK-1) 58 | output_imag = batch_conv1d(signal_real, cof_imag) + batch_conv1d(signal_imag, cof_real) # (NxP)x(L+SMK-1) 59 | 60 | output = torch.cat((output_real.view(N*P,-1,1), output_imag.view(N*P,-1,1)), -1).view(N,P,SMK,2) # NxPxSMKx2 61 | output = torch.view_as_complex(output) 62 | 63 | return output, H_t 64 | 65 | 66 | # Realization of OFDM system as a nn module 67 | class OFDM(nn.Module): 68 | def __init__(self, opt, device, pilot_path): 69 | super(OFDM, self).__init__() 70 | self.opt = opt 71 | 72 | # Setup the channel layer 73 | self.channel = Channel(opt, device) 74 | 75 | # Generate the pilot signal 76 | if not os.path.exists(pilot_path): 77 | bits = torch.randint(2, (opt.M,2)) 78 | torch.save(bits,pilot_path) 79 | pilot = (2*bits-1).float() 80 | else: 81 | bits = torch.load(pilot_path) 82 | pilot = (2*bits-1).float() 83 | 84 | self.pilot = pilot.to(device) 85 | self.pilot = torch.view_as_complex(self.pilot) 86 | self.pilot = normalize(self.pilot, 1) 87 | self.pilot_cp = add_cp(torch.fft.ifft(self.pilot), self.opt.K).repeat(opt.P, opt.N_pilot,1) 88 | 89 | def forward(self, x, SNR, cof=None, batch_size=None): 90 | # Input size: NxPxSxM The information to be transmitted 91 | # cof denotes given channel coefficients 92 | 93 | # If x is None, we only send the pilots through the channel 94 | is_pilot = (x == None) 95 | 96 | if not is_pilot: 97 | 98 | # Change to new complex representations 99 | N = x.shape[0] 100 | 101 | # IFFT: NxPxSxM => NxPxSxM 102 | x = torch.fft.ifft(x, dim=-1) 103 | 104 | # Add Cyclic Prefix: NxPxSxM => NxPxSx(M+K) 105 | x = add_cp(x, self.opt.K) 106 | 107 | # Add pilot: NxPxSx(M+K) => NxPx(S+1)x(M+K) 108 | pilot = self.pilot_cp.repeat(N,1,1,1) 109 | x = torch.cat((pilot, x), 2) 110 | Ns = self.opt.S 111 | else: 112 | N = batch_size 113 | x = self.pilot_cp.repeat(N,1,1,1) 114 | Ns = 0 115 | 116 | # Reshape: NxPx(S+1)x(M+K) => NxPx(S+1)(M+K) 117 | x = x.view(N, self.opt.P, (Ns+self.opt.N_pilot)*(self.opt.M+self.opt.K)) 118 | 119 | # PAPR before clipping 120 | papr = PAPR(x) 121 | 122 | # Clipping (Optional): NxPx(S+1)(M+K) => NxPx(S+1)(M+K) 123 | if self.opt.is_clip: 124 | x = self.clip(x) 125 | 126 | # PAPR after clipping 127 | papr_cp = PAPR(x) 128 | 129 | # Pass through the Channel: NxPx(S+1)(M+K) => NxPx((S+1)(M+K)) 130 | y, H_t = self.channel(x, cof) 131 | 132 | # Calculate the power of received signal 133 | pwr = torch.mean(y.abs()**2, -1, True) 134 | noise_pwr = pwr*10**(-SNR/10) 135 | 136 | # Generate random noise 137 | noise = torch.sqrt(noise_pwr/2) * (torch.randn_like(y) + 1j*torch.randn_like(y)) 138 | y_noisy = y + noise 139 | 140 | # NxPx((S+S')(M+K)) => NxPx(S+S')x(M+K) 141 | output = y_noisy.view(N, self.opt.P, Ns+self.opt.N_pilot, self.opt.M+self.opt.K) 142 | 143 | y_pilot = output[:,:,:self.opt.N_pilot,:] # NxPxS'x(M+K) 144 | y_sig = output[:,:,self.opt.N_pilot:,:] # NxPxSx(M+K) 145 | 146 | if not is_pilot: 147 | # Remove Cyclic Prefix: 148 | info_pilot = rm_cp(y_pilot, self.opt.K) # NxPxS'xM 149 | info_sig = rm_cp(y_sig, self.opt.K) # NxPxSxM 150 | 151 | # FFT: 152 | info_pilot = torch.fft.fft(info_pilot, dim=-1) 153 | info_sig = torch.fft.fft(info_sig, dim=-1) 154 | 155 | return info_pilot, info_sig, H_t, noise_pwr, papr, papr_cp 156 | else: 157 | info_pilot = rm_cp(y_pilot, self.opt.K) # NxPxS'xM 158 | info_pilot = torch.fft.fft(info_pilot, dim=-1) 159 | 160 | return info_pilot, H_t, noise_pwr 161 | 162 | 163 | # Realization of direct transmission over the multipath channel 164 | class PLAIN(nn.Module): 165 | 166 | def __init__(self, opt, device): 167 | super(PLAIN, self).__init__() 168 | self.opt = opt 169 | 170 | # Setup the channel layer 171 | self.channel = Channel(opt, device) 172 | 173 | def forward(self, x, SNR): 174 | 175 | # Input size: NxPxM 176 | N, P, M = x.shape 177 | y = self.channel(x, None) 178 | 179 | # Calculate the power of received signal 180 | pwr = torch.mean(y.abs()**2, -1, True) 181 | noise_pwr = pwr*10**(-SNR/10) 182 | 183 | # Generate random noise 184 | noise = torch.sqrt(noise_pwr/2) * (torch.randn_like(y) + 1j*torch.randn_like(y)) 185 | y_noisy = y + noise # NxPx(M+L-1) 186 | rx = y_noisy[:, :, :M, :] 187 | return rx 188 | 189 | 190 | 191 | 192 | 193 | if __name__ == "__main__": 194 | 195 | import argparse 196 | opt = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 197 | 198 | opt.P = 1 199 | opt.S = 6 200 | opt.M = 64 201 | opt.K = 16 202 | opt.L = 8 203 | opt.decay = 4 204 | opt.N_pilot = 1 205 | opt.SNR = 10 206 | opt.is_clip = False 207 | 208 | ofdm = OFDM(opt, 0, './models/Pilot_bit.pt') 209 | 210 | input_f = torch.randn(128, opt.P, opt.S, opt.M) + 1j*torch.randn(1, opt.P, opt.S, opt.M) 211 | input_f = normalize(input_f, 1) 212 | input_f = input_f.cuda() 213 | 214 | info_pilot, info_sig, H_t, noise_pwr, papr, papr_cp = ofdm(input_f, opt.SNR) 215 | H_t = H_t.cuda() 216 | 217 | err = input_f*H_t.unsqueeze(0) - info_sig 218 | 219 | print(f'OFDM path error :{torch.mean(err.abs()**2).data}') 220 | 221 | from utils import ZF_equalization, MMSE_equalization, LS_channel_est, LMMSE_channel_est 222 | 223 | H_est_LS = LS_channel_est(ofdm.pilot, info_pilot) 224 | err_LS = torch.mean((H_est_LS.squeeze()-H_t.squeeze()).abs()**2) 225 | print(f'LS channel estimation error :{err_LS.data}') 226 | 227 | H_est_LMMSE = LMMSE_channel_est(ofdm.pilot, info_pilot, opt.M*noise_pwr) 228 | err_LMMSE = torch.mean((H_est_LMMSE.squeeze()-H_t.squeeze()).abs()**2) 229 | print(f'LMMSE channel estimation error :{err_LMMSE.data}') 230 | 231 | rx_ZF = ZF_equalization(H_t.unsqueeze(0), info_sig) 232 | err_ZF = torch.mean((rx_ZF.squeeze()-input_f.squeeze()).abs()**2) 233 | print(f'ZF error :{err_ZF.data}') 234 | 235 | rx_MMSE = MMSE_equalization(H_t.unsqueeze(0), info_sig, opt.M*noise_pwr) 236 | err_MMSE = torch.mean((rx_MMSE.squeeze()-input_f.squeeze()).abs()**2) 237 | print(f'MMSE error :{err_MMSE.data}') 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | """This class defines options used during both training and test time. 11 | 12 | It also implements several helper functions such as parsing, printing, and saving the options. 13 | It also gathers additional options defined in functions in both dataset class and model class. 14 | """ 15 | 16 | def __init__(self): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | """Define the common options that are used in both training and test.""" 22 | # basic parameters 23 | #parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 24 | parser.add_argument('--name', type=str, default='JSCC_OFDM', help='name of the experiment. It decides where to store samples and models') 25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 26 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 27 | # model parameters 28 | parser.add_argument('--model', type=str, default='JSCCOFDM', help='chooses which model to use. [JSCCOFDM | JSCC ]') 29 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') 30 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 31 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') 32 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 33 | parser.add_argument('--max_ngf', type=int, default=256, help='maximal # of gen filters in the last conv layer') 34 | parser.add_argument('--gan_mode', type=str, default='none', help='choose from [wgangp | lsgan | vanilla | none]') 35 | parser.add_argument('--label_smooth', type=int, default=1, help='label smoothing factor for lsgan and vanilla gan') 36 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if gan_mode != none') 37 | parser.add_argument('--norm_D', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') 38 | parser.add_argument('--n_downsample', type=int, default=2, help='number of downsampling') 39 | parser.add_argument('--n_blocks', type=int, default=2, help='number of residual blocks in either encoder or generator') 40 | parser.add_argument('--first_kernel', type=int, default=5, help='kernal size of the first conv layer in encoder') 41 | parser.add_argument('--C_channel', type=int, default=12, help='output channels for the latent vector') 42 | parser.add_argument('--activation', type=str, default='sigmoid', help='output activation, choose from [sigmoid | tanh]') 43 | parser.add_argument('--norm_EG', type=str, default='batch', help='instance normalization or batch normalization [instance | batch | none]') 44 | parser.add_argument('--init_type', type=str, default='kaiming', help='network initialization [normal | xavier | kaiming | orthogonal]') 45 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 46 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 47 | # OFDM parameters 48 | parser.add_argument('--P', type=int, default=1, help='number of packets for each transmitted image') 49 | parser.add_argument('--S', type=int, default=6, help='number of OFDM symbols per packet') 50 | parser.add_argument('--M', type=int, default=64, help='number of subcarriers per symbol') 51 | parser.add_argument('--K', type=int, default=16, help='length of cyclic prefix') 52 | parser.add_argument('--L', type=int, default=8, help='length of multipath channel') 53 | parser.add_argument('--decay', type=int, default=4, help='decay constant for the multipath channel') 54 | parser.add_argument('--is_clip', action='store_true', help='whether to include clipping') 55 | parser.add_argument('--CR', type=float, default=1.0, help='clipping ratio') 56 | parser.add_argument('--N_pilot', type=int, default=2, help='number of pilot symbols for channel estimation') 57 | parser.add_argument('--pilot', type=str, default='QPSK', help='type of pilots, choose from [QPSK | ZadoffChu]') 58 | parser.add_argument('--CE', type=str, default='LMMSE', help='channel estimation method, choose from [LS | LMMSE | TRUE]') 59 | parser.add_argument('--EQ', type=str, default='MMSE', help='equalization method, choose from [ZF | MMSE]') 60 | parser.add_argument('--SNR', type=float, default=20.0, help='SNR') 61 | parser.add_argument('--is_feedback', action='store_true', help='Wether to provide CSI feedback to the encoder') 62 | parser.add_argument('--feedforward', type=str, default='EXPLICIT-RES', help='which decoder design to use, choose from [IMPLICIT | EXPLICIT-CE | EXPLICIT-CE-EQ | EXPLICIT-RES]') 63 | # dataset parameters 64 | parser.add_argument('--dataset_mode', type=str, default='CIFAR10', help='chooses how datasets are loaded. [CIFAR10 | CelebA]') 65 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 66 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 67 | parser.add_argument('--batch_size', type=int, default=128, help='input batch size') 68 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 69 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 70 | # additional parameters 71 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 72 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 73 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 74 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 75 | self.initialized = True 76 | return parser 77 | 78 | def gather_options(self): 79 | """Initialize our parser with basic options(only once). 80 | Add additional model-specific and dataset-specific options. 81 | These options are defined in the function 82 | in model and dataset classes. 83 | """ 84 | if not self.initialized: # check if it has been initialized 85 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 86 | parser = self.initialize(parser) 87 | 88 | # get the basic options 89 | opt, _ = parser.parse_known_args() 90 | 91 | # modify model-related parser options 92 | #model_name = opt.model 93 | #model_option_setter = models.get_option_setter(model_name) 94 | #parser = model_option_setter(parser, self.isTrain) 95 | #opt, _ = parser.parse_known_args() # parse again with new defaults 96 | 97 | # modify dataset-related parser options 98 | #dataset_name = opt.dataset_mode 99 | #dataset_option_setter = data.get_option_setter(dataset_name) 100 | #parser = dataset_option_setter(parser, self.isTrain) 101 | 102 | # save and return the parser 103 | self.parser = parser 104 | return parser.parse_args() 105 | 106 | def print_options(self, opt): 107 | """Print and save options 108 | 109 | It will print both current options and default values(if different). 110 | It will save options into a text file / [checkpoints_dir] / opt.txt 111 | """ 112 | message = '' 113 | message += '----------------- Options ---------------\n' 114 | for k, v in sorted(vars(opt).items()): 115 | comment = '' 116 | default = self.parser.get_default(k) 117 | if v != default: 118 | comment = '\t[default: %s]' % str(default) 119 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 120 | message += '----------------- End -------------------' 121 | print(message) 122 | 123 | # save to the disk 124 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 125 | util.mkdirs(expr_dir) 126 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 127 | with open(file_name, 'wt') as opt_file: 128 | opt_file.write(message) 129 | opt_file.write('\n') 130 | 131 | def parse(self): 132 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 133 | opt = self.gather_options() 134 | opt.isTrain = self.isTrain # train or test 135 | 136 | # process opt.suffix 137 | if opt.suffix: 138 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 139 | opt.name = opt.name + suffix 140 | 141 | self.print_options(opt) 142 | 143 | # set gpu ids 144 | str_ids = opt.gpu_ids.split(',') 145 | opt.gpu_ids = [] 146 | for str_id in str_ids: 147 | id = int(str_id) 148 | if id >= 0: 149 | opt.gpu_ids.append(id) 150 | if len(opt.gpu_ids) > 0: 151 | torch.cuda.set_device(opt.gpu_ids[0]) 152 | 153 | self.opt = opt 154 | return self.opt 155 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util, html 7 | from subprocess import Popen, PIPE 8 | 9 | 10 | if sys.version_info[0] == 2: 11 | VisdomExceptionBase = Exception 12 | else: 13 | VisdomExceptionBase = ConnectionError 14 | 15 | 16 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 17 | """Save images to the disk. 18 | 19 | Parameters: 20 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 21 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 22 | image_path (str) -- the string is used to create image paths 23 | aspect_ratio (float) -- the aspect ratio of saved images 24 | width (int) -- the images will be resized to width x width 25 | 26 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 27 | """ 28 | image_dir = webpage.get_image_dir() 29 | short_path = ntpath.basename(image_path) 30 | name = os.path.splitext(short_path)[0] 31 | 32 | webpage.add_header(name) 33 | ims, txts, links = [], [], [] 34 | 35 | for label, im_data in visuals.items(): 36 | im = util.tensor2im(im_data) 37 | image_name = '%s_%s.png' % (name, label) 38 | save_path = os.path.join(image_dir, image_name) 39 | util.save_image(im, save_path, aspect_ratio=aspect_ratio) 40 | ims.append(image_name) 41 | txts.append(label) 42 | links.append(image_name) 43 | webpage.add_images(ims, txts, links, width=width) 44 | 45 | 46 | class Visualizer(): 47 | """This class includes several functions that can display/save images and print/save logging information. 48 | 49 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 50 | """ 51 | 52 | def __init__(self, opt): 53 | """Initialize the Visualizer class 54 | 55 | Parameters: 56 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 57 | Step 1: Cache the training/test options 58 | Step 2: connect to a visdom server 59 | Step 3: create an HTML object for saveing HTML filters 60 | Step 4: create a logging file to store training losses 61 | """ 62 | self.opt = opt # cache the option 63 | self.display_id = opt.display_id 64 | self.use_html = opt.isTrain and not opt.no_html 65 | self.win_size = opt.display_winsize 66 | self.name = opt.name 67 | self.port = opt.display_port 68 | self.saved = False 69 | if self.display_id > 0: # connect to a visdom server given and 70 | import visdom 71 | self.ncols = opt.display_ncols 72 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 73 | if not self.vis.check_connection(): 74 | self.create_visdom_connections() 75 | 76 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 77 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 78 | self.img_dir = os.path.join(self.web_dir, 'images') 79 | print('create web directory %s...' % self.web_dir) 80 | util.mkdirs([self.web_dir, self.img_dir]) 81 | # create a logging file to store training losses 82 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 83 | with open(self.log_name, "a") as log_file: 84 | now = time.strftime("%c") 85 | log_file.write('================ Training Loss (%s) ================\n' % now) 86 | 87 | def reset(self): 88 | """Reset the self.saved status""" 89 | self.saved = False 90 | 91 | def create_visdom_connections(self): 92 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 93 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 94 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 95 | print('Command: %s' % cmd) 96 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 97 | 98 | def display_current_results(self, visuals, epoch, save_result): 99 | """Display current results on visdom; save current results to an HTML file. 100 | 101 | Parameters: 102 | visuals (OrderedDict) - - dictionary of images to display or save 103 | epoch (int) - - the current epoch 104 | save_result (bool) - - if save the current results to an HTML file 105 | """ 106 | if self.display_id > 0: # show images in the browser using visdom 107 | ncols = self.ncols 108 | if ncols > 0: # show all the images in one visdom panel 109 | ncols = min(ncols, len(visuals)) 110 | h, w = next(iter(visuals.values())).shape[:2] 111 | table_css = """""" % (w, h) # create a table css 115 | # create a table of images. 116 | title = self.name 117 | label_html = '' 118 | label_html_row = '' 119 | images = [] 120 | idx = 0 121 | for label, image in visuals.items(): 122 | image_numpy = util.tensor2im(image) 123 | label_html_row += '%s' % label 124 | images.append(image_numpy.transpose([2, 0, 1])) 125 | idx += 1 126 | if idx % ncols == 0: 127 | label_html += '%s' % label_html_row 128 | label_html_row = '' 129 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 130 | while idx % ncols != 0: 131 | images.append(white_image) 132 | label_html_row += '' 133 | idx += 1 134 | if label_html_row != '': 135 | label_html += '%s' % label_html_row 136 | try: 137 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 138 | padding=2, opts=dict(title=title + ' images')) 139 | label_html = '%s
' % label_html 140 | self.vis.text(table_css + label_html, win=self.display_id + 2, 141 | opts=dict(title=title + ' labels')) 142 | except VisdomExceptionBase: 143 | self.create_visdom_connections() 144 | 145 | else: # show each image in a separate visdom panel; 146 | idx = 1 147 | try: 148 | for label, image in visuals.items(): 149 | image_numpy = util.tensor2im(image) 150 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 151 | win=self.display_id + idx) 152 | idx += 1 153 | except VisdomExceptionBase: 154 | self.create_visdom_connections() 155 | 156 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 157 | self.saved = True 158 | # save images to the disk 159 | for label, image in visuals.items(): 160 | image_numpy = util.tensor2im(image) 161 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 162 | util.save_image(image_numpy, img_path) 163 | 164 | # update website 165 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 166 | for n in range(epoch, 0, -1): 167 | webpage.add_header('epoch [%d]' % n) 168 | ims, txts, links = [], [], [] 169 | 170 | for label, image_numpy in visuals.items(): 171 | image_numpy = util.tensor2im(image) 172 | img_path = 'epoch%.3d_%s.png' % (n, label) 173 | ims.append(img_path) 174 | txts.append(label) 175 | links.append(img_path) 176 | webpage.add_images(ims, txts, links, width=self.win_size) 177 | webpage.save() 178 | 179 | def plot_current_losses(self, epoch, counter_ratio, losses): 180 | """display the current losses on visdom display: dictionary of error labels and values 181 | 182 | Parameters: 183 | epoch (int) -- current epoch 184 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 185 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 186 | """ 187 | if not hasattr(self, 'plot_data'): 188 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 189 | self.plot_data['X'].append(epoch + counter_ratio) 190 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 191 | 192 | try: 193 | self.vis.line( 194 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 195 | Y=np.array(self.plot_data['Y']), 196 | opts={ 197 | 'title': self.name + ' loss over time', 198 | 'legend': self.plot_data['legend'], 199 | 'xlabel': 'epoch', 200 | 'ylabel': 'loss'}, 201 | win=self.display_id) 202 | except VisdomExceptionBase: 203 | self.create_visdom_connections() 204 | 205 | # losses: same format as |losses| of plot_current_losses 206 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 207 | """print current losses on console; also save the losses to the disk 208 | 209 | Parameters: 210 | epoch (int) -- current epoch 211 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 212 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 213 | t_comp (float) -- computational time per data point (normalized by batch_size) 214 | t_data (float) -- data loading time per data point (normalized by batch_size) 215 | """ 216 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 217 | for k, v in losses.items(): 218 | message += '%s: %.3f ' % (k, v) 219 | 220 | print(message) # print the message 221 | with open(self.log_name, "a") as log_file: 222 | log_file.write('%s\n' % message) # save the message 223 | -------------------------------------------------------------------------------- /models/JSCCOFDM_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import numpy as np 4 | import torch 5 | import os 6 | from torch.autograd import Variable 7 | from util.image_pool import ImagePool 8 | from .base_model import BaseModel 9 | import models.networks as networks 10 | import models.channel as channel 11 | from models.utils import normalize, ZF_equalization, MMSE_equalization, LS_channel_est, LMMSE_channel_est 12 | 13 | class JSCCOFDMModel(BaseModel): 14 | 15 | def __init__(self, opt): 16 | BaseModel.__init__(self, opt) 17 | 18 | # specify the training losses you want to print out. The training/test scripts will call 19 | self.loss_names = ['G_L2', 'PAPR', 'CE', 'EQ'] 20 | # specify the images you want to save/display. The training/test scripts will call 21 | self.visual_names = ['real_A', 'fake'] 22 | 23 | # specify the models you want to save to the disk. The training/test scripts will call and 24 | if self.opt.gan_mode != 'none': 25 | self.model_names = ['E', 'G', 'D'] 26 | else: # during test time, only load G 27 | self.model_names = ['E', 'G'] 28 | 29 | if self.opt.feedforward in ['EXPLICIT-RES']: 30 | self.model_names += ['S1', 'S2'] 31 | 32 | if self.opt.feedforward in ['EXPLICIT-CE-EQ', 'EXPLICIT-RES']: 33 | C_decode = opt.C_channel 34 | elif self.opt.feedforward == 'IMPLICIT': 35 | C_decode = opt.C_channel + self.opt.N_pilot*self.opt.P*2 + self.opt.P*2 36 | elif self.opt.feedforward == 'EXPLICIT-CE': 37 | C_decode = opt.C_channel + self.opt.P*2 38 | 39 | if self.opt.is_feedback: 40 | add_C = self.opt.P*2 41 | else: 42 | add_C = 0 43 | 44 | # define networks (both generator and discriminator) 45 | self.netE = networks.define_E(input_nc=opt.input_nc, ngf=opt.ngf, max_ngf=opt.max_ngf, 46 | n_downsample=opt.n_downsample, C_channel=opt.C_channel, 47 | n_blocks=opt.n_blocks, norm=opt.norm_EG, init_type=opt.init_type, 48 | init_gain=opt.init_gain, gpu_ids=self.gpu_ids, first_kernel=opt.first_kernel, first_add_C=add_C) 49 | 50 | self.netG = networks.define_G(output_nc=opt.output_nc, ngf=opt.ngf, max_ngf=opt.max_ngf, 51 | n_downsample=opt.n_downsample, C_channel=C_decode, 52 | n_blocks=opt.n_blocks, norm=opt.norm_EG, init_type=opt.init_type, 53 | init_gain=opt.init_gain, gpu_ids=self.gpu_ids, first_kernel=opt.first_kernel, activation=opt.activation) 54 | 55 | #if self.isTrain and self.is_GAN: # define a discriminator; 56 | if self.opt.gan_mode != 'none': 57 | self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.n_layers_D, 58 | opt.norm_D, opt.init_type, opt.init_gain, self.gpu_ids) 59 | 60 | if self.opt.feedforward in ['EXPLICIT-RES']: 61 | self.netS1 = networks.define_Subnet(dim=(self.opt.N_pilot*self.opt.P+1)*2, dim_out=self.opt.P*2, 62 | norm=opt.norm_EG, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids) 63 | 64 | self.netS2 = networks.define_Subnet(dim=(self.opt.S+1)*self.opt.P*2, dim_out=self.opt.S*self.opt.P*2, 65 | norm=opt.norm_EG, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids) 66 | 67 | print('---------- Networks initialized -------------') 68 | 69 | # set loss functions and optimizers 70 | if self.isTrain: 71 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 72 | self.criterionFeat = torch.nn.L1Loss() 73 | self.criterionL2 = torch.nn.MSELoss() 74 | 75 | params = list(self.netE.parameters()) + list(self.netG.parameters()) 76 | 77 | if self.opt.feedforward in ['EXPLICIT-RES']: 78 | params+=list(self.netS1.parameters()) + list(self.netS2.parameters()) 79 | 80 | self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) 81 | self.optimizers.append(self.optimizer_G) 82 | 83 | if self.opt.gan_mode != 'none': 84 | params = list(self.netD.parameters()) 85 | self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) 86 | self.optimizers.append(self.optimizer_D) 87 | 88 | self.opt = opt 89 | self.ofdm = channel.OFDM(opt, self.device, './models/Pilot_bit.pt') 90 | 91 | def name(self): 92 | return 'JSCCOFDM_Model' 93 | 94 | def set_input(self, image): 95 | self.real_A = image.clone().to(self.device) 96 | self.real_B = image.clone().to(self.device) 97 | 98 | def forward(self): 99 | 100 | N = self.real_A.shape[0] 101 | 102 | if self.opt.is_feedback: 103 | with torch.no_grad(): 104 | cof, _ = self.ofdm.channel.sample(N, self.opt.P, self.opt.M, self.opt.L) 105 | out_pilot, H_t, noise_pwr = self.ofdm(None, SNR=self.opt.SNR, cof=cof, batch_size=N) 106 | H_est = self.channel_estimation(out_pilot, noise_pwr) 107 | H = torch.view_as_real(H_est).to(self.device) 108 | latent = self.netE(self.real_A, H) 109 | else: 110 | cof = None 111 | latent = self.netE(self.real_A) 112 | 113 | self.tx = latent.contiguous().view(N, self.opt.P, self.opt.S, 2, self.opt.M).contiguous().permute(0,1,2,4,3) 114 | self.tx_c = torch.view_as_complex(self.tx.contiguous()) 115 | self.tx_c = normalize(self.tx_c, 1) 116 | 117 | out_pilot, out_sig, self.H_true, noise_pwr, self.PAPR, self.PAPR_cp = self.ofdm(self.tx_c, SNR=self.opt.SNR, cof=cof) 118 | self.H_true = self.H_true.to(self.device) 119 | 120 | N, C, H, W = latent.shape 121 | 122 | if self.opt.feedforward == 'IMPLICIT': 123 | r1 = torch.view_as_real(self.ofdm.pilot).repeat(N,1,1,1,1) 124 | r2 = torch.view_as_real(out_pilot) 125 | r3 = torch.view_as_real(out_sig) 126 | dec_in = torch.cat((r1, r2, r3), 2).contiguous().permute(0,1,2,4,3).contiguous().view(N, -1, H, W) 127 | self.fake = self.netG(dec_in) 128 | elif self.opt.feedforward == 'EXPLICIT-CE': 129 | # Channel estimation 130 | self.H_est = self.channel_estimation(out_pilot, noise_pwr) 131 | r1 = torch.view_as_real(self.H_est) 132 | r2 = torch.view_as_real(out_sig) 133 | dec_in = torch.cat((r1, r2), 2).contiguous().permute(0,1,2,4,3).contiguous().view(N, -1, H, W) 134 | self.fake = self.netG(dec_in) 135 | elif self.opt.feedforward == 'EXPLICIT-CE-EQ': 136 | self.H_est = self.channel_estimation(out_pilot, noise_pwr) 137 | self.rx = self.equalization(self.H_est, out_sig, noise_pwr) 138 | r1 = torch.view_as_real(self.rx) 139 | dec_in = r1.contiguous().permute(0,1,2,4,3).contiguous().view(N, -1, H, W) 140 | self.fake = self.netG(dec_in) 141 | elif self.opt.feedforward == 'EXPLICIT-RES': 142 | self.H_est = self.channel_estimation(out_pilot, noise_pwr) 143 | sub11 = torch.view_as_real(self.ofdm.pilot).repeat(N,1,1,1,1) 144 | sub12 = torch.view_as_real(out_pilot) 145 | sub1_input = torch.cat((sub11, sub12), 2).contiguous().permute(0,1,2,4,3).contiguous().view(N, -1, H, W) 146 | sub1_output = self.netS1(sub1_input).view(N, self.opt.P, 1, 2, self.opt.M).permute(0,1,2,4,3) 147 | self.H_est = self.H_est + torch.view_as_complex(sub1_output.contiguous()) 148 | 149 | self.rx = self.equalization(self.H_est, out_sig, noise_pwr) 150 | sub21 = torch.view_as_real(self.H_est) 151 | sub22 = torch.view_as_real(out_sig) 152 | sub2_input = torch.cat((sub21, sub22), 2).contiguous().permute(0,1,2,4,3).contiguous().view(N, -1, H, W) 153 | sub2_output = self.netS2(sub2_input).view(N, self.opt.P, self.opt.S, 2, self.opt.M).permute(0,1,2,4,3) 154 | self.rx = self.rx + torch.view_as_complex(sub2_output.contiguous()) 155 | 156 | dec_in = torch.view_as_real(self.rx).permute(0,1,2,4,3).contiguous().view(latent.shape) 157 | self.fake = self.netG(dec_in) 158 | 159 | def backward_D(self): 160 | """Calculate GAN loss for the discriminator""" 161 | # Fake; stop backprop to the generator by detaching fake_B 162 | 163 | _, pred_fake = self.netD(self.fake.detach()) 164 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 165 | 166 | real_data = self.real_B 167 | _, pred_real = self.netD(real_data) 168 | self.loss_D_real = self.criterionGAN(pred_real, True) 169 | 170 | if self.opt.gan_mode in ['lsgan', 'vanilla']: 171 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 172 | self.loss_D.backward() 173 | elif self.opt.gan_mode == 'wgangp': 174 | penalty, grad = networks.cal_gradient_penalty(self.netD, real_data, self.fake.detach(), self.device, type='mixed', constant=1.0, lambda_gp=10.0) 175 | self.loss_D = self.loss_D_fake + self.loss_D_real + penalty 176 | self.loss_D.backward(retain_graph=True) 177 | 178 | def backward_G(self): 179 | """Calculate GAN and L1 loss for the generator""" 180 | # First, G(A) should fake the discriminator 181 | 182 | if self.opt.gan_mode != 'none': 183 | feat_fake, pred_fake = self.netD(self.fake) 184 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 185 | 186 | if self.is_Feat: 187 | feat_real, pred_real = self.netD(self.real_B) 188 | self.loss_G_Feat = 0 189 | 190 | for j in range(len(feat_real)): 191 | self.loss_G_Feat += self.criterionFeat(feat_real[j].detach(), feat_fake[j]) * self.opt.lambda_feat 192 | else: 193 | self.loss_G_Feat = 0 194 | else: 195 | self.loss_G_GAN = 0 196 | self.loss_G_Feat = 0 197 | 198 | self.loss_G_L2 = self.criterionL2(self.fake, self.real_B) * self.opt.lambda_L2 199 | self.loss_PAPR = torch.mean(self.PAPR_cp) * self.opt.lambda_papr 200 | if self.opt.feedforward == 'EXPLICIT-RES': 201 | self.loss_CE = self.criterionL2(torch.view_as_real(self.H_true.squeeze()), torch.view_as_real(self.H_est.squeeze())) * self.opt.lambda_ce 202 | self.loss_EQ = self.criterionL2(torch.view_as_real(self.rx), torch.view_as_real(self.tx_c)) * self.opt.lambda_eq 203 | else: 204 | self.loss_CE = 0 205 | self.loss_EQ = 0 206 | 207 | self.loss_G = self.loss_G_GAN + self.loss_G_Feat + self.loss_G_L2 + self.loss_PAPR + self.loss_CE + self.loss_EQ 208 | self.loss_G.backward() 209 | 210 | def optimize_parameters(self): 211 | self.forward() # compute fake images: G(A) 212 | # update D 213 | if self.opt.gan_mode != 'none': 214 | self.set_requires_grad(self.netD, True) # enable backprop for D 215 | self.optimizer_D.zero_grad() # set D's gradients to zero 216 | self.backward_D() # calculate gradients for D 217 | self.optimizer_D.step() # update D's weights 218 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 219 | else: 220 | self.loss_D_fake = 0 221 | self.loss_D_real = 0 222 | # update G 223 | self.optimizer_G.zero_grad() # set G's gradients to zero 224 | self.backward_G() # calculate graidents for G 225 | self.optimizer_G.step() # udpate G's weights 226 | 227 | def channel_estimation(self, out_pilot, noise_pwr): 228 | if self.opt.CE == 'LS': 229 | H_est = LS_channel_est(self.ofdm.pilot, out_pilot) 230 | elif self.opt.CE == 'LMMSE': 231 | H_est = LMMSE_channel_est(self.ofdm.pilot, out_pilot, self.opt.M*noise_pwr) 232 | elif self.opt.CE == 'TRUE': 233 | H_est = self.H_true.unsqueeze(2).to(self.device) 234 | else: 235 | raise NotImplementedError('The channel estimation method [%s] is not implemented' % CE) 236 | 237 | return H_est 238 | 239 | def equalization(self, H_est, out_sig, noise_pwr): 240 | # Equalization 241 | if self.opt.EQ == 'ZF': 242 | rx = ZF_equalization(H_est, out_sig) 243 | elif self.opt.EQ == 'MMSE': 244 | rx = MMSE_equalization(H_est, out_sig, self.opt.M*noise_pwr) 245 | elif self.opt.EQ == 'None': 246 | rx = None 247 | else: 248 | raise NotImplementedError('The equalization method [%s] is not implemented' % CE) 249 | return rx 250 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | import functools 7 | from torch.optim import lr_scheduler 8 | import numpy as np 9 | from torch.nn import functional as F 10 | from typing import List, Callable, Union, Any, TypeVar, Tuple 11 | from math import exp 12 | 13 | # from torch import tensor as Tensor 14 | 15 | Tensor = TypeVar('torch.tensor') 16 | ############################################################################### 17 | # Functions 18 | ############################################################################### 19 | 20 | class Identity(nn.Module): 21 | def forward(self, x): 22 | return x 23 | 24 | class Flatten(nn.Module): 25 | def forward(self, x): 26 | N, C, H, W = x.size() # read in N, C, H, W 27 | return x.view(N, -1) # "flatten" the C * H * W values into a single vector per image 28 | 29 | def get_norm_layer(norm_type='instance'): 30 | """Return a normalization layer 31 | 32 | Parameters: 33 | norm_type (str) -- the name of the normalization layer: batch | instance | none 34 | 35 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 36 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 37 | """ 38 | if norm_type == 'batch': 39 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 40 | elif norm_type == 'instance': 41 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 42 | elif norm_type == 'none': 43 | def norm_layer(x): return Identity() 44 | else: 45 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 46 | return norm_layer 47 | 48 | 49 | def init_weights(net, init_type='normal', init_gain=0.02): 50 | """Initialize network weights. 51 | 52 | Parameters: 53 | net (network) -- network to be initialized 54 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 55 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 56 | 57 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 58 | work better for some applications. Feel free to try yourself. 59 | """ 60 | def init_func(m): # define the initialization function 61 | classname = m.__class__.__name__ 62 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 63 | if init_type == 'normal': 64 | init.normal_(m.weight.data, 0.0, init_gain) 65 | elif init_type == 'xavier': 66 | init.xavier_normal_(m.weight.data, gain=init_gain) 67 | elif init_type == 'kaiming': 68 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 69 | elif init_type == 'orthogonal': 70 | init.orthogonal_(m.weight.data, gain=init_gain) 71 | else: 72 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 73 | if hasattr(m, 'bias') and m.bias is not None: 74 | init.constant_(m.bias.data, 0.0) 75 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 76 | init.normal_(m.weight.data, 1.0, init_gain) 77 | init.constant_(m.bias.data, 0.0) 78 | 79 | print('initialize network with %s' % init_type) 80 | net.apply(init_func) # apply the initialization function 81 | 82 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 83 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 84 | Parameters: 85 | net (network) -- the network to be initialized 86 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 87 | gain (float) -- scaling factor for normal, xavier and orthogonal. 88 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 89 | 90 | Return an initialized network. 91 | """ 92 | if len(gpu_ids) > 0: 93 | assert(torch.cuda.is_available()) 94 | net.to(gpu_ids[0]) 95 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 96 | init_weights(net, init_type, init_gain=init_gain) 97 | return net 98 | 99 | 100 | def get_scheduler(optimizer, opt): 101 | """Return a learning rate scheduler 102 | 103 | Parameters: 104 | optimizer -- the optimizer of the network 105 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  106 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 107 | 108 | For 'linear', we keep the same learning rate for the first epochs 109 | and linearly decay the rate to zero over the next epochs. 110 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 111 | See https://pytorch.org/docs/stable/optim.html for more details. 112 | """ 113 | if opt.lr_policy == 'linear': 114 | def lambda_rule(epoch): 115 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) 116 | return lr_l 117 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 118 | elif opt.lr_policy == 'step': 119 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 120 | elif opt.lr_policy == 'plateau': 121 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 122 | elif opt.lr_policy == 'cosine': 123 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 124 | else: 125 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 126 | return scheduler 127 | 128 | 129 | def define_E(input_nc, ngf, max_ngf, n_downsample, C_channel, n_blocks, norm='instance', init_type='kaiming', init_gain=0.02, gpu_ids=[], first_kernel=7, first_add_C=0): 130 | """Create a generator 131 | Parameters: 132 | input_nc (int) -- the number of channels in input images 133 | ngf (int) -- the number of filters in the last conv layer 134 | norm (str) -- the name of normalization layers used in the network: batch | instance | none 135 | init_type (str) -- the name of our initialization method. 136 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 137 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 138 | first_kernel (int) -- the kernel size of the first conv layer 139 | first_add_C (int) -- additional channels for the feedback mode 140 | 141 | Returns a generator 142 | Our current implementation provides two types of generators: 143 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). 144 | """ 145 | net = None 146 | norm_layer = get_norm_layer(norm_type=norm) 147 | net = Encoder(input_nc=input_nc, ngf=ngf, max_ngf=max_ngf, C_channel=C_channel, n_blocks=n_blocks, n_downsampling=n_downsample, norm_layer=norm_layer, padding_type="reflect", first_kernel=first_kernel, first_add_C=first_add_C) 148 | return init_net(net, init_type, init_gain, gpu_ids) 149 | 150 | def define_G(output_nc, ngf, max_ngf, n_downsample, C_channel, n_blocks, norm="instance", init_type='kaiming', init_gain=0.02, gpu_ids=[], first_kernel=7, activation='sigmoid'): 151 | net = None 152 | norm_layer = get_norm_layer(norm_type=norm) 153 | net = Generator(output_nc=output_nc, ngf=ngf, max_ngf=max_ngf, C_channel=C_channel, n_blocks=n_blocks, n_downsampling=n_downsample, norm_layer=norm_layer, padding_type="reflect", first_kernel=first_kernel, activation_=activation) 154 | return init_net(net, init_type, init_gain, gpu_ids) 155 | 156 | def define_D(input_nc, ndf, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): 157 | net = None 158 | norm_layer = get_norm_layer(norm_type=norm) 159 | net = NLayerDiscriminator(input_nc, ndf, n_layers=n_layers_D, norm_layer=norm_layer) 160 | return init_net(net, init_type, init_gain, gpu_ids) 161 | 162 | def print_network(net): 163 | if isinstance(net, list): 164 | net = net[0] 165 | num_params = 0 166 | for param in net.parameters(): 167 | num_params += param.numel() 168 | print(net) 169 | print('Total number of parameters: %d' % num_params) 170 | 171 | 172 | 173 | ############################################################################## 174 | # Losses 175 | ############################################################################## 176 | class GANLoss(nn.Module): 177 | """Define different GAN objectives. 178 | 179 | The GANLoss class abstracts away the need to create the target label tensor 180 | that has the same size as the input. 181 | """ 182 | 183 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 184 | """ Initialize the GANLoss class. 185 | 186 | Parameters: 187 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 188 | target_real_label (bool) - - label for a real image 189 | target_fake_label (bool) - - label of a fake image 190 | 191 | Note: Do not use sigmoid as the last layer of Discriminator. 192 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 193 | """ 194 | super(GANLoss, self).__init__() 195 | self.register_buffer('real_label', torch.tensor(target_real_label)) 196 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 197 | self.gan_mode = gan_mode 198 | if gan_mode == 'lsgan': 199 | self.loss = nn.MSELoss() 200 | elif gan_mode == 'vanilla': 201 | self.loss = nn.BCEWithLogitsLoss() 202 | elif gan_mode in ['wgangp', 'none']: 203 | self.loss = None 204 | else: 205 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 206 | 207 | def get_target_tensor(self, prediction, target_is_real): 208 | """Create label tensors with the same size as the input. 209 | 210 | Parameters: 211 | prediction (tensor) - - tpyically the prediction from a discriminator 212 | target_is_real (bool) - - if the ground truth label is for real images or fake images 213 | 214 | Returns: 215 | A label tensor filled with ground truth label, and with the size of the input 216 | """ 217 | 218 | if target_is_real: 219 | target_tensor = self.real_label 220 | else: 221 | target_tensor = self.fake_label 222 | return target_tensor.expand_as(prediction) 223 | 224 | def __call__(self, prediction, target_is_real): 225 | """Calculate loss given Discriminator's output and grount truth labels. 226 | 227 | Parameters: 228 | prediction (tensor) - - tpyically the prediction output from a discriminator 229 | target_is_real (bool) - - if the ground truth label is for real images or fake images 230 | 231 | Returns: 232 | the calculated loss. 233 | """ 234 | if self.gan_mode in ['lsgan', 'vanilla']: 235 | target_tensor = self.get_target_tensor(prediction, target_is_real) 236 | loss = self.loss(prediction, target_tensor) 237 | elif self.gan_mode == 'wgangp': 238 | if target_is_real: 239 | loss = -prediction.mean() 240 | else: 241 | loss = prediction.mean() 242 | return loss 243 | 244 | 245 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 246 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 247 | 248 | Arguments: 249 | netD (network) -- discriminator network 250 | real_data (tensor array) -- real images 251 | fake_data (tensor array) -- generated images from the generator 252 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 253 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 254 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 255 | lambda_gp (float) -- weight for this loss 256 | 257 | Returns the gradient penalty loss 258 | """ 259 | if lambda_gp > 0.0: 260 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 261 | interpolatesv = real_data 262 | elif type == 'fake': 263 | interpolatesv = fake_data 264 | elif type == 'mixed': 265 | alpha = torch.rand(real_data.shape[0], 1, device=device) 266 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 267 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 268 | else: 269 | raise NotImplementedError('{} not implemented'.format(type)) 270 | interpolatesv.requires_grad_(True) 271 | disc_interpolates = netD(interpolatesv) 272 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 273 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 274 | create_graph=True, retain_graph=True, only_inputs=True) 275 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 276 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 277 | return gradient_penalty, gradients 278 | else: 279 | return 0.0, None 280 | 281 | 282 | ############################################################################## 283 | # Encoder 284 | ############################################################################## 285 | class Encoder(nn.Module): 286 | 287 | def __init__(self, input_nc, ngf=64, max_ngf=512, C_channel=16, n_blocks=2, n_downsampling=2, norm_layer=nn.BatchNorm2d, padding_type="reflect", first_kernel=7, first_add_C=0): 288 | """Construct a Resnet-based generator 289 | 290 | Parameters: 291 | input_nc (int) -- the number of channels in input images 292 | ngf (int) -- the number of filters in the first conv layer 293 | norm_layer -- normalization layer 294 | use_dropout (bool) -- if use dropout layers 295 | n_blocks (int) -- the number of ResNet blocks 296 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 297 | """ 298 | assert(n_downsampling>=0) 299 | assert(n_blocks>=0) 300 | super(Encoder, self).__init__() 301 | 302 | if type(norm_layer) == functools.partial: 303 | use_bias = norm_layer.func == nn.InstanceNorm2d 304 | else: 305 | use_bias = norm_layer == nn.InstanceNorm2d 306 | 307 | activation = nn.ReLU(True) 308 | 309 | model = [nn.ReflectionPad2d((first_kernel-1)//2), 310 | nn.Conv2d(input_nc, ngf, kernel_size=first_kernel, padding=0, bias=use_bias), 311 | norm_layer(ngf), 312 | activation] 313 | 314 | # add downsampling layers 315 | for i in range(n_downsampling): 316 | mult = 2**i 317 | model += [nn.Conv2d(min(ngf * mult,max_ngf), min(ngf * mult * 2,max_ngf), kernel_size=3, stride=2, padding=1, bias=use_bias), 318 | norm_layer(min(ngf * mult * 2, max_ngf)), activation] 319 | 320 | self.model_down = nn.Sequential(*model) 321 | model= [] 322 | # add ResNet blocks 323 | mult = 2 ** n_downsampling 324 | for i in range(n_blocks): 325 | model += [ResnetBlock(min(ngf * mult,max_ngf)+first_add_C, padding_type=padding_type, norm_layer=norm_layer, use_dropout=False, use_bias=use_bias)] 326 | 327 | self.model_res = nn.Sequential(*model) 328 | 329 | self.projection = nn.Conv2d(min(ngf * mult,max_ngf)+first_add_C, C_channel, kernel_size=3, padding=1, stride=1, bias=use_bias) 330 | 331 | def forward(self, input, H=None): 332 | 333 | z = self.model_down(input) 334 | if H is not None: 335 | N,C,HH,WW = z.shape 336 | z = torch.cat((z,H.contiguous().permute(0,1,2,4,3).view(N, -1, HH,WW)), 1) 337 | return self.projection(self.model_res(z)) 338 | 339 | ############################################################################## 340 | # Generator 341 | ############################################################################## 342 | class Generator(nn.Module): 343 | def __init__(self, output_nc, ngf=64, max_ngf=512, C_channel=16, n_blocks=2, n_downsampling=2, norm_layer=nn.BatchNorm2d, padding_type="reflect", first_kernel=7, activation_='sigmoid'): 344 | assert (n_blocks>=0) 345 | assert(n_downsampling>=0) 346 | 347 | super(Generator, self).__init__() 348 | 349 | self.activation_ = activation_ 350 | 351 | if type(norm_layer) == functools.partial: 352 | use_bias = norm_layer.func == nn.InstanceNorm2d 353 | else: 354 | use_bias = norm_layer == nn.InstanceNorm2d 355 | 356 | activation = nn.ReLU(True) 357 | 358 | mult = 2 ** n_downsampling 359 | ngf_dim = min(ngf * mult, max_ngf) 360 | model = [nn.Conv2d(C_channel,ngf_dim,kernel_size=3, padding=1 ,stride=1, bias=use_bias)] 361 | 362 | for i in range(n_blocks): 363 | model += [ResnetBlock(ngf_dim, padding_type=padding_type, norm_layer=norm_layer, use_dropout=False, use_bias=use_bias)] 364 | 365 | for i in range(n_downsampling): 366 | mult = 2 ** (n_downsampling - i) 367 | model += [nn.ConvTranspose2d(min(ngf * mult,max_ngf), min(ngf * mult //2, max_ngf), 368 | kernel_size=3, stride=2, 369 | padding=1, output_padding=1, 370 | bias=use_bias), 371 | norm_layer(min(ngf * mult //2, max_ngf)), 372 | activation] 373 | 374 | model += [nn.ReflectionPad2d((first_kernel-1)//2), nn.Conv2d(ngf, output_nc, kernel_size=first_kernel, padding=0)] 375 | 376 | if activation_ == 'tanh': 377 | model +=[nn.Tanh()] 378 | elif activation_ == 'sigmoid': 379 | model +=[nn.Sigmoid()] 380 | 381 | self.model = nn.Sequential(*model) 382 | 383 | def forward(self, input): 384 | 385 | if self.activation_=='tanh': 386 | return self.model(input) 387 | elif self.activation_=='sigmoid': 388 | return 2*self.model(input)-1 389 | 390 | ######################################################################################### 391 | # Residual block 392 | ######################################################################################### 393 | class ResnetBlock(nn.Module): 394 | """Define a Resnet block""" 395 | 396 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 397 | """Initialize the Resnet block 398 | 399 | A resnet block is a conv block with skip connections 400 | We construct a conv block with build_conv_block function, 401 | and implement skip connections in function. 402 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 403 | """ 404 | super(ResnetBlock, self).__init__() 405 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 406 | 407 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 408 | """Construct a convolutional block. 409 | 410 | Parameters: 411 | dim (int) -- the number of channels in the conv layer. 412 | padding_type (str) -- the name of padding layer: reflect | replicate | zero 413 | norm_layer -- normalization layer 414 | use_dropout (bool) -- if use dropout layers. 415 | use_bias (bool) -- if the conv layer uses bias or not 416 | 417 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) 418 | """ 419 | conv_block = [] 420 | p = 0 421 | if padding_type == 'reflect': 422 | conv_block += [nn.ReflectionPad2d(1)] 423 | elif padding_type == 'replicate': 424 | conv_block += [nn.ReplicationPad2d(1)] 425 | elif padding_type == 'zero': 426 | p = 1 427 | else: 428 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 429 | 430 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 431 | if use_dropout: 432 | conv_block += [nn.Dropout(0.5)] 433 | 434 | p = 0 435 | if padding_type == 'reflect': 436 | conv_block += [nn.ReflectionPad2d(1)] 437 | elif padding_type == 'replicate': 438 | conv_block += [nn.ReplicationPad2d(1)] 439 | elif padding_type == 'zero': 440 | p = 1 441 | else: 442 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 443 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] 444 | 445 | return nn.Sequential(*conv_block) 446 | 447 | def forward(self, x): 448 | """Forward function (with skip connections)""" 449 | out = x + self.conv_block(x) # add skip connections 450 | return out 451 | 452 | 453 | class NLayerDiscriminator(nn.Module): 454 | """Defines a PatchGAN discriminator""" 455 | 456 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 457 | """Construct a PatchGAN discriminator 458 | 459 | Parameters: 460 | input_nc (int) -- the number of channels in input images 461 | ndf (int) -- the number of filters in the last conv layer 462 | n_layers (int) -- the number of conv layers in the discriminator 463 | norm_layer -- normalization layer 464 | """ 465 | super(NLayerDiscriminator, self).__init__() 466 | 467 | self.n_layers = n_layers 468 | 469 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 470 | use_bias = norm_layer.func == nn.InstanceNorm2d 471 | else: 472 | use_bias = norm_layer == nn.InstanceNorm2d 473 | 474 | kw = 4 475 | padw = 1 476 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 477 | nf_mult = 1 478 | nf_mult_prev = 1 479 | for n in range(1, n_layers): # gradually increase the number of filters 480 | nf_mult_prev = nf_mult 481 | nf_mult = min(2 ** n, 8) 482 | sequence += [[ 483 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 484 | norm_layer(ndf * nf_mult), 485 | nn.LeakyReLU(0.2, True) 486 | ]] 487 | 488 | nf_mult_prev = nf_mult 489 | nf_mult = min(2 ** n_layers, 8) 490 | sequence += [[ 491 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 492 | norm_layer(ndf * nf_mult), 493 | nn.LeakyReLU(0.2, True) 494 | ]] 495 | 496 | sequence += [[nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]] # output 1 channel prediction map 497 | 498 | 499 | for n in range(len(sequence)): 500 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 501 | 502 | 503 | def forward(self, input): 504 | """Standard forward.""" 505 | res = [input] 506 | for n in range(self.n_layers+1): 507 | model = getattr(self, 'model'+str(n)) 508 | res.append(model(res[-1])) 509 | 510 | model = getattr(self, 'model'+str(self.n_layers+1)) 511 | out = model(res[-1]) 512 | 513 | return res[1:], out 514 | 515 | 516 | class PixelDiscriminator(nn.Module): 517 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" 518 | 519 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): 520 | """Construct a 1x1 PatchGAN discriminator 521 | 522 | Parameters: 523 | input_nc (int) -- the number of channels in input images 524 | ndf (int) -- the number of filters in the last conv layer 525 | norm_layer -- normalization layer 526 | """ 527 | super(PixelDiscriminator, self).__init__() 528 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 529 | use_bias = norm_layer.func == nn.InstanceNorm2d 530 | else: 531 | use_bias = norm_layer == nn.InstanceNorm2d 532 | 533 | self.net = [ 534 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 535 | nn.LeakyReLU(0.2, True), 536 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 537 | norm_layer(ndf * 2), 538 | nn.LeakyReLU(0.2, True), 539 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 540 | 541 | self.net = nn.Sequential(*self.net) 542 | 543 | def forward(self, input): 544 | """Standard forward.""" 545 | return self.net(input) 546 | 547 | 548 | ######################################################################################### 549 | # Residual-like subnet 550 | ######################################################################################### 551 | class Subnet(nn.Module): 552 | """Define a Resnet block""" 553 | 554 | def __init__(self, dim, dim_out, padding_type, norm_layer, use_dropout, use_bias): 555 | """Initialize the Resnet block 556 | 557 | A resnet block is a conv block with skip connections 558 | We construct a conv block with build_conv_block function, 559 | and implement skip connections in function. 560 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 561 | """ 562 | super(Subnet, self).__init__() 563 | self.conv_block = self.build_conv_block(dim, dim_out, padding_type, norm_layer, use_dropout, use_bias) 564 | 565 | def build_conv_block(self, dim, dim_out, padding_type, norm_layer, use_dropout, use_bias): 566 | """Construct a convolutional block. 567 | 568 | Parameters: 569 | dim (int) -- the number of channels in the conv layer. 570 | padding_type (str) -- the name of padding layer: reflect | replicate | zero 571 | norm_layer -- normalization layer 572 | use_dropout (bool) -- if use dropout layers. 573 | use_bias (bool) -- if the conv layer uses bias or not 574 | 575 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) 576 | """ 577 | conv_block = [] 578 | p = 0 579 | if padding_type == 'reflect': 580 | conv_block += [nn.ReflectionPad2d(1)] 581 | elif padding_type == 'replicate': 582 | conv_block += [nn.ReplicationPad2d(1)] 583 | elif padding_type == 'zero': 584 | p = 1 585 | else: 586 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 587 | 588 | conv_block += [nn.Conv2d(dim, 64, kernel_size=3, padding=p, bias=use_bias), norm_layer(64), nn.ReLU(True)] 589 | if use_dropout: 590 | conv_block += [nn.Dropout(0.5)] 591 | 592 | p = 0 593 | if padding_type == 'reflect': 594 | conv_block += [nn.ReflectionPad2d(1)] 595 | elif padding_type == 'replicate': 596 | conv_block += [nn.ReplicationPad2d(1)] 597 | elif padding_type == 'zero': 598 | p = 1 599 | else: 600 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 601 | conv_block += [nn.Conv2d(64, dim_out, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim_out)] 602 | 603 | return nn.Sequential(*conv_block) 604 | 605 | def forward(self, x): 606 | """Forward function (with skip connections)""" 607 | out = self.conv_block(x) # add skip connections 608 | return out 609 | 610 | 611 | def define_Subnet(dim, dim_out, norm='instance', init_type='kaiming', init_gain=0.02, gpu_ids=[]): 612 | net = None 613 | norm_layer = get_norm_layer(norm_type=norm) 614 | if type(norm_layer) == functools.partial: 615 | use_bias = norm_layer.func == nn.InstanceNorm2d 616 | else: 617 | use_bias = norm_layer == nn.InstanceNorm2d 618 | net = Subnet(dim=dim, dim_out=dim_out, padding_type='zero', norm_layer=norm_layer, use_dropout=False, use_bias=use_bias) 619 | return init_net(net, init_type, init_gain, gpu_ids) 620 | --------------------------------------------------------------------------------