├── README.md ├── conda.yml ├── data ├── __init__.py ├── dataset_cifar10.py ├── dataset_imagenet.py ├── dataset_kodak.py ├── dataset_mnist.py └── imagenet_preprocessing.py └── jscc.py /README.md: -------------------------------------------------------------------------------- 1 | # DeepJSCC-f: Deep Joint Source-Channel Coding of Images with Feedback 2 | 3 | Code used in paper [DeepJSCC-f: Deep Joint Source-Channel Coding of Images with Feedback](https://arxiv.org/abs/1911.11174), appearing in IEEE Journal on Selected Areas in Information Theory (JSAIT). 4 | 5 | 6 | - [Arxiv](https://arxiv.org/abs/1911.11174) 7 | - [IEEE JSAIT](https://ieeexplore.ieee.org/document/9066966) 8 | 9 | 10 | Authors: David Burth Kurka and Deniz Gündüz 11 | 12 | ## Usage: 13 | 14 | 15 | ``` 16 | python jscc.py --help 17 | ``` 18 | 19 | 20 | -------------------------------------------------------------------------------- /conda.yml: -------------------------------------------------------------------------------- 1 | name: deepjsccf 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - ca-certificates=2020.1.1=0 7 | - certifi=2019.11.28=py36_0 8 | - cudatoolkit=10.0.130=0 9 | - cudnn=7.6.5=cuda10.0_0 10 | - ld_impl_linux-64=2.33.1=h53a641e_7 11 | - libedit=3.1.20181209=hc058e9b_0 12 | - libffi=3.2.1=hd88cf55_4 13 | - libgcc-ng=9.1.0=hdf63c60_0 14 | - libstdcxx-ng=9.1.0=hdf63c60_0 15 | - ncurses=6.2=he6710b0_0 16 | - openssl=1.1.1d=h7b6447c_4 17 | - pip=20.0.2=py36_1 18 | - python=3.6.10=h0371630_0 19 | - readline=7.0=h7b6447c_5 20 | - setuptools=45.2.0=py36_0 21 | - sqlite=3.31.1=h7b6447c_0 22 | - tk=8.6.8=hbc83047_0 23 | - wheel=0.34.2=py36_0 24 | - xz=5.2.4=h14c3975_4 25 | - zlib=1.2.11=h7b6447c_3 26 | - pip: 27 | - absl-py==0.9.0 28 | - alembic==1.4.0 29 | - astor==0.8.1 30 | - attrs==19.3.0 31 | - chardet==3.0.4 32 | - click==7.0 33 | - cloudpickle==1.3.0 34 | - configargparse==1.0 35 | - configparser==4.0.2 36 | - databricks-cli==0.9.1 37 | - dill==0.3.1.1 38 | - docker==4.2.0 39 | - entrypoints==0.3 40 | - flask==1.1.1 41 | - future==0.18.2 42 | - gast==0.2.2 43 | - gitdb==4.0.2 44 | - gitpython==3.1.0 45 | - google-pasta==0.1.8 46 | - googleapis-common-protos==1.51.0 47 | - gorilla==0.3.0 48 | - grpcio==1.27.2 49 | - gunicorn==20.0.4 50 | - h5py==2.10.0 51 | - idna==2.9 52 | - itsdangerous==1.1.0 53 | - jinja2==2.11.1 54 | - keras-applications==1.0.8 55 | - keras-preprocessing==1.1.0 56 | - mako==1.1.1 57 | - markdown==3.2.1 58 | - markupsafe==1.1.1 59 | - mlflow==1.6.0 60 | - numpy==1.18.1 61 | - opt-einsum==3.1.0 62 | - pandas==1.0.1 63 | - prometheus-client==0.7.1 64 | - prometheus-flask-exporter==0.12.2 65 | - promise==2.3 66 | - protobuf==3.11.3 67 | - python-dateutil==2.8.1 68 | - python-editor==1.0.4 69 | - pytz==2019.3 70 | - pyyaml==5.3 71 | - querystring-parser==1.2.4 72 | - requests==2.23.0 73 | - scipy==1.4.1 74 | - simplejson==3.17.0 75 | - six==1.14.0 76 | - smmap==3.0.1 77 | - sqlalchemy==1.3.13 78 | - sqlparse==0.3.0 79 | - tabulate==0.8.6 80 | - tensorboard==1.15.0 81 | - tensorflow-compression==1.3 82 | - tensorflow-datasets==2.1.0 83 | - tensorflow-estimator==1.15.1 84 | - tensorflow-gpu==1.15.2 85 | - tensorflow-metadata==0.21.1 86 | - termcolor==1.1.0 87 | - tqdm==4.43.0 88 | - urllib3==1.25.8 89 | - websocket-client==0.57.0 90 | - werkzeug==1.0.0 91 | - wrapt==1.12.0 92 | 93 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ipc-lab/deepJSCC-feedback/c3187073697406148f65287670a24f8ca5fefcc7/data/__init__.py -------------------------------------------------------------------------------- /data/dataset_cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tarfile 4 | from six.moves import urllib 5 | import tensorflow as tf 6 | 7 | DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' 8 | _HEIGHT = 32 9 | _WIDTH = 32 10 | _NUM_CHANNELS = 3 11 | _DEFAULT_IMAGE_BYTES = _HEIGHT * _WIDTH * _NUM_CHANNELS 12 | # The record is the image plus a one-byte label 13 | _RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1 14 | _NUM_CLASSES = 10 15 | _NUM_DATA_FILES = 5 16 | _NUM_IMAGES = { 17 | 'train': 45000, 18 | 'validation': 5000, 19 | 'test': 10000, 20 | } 21 | 22 | SHUFFLE_BUFFER = _NUM_IMAGES['train'] 23 | SHAPE = [_HEIGHT, _WIDTH, _NUM_CHANNELS] 24 | 25 | 26 | def get_dataset(is_training, data_dir): 27 | """Returns a dataset object""" 28 | filenames = get_filenames(is_training, data_dir) 29 | return tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES) 30 | 31 | 32 | def get_filenames(is_training, data_dir): 33 | """Returns a list of filenames.""" 34 | maybe_download_and_extract(data_dir) 35 | 36 | data_dir = os.path.join(data_dir, 'cifar-10-batches-bin') 37 | if is_training: 38 | return [ 39 | os.path.join(data_dir, 'data_batch_%d.bin' % i) 40 | for i in range(1, _NUM_DATA_FILES + 1) 41 | ] 42 | else: 43 | return [os.path.join(data_dir, 'test_batch.bin')] 44 | 45 | 46 | def parse_record(raw_record, _mode, dtype): 47 | """Parse CIFAR-10 image and label from a raw record.""" 48 | # Convert bytes to a vector of uint8 that is record_bytes long. 49 | record_vector = tf.io.decode_raw(raw_record, tf.uint8) 50 | 51 | # The first byte represents the label, which we convert from uint8 to int32 52 | # and then to one-hot. 53 | label = tf.cast(record_vector[0], tf.int32) 54 | 55 | # The remaining bytes after the label represent the image, which we reshape 56 | # from [depth * height * width] to [depth, height, width]. 57 | depth_major = tf.reshape(record_vector[1:_RECORD_BYTES], 58 | [_NUM_CHANNELS, _HEIGHT, _WIDTH]) 59 | 60 | # Convert from [depth, height, width] to [height, width, depth], and cast 61 | # as float32. 62 | image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) 63 | 64 | # normalise images to range 0-1 65 | image = image/255.0 66 | 67 | image = tf.cast(image, dtype) 68 | 69 | return image, image 70 | 71 | 72 | def maybe_download_and_extract(data_dir): 73 | """Download and extract the tarball from Alex's website.""" 74 | if not os.path.exists(data_dir): 75 | os.makedirs(data_dir) 76 | 77 | filename = DATA_URL.split('/')[-1] 78 | filepath = os.path.join(data_dir, filename) 79 | 80 | if not os.path.exists(filepath): 81 | def _progress(count, block_size, total_size): 82 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 83 | filename, 100.0 * count * block_size / total_size)) 84 | sys.stdout.flush() 85 | 86 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 87 | print() 88 | statinfo = os.stat(filepath) 89 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 90 | extracted_dir_path = os.path.join(data_dir, 'cifar-10-batches-bin') 91 | if not os.path.exists(extracted_dir_path): 92 | tarfile.open(filepath, 'r:gz').extractall(data_dir) 93 | -------------------------------------------------------------------------------- /data/dataset_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import data.imagenet_preprocessing as imgnet_preprocessing 4 | 5 | _DEFAULT_IMAGE_SIZE = 128#224 6 | _NUM_CHANNELS = 3 7 | _NUM_CLASSES = 1001 8 | 9 | # total 'train' images: 1281167 10 | _NUM_IMAGES = { 11 | 'train': 1231167, 12 | 'validation': 49920, # I thought it was 50k, but its just finding 49k 13 | 'test': 50000, 14 | } 15 | 16 | _NUM_TRAIN_FILES = 1024 # number of tfrecords files 17 | SHUFFLE_BUFFER = 10000 18 | SHAPE = [_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS] 19 | 20 | 21 | def get_filenames(is_training, data_dir): 22 | """Return filenames for dataset.""" 23 | if is_training: 24 | return [ 25 | os.path.join(data_dir, 'train', 'train-%05d-of-01024' % i) 26 | for i in range(_NUM_TRAIN_FILES)] 27 | else: 28 | return [ 29 | os.path.join(data_dir, 'validation', 'validation-%05d-of-00128' % i) 30 | for i in range(128)] 31 | 32 | 33 | def _parse_example_proto(example_serialized): 34 | """Parses an Example proto containing a training example of an image. 35 | 36 | The output of the build_image_data.py image preprocessing script is a 37 | dataset containing serialized Example protocol buffers. Each Example proto 38 | contains the following fields (values are included as examples): 39 | 40 | image/height: 462 41 | image/width: 581 42 | image/colorspace: 'RGB' 43 | image/channels: 3 44 | image/class/label: 615 45 | image/class/synset: 'n03623198' 46 | image/class/text: 'knee pad' 47 | image/object/bbox/xmin: 0.1 48 | image/object/bbox/xmax: 0.9 49 | image/object/bbox/ymin: 0.2 50 | image/object/bbox/ymax: 0.6 51 | image/object/bbox/label: 615 52 | image/format: 'JPEG' 53 | image/filename: 'ILSVRC2012_val_00041207.JPEG' 54 | image/encoded: 55 | 56 | Args: 57 | example_serialized: scalar Tensor tf.string containing a serialized 58 | Example protocol buffer. 59 | 60 | Returns: 61 | image_buffer: Tensor tf.string containing the contents of a JPEG file. 62 | label: Tensor tf.int32 containing the label. 63 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 64 | where each coordinate is [0, 1) and the coordinates are arranged as 65 | [ymin, xmin, ymax, xmax]. 66 | """ 67 | # Dense features in Example proto. 68 | feature_map = { 69 | 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, 70 | default_value=''), 71 | 'image/class/label': tf.FixedLenFeature([], dtype=tf.int64, 72 | default_value=-1), 73 | 'image/class/text': tf.FixedLenFeature([], dtype=tf.string, 74 | default_value=''), 75 | } 76 | sparse_float32 = tf.VarLenFeature(dtype=tf.float32) 77 | # Sparse features in Example proto. 78 | feature_map.update( 79 | {k: sparse_float32 for k in ['image/object/bbox/xmin', 80 | 'image/object/bbox/ymin', 81 | 'image/object/bbox/xmax', 82 | 'image/object/bbox/ymax']}) 83 | 84 | features = tf.parse_single_example(example_serialized, feature_map) 85 | label = tf.cast(features['image/class/label'], dtype=tf.int32) 86 | 87 | xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0) 88 | ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0) 89 | xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0) 90 | ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0) 91 | 92 | # Note that we impose an ordering of (y, x) just to make life difficult. 93 | bbox = tf.concat([ymin, xmin, ymax, xmax], 0) 94 | 95 | # Force the variable number of bounding boxes into the shape 96 | # [1, num_boxes, coords]. 97 | bbox = tf.expand_dims(bbox, 0) 98 | bbox = tf.transpose(bbox, [0, 2, 1]) 99 | 100 | return features['image/encoded'], label, bbox 101 | 102 | 103 | def parse_record(raw_record, _mode, dtype): 104 | """Parses a record containing a training example of an image. 105 | 106 | The input record is parsed into a label and image, and the image is passed 107 | through preprocessing steps (cropping, flipping, and so on). 108 | 109 | Args: 110 | raw_record: scalar Tensor tf.string containing a serialized 111 | Example protocol buffer. 112 | is_training: A boolean denoting whether the input is for training. 113 | dtype: data type to use for images/features. 114 | 115 | Returns: 116 | Tuple with processed image tensor and one-hot-encoded label tensor. 117 | """ 118 | image_buffer, label, bbox = _parse_example_proto(raw_record) 119 | 120 | image = imgnet_preprocessing.preprocess_image( 121 | image_buffer=image_buffer, 122 | bbox=bbox, 123 | output_height=_DEFAULT_IMAGE_SIZE, 124 | output_width=_DEFAULT_IMAGE_SIZE, 125 | num_channels=_NUM_CHANNELS, 126 | is_training=False) # as we are not classifying, do minimal processing 127 | image = tf.cast(image, dtype) 128 | 129 | return image, image 130 | 131 | 132 | def get_dataset(is_training, data_dir): 133 | """Returns a dataset object 134 | 135 | Args: 136 | is_training: A boolean denoting whether the input is for training. 137 | data_dir: The directory containing the input data. 138 | Returns: 139 | A dataset that can be used for iteration. 140 | """ 141 | filenames = get_filenames(is_training, data_dir) 142 | dataset = tf.data.Dataset.from_tensor_slices(filenames) 143 | 144 | 145 | # if is_training: 146 | # # Shuffle the input files 147 | # dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES) 148 | 149 | # Convert to individual records. 150 | # cycle_length = 20 means 20 files will be read and deserialized in 151 | # parallel. This number is low enough to not cause too much contention on 152 | # small systems but high enough to provide the benefits of parallelization. 153 | # You may want to increase this number if you have a large number of CPU 154 | # cores. 155 | dataset = dataset.apply(tf.contrib.data.parallel_interleave( 156 | tf.data.TFRecordDataset, cycle_length=20)) 157 | return dataset 158 | -------------------------------------------------------------------------------- /data/dataset_kodak.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tarfile 4 | from six.moves import urllib 5 | import tensorflow as tf 6 | 7 | _HEIGHT = 512 8 | _WIDTH = 768 9 | _NUM_CHANNELS = 3 10 | _NUM_IMAGES = { 11 | 'train': 24, 12 | 'validation': 24, 13 | 'test': 24, 14 | } 15 | 16 | SHUFFLE_BUFFER = _NUM_IMAGES['train'] 17 | SHAPE = [_HEIGHT, _WIDTH, _NUM_CHANNELS] 18 | 19 | 20 | def get_dataset(is_training, data_dir): 21 | """Returns a dataset object""" 22 | maybe_download_and_extract(data_dir) 23 | 24 | file_pattern = os.path.join(data_dir, "kodim*.png") 25 | filename_dataset = tf.data.Dataset.list_files(file_pattern) 26 | return filename_dataset.map(lambda x: tf.image.decode_png(tf.read_file(x))) 27 | 28 | 29 | def parse_record(raw_record, _mode, dtype): 30 | """Parse CIFAR-10 image and label from a raw record.""" 31 | image = tf.reshape(raw_record, [_HEIGHT, _WIDTH, _NUM_CHANNELS]) 32 | # normalise images to range 0-1 33 | image = tf.cast(image, dtype) 34 | image = tf.divide(image, 255.0) 35 | 36 | 37 | return image, image 38 | 39 | 40 | def preprocess_image(image, is_training): 41 | """Preprocess a single image of layout [height, width, depth].""" 42 | if is_training: 43 | # Resize the image to add four extra pixels on each side. 44 | image = tf.image.resize_image_with_crop_or_pad( 45 | image, _HEIGHT + 8, _WIDTH + 8) 46 | 47 | # Randomly crop a [_HEIGHT, _WIDTH] section of the image. 48 | image = tf.random_crop(image, [_HEIGHT, _WIDTH, _NUM_CHANNELS]) 49 | 50 | # Randomly flip the image horizontally. 51 | image = tf.image.random_flip_left_right(image) 52 | 53 | # Subtract off the mean and divide by the variance of the pixels. 54 | image = tf.image.per_image_standardization(image) 55 | return image 56 | 57 | 58 | def maybe_download_and_extract(data_dir): 59 | """Download and extract the tarball from Alex's website.""" 60 | if os.path.exists(data_dir): 61 | return 62 | else: 63 | os.makedirs(data_dir) 64 | 65 | filepath = data_dir 66 | 67 | url = "http://www.cs.albany.edu/~xypan/research/img/Kodak/kodim{}.png" 68 | def _progress(count, block_size, total_size): 69 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 70 | filepath, 100.0 * count * block_size / total_size)) 71 | sys.stdout.flush() 72 | 73 | for i in range(25): 74 | print(url.format(i+1)) 75 | filepath, _ = urllib.request.urlretrieve(url.format(i+1), filepath, _progress) 76 | print() 77 | statinfo = os.stat(filepath) 78 | print('Successfully downloaded', filepath, statinfo.st_size, 'bytes.') 79 | -------------------------------------------------------------------------------- /data/dataset_mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """tf.data.Dataset interface to the MNIST dataset.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import gzip 21 | import os 22 | import shutil 23 | import tempfile 24 | 25 | import numpy as np 26 | from six.moves import urllib 27 | import tensorflow as tf 28 | 29 | 30 | def read32(bytestream): 31 | """Read 4 bytes from bytestream as an unsigned 32-bit integer.""" 32 | dt = np.dtype(np.uint32).newbyteorder('>') 33 | return np.frombuffer(bytestream.read(4), dtype=dt)[0] 34 | 35 | 36 | def check_image_file_header(filename): 37 | """Validate that filename corresponds to images for the MNIST dataset.""" 38 | with tf.gfile.Open(filename, 'rb') as f: 39 | magic = read32(f) 40 | read32(f) # num_images, unused 41 | rows = read32(f) 42 | cols = read32(f) 43 | if magic != 2051: 44 | raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, 45 | f.name)) 46 | if rows != 28 or cols != 28: 47 | raise ValueError( 48 | 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' % 49 | (f.name, rows, cols)) 50 | 51 | 52 | def check_labels_file_header(filename): 53 | """Validate that filename corresponds to labels for the MNIST dataset.""" 54 | with tf.gfile.Open(filename, 'rb') as f: 55 | magic = read32(f) 56 | read32(f) # num_items, unused 57 | if magic != 2049: 58 | raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, 59 | f.name)) 60 | 61 | 62 | def download(directory, filename): 63 | """Download (and unzip) a file from the MNIST dataset if not already done.""" 64 | filepath = os.path.join(directory, filename) 65 | if tf.gfile.Exists(filepath): 66 | return filepath 67 | if not tf.gfile.Exists(directory): 68 | tf.gfile.MakeDirs(directory) 69 | # CVDF mirror of http://yann.lecun.com/exdb/mnist/ 70 | url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz' 71 | _, zipped_filepath = tempfile.mkstemp(suffix='.gz') 72 | print('Downloading %s to %s' % (url, zipped_filepath)) 73 | urllib.request.urlretrieve(url, zipped_filepath) 74 | with gzip.open(zipped_filepath, 'rb') as f_in, \ 75 | tf.gfile.Open(filepath, 'wb') as f_out: 76 | shutil.copyfileobj(f_in, f_out) 77 | os.remove(zipped_filepath) 78 | return filepath 79 | 80 | 81 | def dataset(directory, images_file, labels_file): 82 | """Download and parse MNIST dataset.""" 83 | 84 | images_file = download(directory, images_file) 85 | labels_file = download(directory, labels_file) 86 | 87 | check_image_file_header(images_file) 88 | check_labels_file_header(labels_file) 89 | 90 | def decode_image(image): 91 | # Normalize from [0, 255] to [0.0, 1.0] 92 | image = tf.decode_raw(image, tf.uint8) 93 | image = tf.cast(image, tf.float32) 94 | image = tf.reshape(image, [28, 28, 1]) 95 | return image / 255.0 96 | 97 | def decode_label(label): 98 | label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8] 99 | label = tf.reshape(label, []) # label is a scalar 100 | return tf.to_int32(label) 101 | 102 | images = tf.data.FixedLengthRecordDataset( 103 | images_file, 28 * 28, header_bytes=16).map(decode_image) 104 | labels = tf.data.FixedLengthRecordDataset( 105 | labels_file, 1, header_bytes=8).map(decode_label) 106 | return tf.data.Dataset.zip((images, images)) 107 | 108 | 109 | def train(directory): 110 | """tf.data.Dataset object for MNIST training data.""" 111 | return dataset(directory, 'train-images-idx3-ubyte', 112 | 'train-labels-idx1-ubyte') 113 | 114 | 115 | def test(directory): 116 | """tf.data.Dataset object for MNIST test data.""" 117 | return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte') 118 | -------------------------------------------------------------------------------- /data/imagenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images. 16 | 17 | Training images are sampled using the provided bounding boxes, and subsequently 18 | cropped to the sampled bounding box. Images are additionally flipped randomly, 19 | then resized to the target output size (without aspect-ratio preservation). 20 | 21 | Images used during evaluation are resized (with aspect-ratio preservation) and 22 | centrally cropped. 23 | 24 | All images undergo mean color subtraction. 25 | 26 | Note that these steps are colloquially referred to as "ResNet preprocessing," 27 | and they differ from "VGG preprocessing," which does not use bounding boxes 28 | and instead does an aspect-preserving resize followed by random crop during 29 | training. (These both differ from "Inception preprocessing," which introduces 30 | color distortion steps.) 31 | 32 | """ 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | 38 | import tensorflow as tf 39 | 40 | _R_MEAN = 123.68 41 | _G_MEAN = 116.78 42 | _B_MEAN = 103.94 43 | _CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN] 44 | 45 | # The lower bound for the smallest side of the image for aspect-preserving 46 | # resizing. For example, if an image is 500 x 1000, it will be resized to 47 | # _RESIZE_MIN x (_RESIZE_MIN * 2). 48 | _RESIZE_MIN = 256 49 | 50 | 51 | def _decode_crop_and_flip(image_buffer, bbox, num_channels): 52 | """Crops the given image to a random part of the image, and randomly flips. 53 | 54 | We use the fused decode_and_crop op, which performs better than the two ops 55 | used separately in series, but note that this requires that the image be 56 | passed in as an un-decoded string Tensor. 57 | 58 | Args: 59 | image_buffer: scalar string Tensor representing the raw JPEG image buffer. 60 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 61 | where each coordinate is [0, 1) and the coordinates are arranged as 62 | [ymin, xmin, ymax, xmax]. 63 | num_channels: Integer depth of the image buffer for decoding. 64 | 65 | Returns: 66 | 3-D tensor with cropped image. 67 | 68 | """ 69 | # A large fraction of image datasets contain a human-annotated bounding box 70 | # delineating the region of the image containing the object of interest. We 71 | # choose to create a new bounding box for the object which is a randomly 72 | # distorted version of the human-annotated bounding box that obeys an 73 | # allowed range of aspect ratios, sizes and overlap with the human-annotated 74 | # bounding box. If no box is supplied, then we assume the bounding box is 75 | # the entire image. 76 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 77 | tf.image.extract_jpeg_shape(image_buffer), 78 | bounding_boxes=bbox, 79 | min_object_covered=0.1, 80 | aspect_ratio_range=[0.75, 1.33], 81 | area_range=[0.05, 1.0], 82 | max_attempts=100, 83 | use_image_if_no_bounding_boxes=True) 84 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 85 | 86 | # Reassemble the bounding box in the format the crop op requires. 87 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 88 | target_height, target_width, _ = tf.unstack(bbox_size) 89 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) 90 | 91 | # Use the fused decode and crop op here, which is faster than each in series. 92 | cropped = tf.image.decode_and_crop_jpeg( 93 | image_buffer, crop_window, channels=num_channels) 94 | 95 | # Flip to add a little more random distortion in. 96 | cropped = tf.image.random_flip_left_right(cropped) 97 | return cropped 98 | 99 | 100 | def _central_crop(image, crop_height, crop_width): 101 | """Performs central crops of the given image list. 102 | 103 | Args: 104 | image: a 3-D image tensor 105 | crop_height: the height of the image following the crop. 106 | crop_width: the width of the image following the crop. 107 | 108 | Returns: 109 | 3-D tensor with cropped image. 110 | """ 111 | shape = tf.shape(image) 112 | height, width = shape[0], shape[1] 113 | 114 | amount_to_be_cropped_h = (height - crop_height) 115 | crop_top = amount_to_be_cropped_h // 2 116 | amount_to_be_cropped_w = (width - crop_width) 117 | crop_left = amount_to_be_cropped_w // 2 118 | return tf.slice( 119 | image, [crop_top, crop_left, 0], [crop_height, crop_width, -1]) 120 | 121 | 122 | def _mean_image_subtraction(image, means, num_channels): 123 | """Subtracts the given means from each image channel. 124 | 125 | For example: 126 | means = [123.68, 116.779, 103.939] 127 | image = _mean_image_subtraction(image, means) 128 | 129 | Note that the rank of `image` must be known. 130 | 131 | Args: 132 | image: a tensor of size [height, width, C]. 133 | means: a C-vector of values to subtract from each channel. 134 | num_channels: number of color channels in the image that will be distorted. 135 | 136 | Returns: 137 | the centered image. 138 | 139 | Raises: 140 | ValueError: If the rank of `image` is unknown, if `image` has a rank other 141 | than three or if the number of channels in `image` doesn't match the 142 | number of values in `means`. 143 | """ 144 | if image.get_shape().ndims != 3: 145 | raise ValueError('Input must be of size [height, width, C>0]') 146 | 147 | if len(means) != num_channels: 148 | raise ValueError('len(means) must match the number of channels') 149 | 150 | # We have a 1-D tensor of means; convert to 3-D. 151 | means = tf.expand_dims(tf.expand_dims(means, 0), 0) 152 | 153 | return image - means 154 | 155 | 156 | def _smallest_size_at_least(height, width, resize_min): 157 | """Computes new shape with the smallest side equal to `smallest_side`. 158 | 159 | Computes new shape with the smallest side equal to `smallest_side` while 160 | preserving the original aspect ratio. 161 | 162 | Args: 163 | height: an int32 scalar tensor indicating the current height. 164 | width: an int32 scalar tensor indicating the current width. 165 | resize_min: A python integer or scalar `Tensor` indicating the size of 166 | the smallest side after resize. 167 | 168 | Returns: 169 | new_height: an int32 scalar tensor indicating the new height. 170 | new_width: an int32 scalar tensor indicating the new width. 171 | """ 172 | resize_min = tf.cast(resize_min, tf.float32) 173 | 174 | # Convert to floats to make subsequent calculations go smoothly. 175 | height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32) 176 | 177 | smaller_dim = tf.minimum(height, width) 178 | scale_ratio = resize_min / smaller_dim 179 | 180 | # Convert back to ints to make heights and widths that TF ops will accept. 181 | new_height = tf.cast(height * scale_ratio, tf.int32) 182 | new_width = tf.cast(width * scale_ratio, tf.int32) 183 | 184 | return new_height, new_width 185 | 186 | 187 | def _aspect_preserving_resize(image, resize_min): 188 | """Resize images preserving the original aspect ratio. 189 | 190 | Args: 191 | image: A 3-D image `Tensor`. 192 | resize_min: A python integer or scalar `Tensor` indicating the size of 193 | the smallest side after resize. 194 | 195 | Returns: 196 | resized_image: A 3-D tensor containing the resized image. 197 | """ 198 | shape = tf.shape(image) 199 | height, width = shape[0], shape[1] 200 | 201 | new_height, new_width = _smallest_size_at_least(height, width, resize_min) 202 | 203 | return _resize_image(image, new_height, new_width) 204 | 205 | 206 | def _resize_image(image, height, width): 207 | """Simple wrapper around tf.resize_images. 208 | 209 | This is primarily to make sure we use the same `ResizeMethod` and other 210 | details each time. 211 | 212 | Args: 213 | image: A 3-D image `Tensor`. 214 | height: The target height for the resized image. 215 | width: The target width for the resized image. 216 | 217 | Returns: 218 | resized_image: A 3-D tensor containing the resized image. The first two 219 | dimensions have the shape [height, width]. 220 | """ 221 | return tf.image.resize_images( 222 | image, [height, width], method=tf.image.ResizeMethod.BILINEAR, 223 | align_corners=False) 224 | 225 | 226 | def preprocess_image(image_buffer, bbox, output_height, output_width, 227 | num_channels, is_training=False): 228 | """Preprocesses the given image. 229 | 230 | Preprocessing includes decoding, cropping, and resizing for both training 231 | and eval images. Training preprocessing, however, introduces some random 232 | distortion of the image to improve accuracy. 233 | 234 | Args: 235 | image_buffer: scalar string Tensor representing the raw JPEG image buffer. 236 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 237 | where each coordinate is [0, 1) and the coordinates are arranged as 238 | [ymin, xmin, ymax, xmax]. 239 | output_height: The height of the image after preprocessing. 240 | output_width: The width of the image after preprocessing. 241 | num_channels: Integer depth of the image buffer for decoding. 242 | is_training: `True` if we're preprocessing the image for training and 243 | `False` otherwise. 244 | 245 | Returns: 246 | A preprocessed image. 247 | """ 248 | if is_training: 249 | # For training, we want to randomize some of the distortions. 250 | image = _decode_crop_and_flip(image_buffer, bbox, num_channels) 251 | image = _resize_image(image, output_height, output_width) 252 | else: 253 | # For validation, we want to decode, resize, then just crop the middle. 254 | image = tf.image.decode_jpeg(image_buffer, channels=num_channels) 255 | image = _aspect_preserving_resize(image, _RESIZE_MIN) 256 | image = _central_crop(image, output_height, output_width) 257 | 258 | image.set_shape([output_height, output_width, num_channels]) 259 | 260 | # return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels) 261 | return image/255.0 # normalize image (instead of subtracting the norm) 262 | -------------------------------------------------------------------------------- /jscc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | from datetime import datetime 5 | import tensorflow as tf 6 | import numpy as np 7 | import configargparse 8 | from tensorflow.keras import layers 9 | import tensorflow_compression as tfc 10 | import data.dataset_cifar10 11 | import data.dataset_imagenet 12 | import data.dataset_kodak 13 | 14 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" 15 | 16 | DATASETS = { 17 | "cifar": data.dataset_cifar10, 18 | "imagenet": data.dataset_imagenet, 19 | "kodak": data.dataset_kodak, 20 | } 21 | 22 | 23 | class NBatchLogger(tf.keras.callbacks.Callback): 24 | """ 25 | A Logger that log average performance per `display` steps. 26 | """ 27 | 28 | def __init__(self, display): 29 | super(NBatchLogger, self).__init__() 30 | self.step = 0 31 | self.display = display 32 | self.metric_cache = {} 33 | self._start_time = time.time() 34 | 35 | def on_batch_end(self, batch, logs={}): 36 | self.step += 1 37 | for k in self.params["metrics"]: 38 | if k in logs: 39 | self.metric_cache[k] = self.metric_cache.get(k, 0) + logs[k] 40 | 41 | if self.step % self.display == 0: 42 | cur_time = time.time() 43 | duration = cur_time - self._start_time 44 | self._start_time = cur_time 45 | sec_per_step = duration / self.display 46 | 47 | metrics_log = "" 48 | for (k, v) in self.metric_cache.items(): 49 | val = v / self.display 50 | if abs(val) > 1e-3: 51 | metrics_log += " - %s: %.4f" % (k, val) 52 | else: 53 | metrics_log += " - %s: %.4e" % (k, val) 54 | print( 55 | "{} step: {}/{} {} - {:3f} sec/step".format( 56 | datetime.now(), 57 | self.step, 58 | self.params["steps"], 59 | metrics_log, 60 | sec_per_step, 61 | ) 62 | ) 63 | self.metric_cache.clear() 64 | 65 | 66 | class PSNRsVar(tf.keras.metrics.Metric): 67 | """Calculate the variance of a distribution of PSNRs across batches 68 | 69 | """ 70 | 71 | def __init__(self, name="variance", **kwargs): 72 | super(PSNRsVar, self).__init__(name=name, **kwargs) 73 | self.count = self.add_weight(name="count", shape=(), initializer="zeros") 74 | self.mean = self.add_weight(name="mean", shape=(), initializer="zeros") 75 | self.var = self.add_weight(name="M2", shape=(), initializer="zeros") 76 | 77 | def update_state(self, y_true, y_pred, sample_weight=None): 78 | psnrs = tf.image.psnr(y_true, y_pred, max_val=1.0) 79 | samples = tf.cast(psnrs, self.dtype) 80 | batch_count = tf.size(samples) 81 | batch_count = tf.cast(batch_count, self.dtype) 82 | batch_mean = tf.math.reduce_mean(samples) 83 | batch_var = tf.math.reduce_variance(samples) 84 | 85 | # compute new values for variables 86 | new_count = self.count + batch_count 87 | new_mean = (self.count * self.mean + batch_count * batch_mean) / ( 88 | self.count + batch_count 89 | ) 90 | new_var = ( 91 | (self.count * (self.var + tf.square(self.mean - new_mean))) 92 | + (batch_count * (batch_var + tf.square(batch_mean - new_mean))) 93 | ) / (self.count + batch_count) 94 | 95 | self.count.assign(new_count) 96 | self.mean.assign(new_mean) 97 | self.var.assign(new_var) 98 | 99 | def result(self): 100 | return self.var 101 | 102 | def reset_states(self): 103 | # The state of the metric will be reset at the start of each epoch. 104 | self.count.assign(np.zeros(self.count.shape)) 105 | self.mean.assign(np.zeros(self.mean.shape)) 106 | self.var.assign(np.zeros(self.var.shape)) 107 | 108 | 109 | class TargetPSNRsHistogram(tf.keras.metrics.Metric): 110 | def __init__(self, name="PSNR target", min_psnr=20, max_psnr=45, step=1, **kwargs): 111 | super(TargetPSNRsHistogram, self).__init__(name=name, **kwargs) 112 | self.bins_labels = np.arange(min_psnr, max_psnr + 1, step) 113 | self.bins = self.add_weight( 114 | name="bins", shape=self.bins_labels.shape, initializer="zeros" 115 | ) 116 | 117 | def update_state(self, y_true, y_pred, sample_weight=None): 118 | psnrs = tf.image.psnr(y_true, y_pred, max_val=1.0) 119 | counts = [] 120 | # count how many images fit in each psnr range 121 | for b, bin_label in enumerate(self.bins_labels): 122 | counts.append(tf.math.count_nonzero(tf.greater_equal(psnrs, bin_label))) 123 | 124 | self.bins.assign_add(counts) 125 | 126 | def result(self): 127 | return self.bins 128 | 129 | def reset_states(self): 130 | # The state of the metric will be reset at the start of each epoch. 131 | self.bins.assign(np.zeros(self.bins.shape)) 132 | 133 | 134 | def psnr_metric(x_in, x_out): 135 | if type(x_in) is list: 136 | img_in = x_in[0] 137 | else: 138 | img_in = x_in 139 | return tf.image.psnr(img_in, x_out, max_val=1.0) 140 | 141 | 142 | class Encoder(layers.Layer): 143 | """Build encoder from specified arch""" 144 | 145 | def __init__(self, conv_depth, name="encoder", **kwargs): 146 | super(Encoder, self).__init__(name=name, **kwargs) 147 | self.data_format = "channels_last" 148 | num_filters = 256 149 | self.sublayers = [ 150 | tfc.SignalConv2D( 151 | num_filters, 152 | (9, 9), 153 | name="layer_0", 154 | corr=True, 155 | strides_down=2, 156 | padding="same_zeros", 157 | use_bias=True, 158 | activation=tfc.GDN(name="gdn_0"), 159 | ), 160 | layers.PReLU(shared_axes=[1, 2]), 161 | tfc.SignalConv2D( 162 | num_filters, 163 | (5, 5), 164 | name="layer_1", 165 | corr=True, 166 | strides_down=2, 167 | padding="same_zeros", 168 | use_bias=True, 169 | activation=tfc.GDN(name="gdn_1"), 170 | ), 171 | layers.PReLU(shared_axes=[1, 2]), 172 | tfc.SignalConv2D( 173 | num_filters, 174 | (5, 5), 175 | name="layer_2", 176 | corr=True, 177 | strides_down=1, 178 | padding="same_zeros", 179 | use_bias=True, 180 | activation=tfc.GDN(name="gdn_2"), 181 | ), 182 | layers.PReLU(shared_axes=[1, 2]), 183 | tfc.SignalConv2D( 184 | num_filters, 185 | (5, 5), 186 | name="layer_3", 187 | corr=True, 188 | strides_down=1, 189 | padding="same_zeros", 190 | use_bias=True, 191 | activation=tfc.GDN(name="gdn_3"), 192 | ), 193 | layers.PReLU(shared_axes=[1, 2]), 194 | tfc.SignalConv2D( 195 | conv_depth, 196 | (5, 5), 197 | name="layer_out", 198 | corr=True, 199 | strides_down=1, 200 | padding="same_zeros", 201 | use_bias=True, 202 | activation=None, 203 | ), 204 | ] 205 | 206 | def call(self, x): 207 | for sublayer in self.sublayers: 208 | x = sublayer(x) 209 | return x 210 | 211 | 212 | class Decoder(layers.Layer): 213 | """Build encoder from specified arch""" 214 | 215 | def __init__(self, n_channels, name="decoder", **kwargs): 216 | super(Decoder, self).__init__(name=name, **kwargs) 217 | self.data_format = "channels_last" 218 | num_filters = 256 219 | self.sublayers = [ 220 | tfc.SignalConv2D( 221 | num_filters, 222 | (5, 5), 223 | name="layer_out", 224 | corr=False, 225 | strides_up=1, 226 | padding="same_zeros", 227 | use_bias=True, 228 | activation=tfc.GDN(name="igdn_out", inverse=True), 229 | ), 230 | layers.PReLU(shared_axes=[1, 2]), 231 | tfc.SignalConv2D( 232 | num_filters, 233 | (5, 5), 234 | name="layer_0", 235 | corr=False, 236 | strides_up=1, 237 | padding="same_zeros", 238 | use_bias=True, 239 | activation=tfc.GDN(name="igdn_0", inverse=True), 240 | ), 241 | layers.PReLU(shared_axes=[1, 2]), 242 | tfc.SignalConv2D( 243 | num_filters, 244 | (5, 5), 245 | name="layer_1", 246 | corr=False, 247 | strides_up=1, 248 | padding="same_zeros", 249 | use_bias=True, 250 | activation=tfc.GDN(name="igdn_1", inverse=True), 251 | ), 252 | layers.PReLU(shared_axes=[1, 2]), 253 | tfc.SignalConv2D( 254 | num_filters, 255 | (5, 5), 256 | name="layer_2", 257 | corr=False, 258 | strides_up=2, 259 | padding="same_zeros", 260 | use_bias=True, 261 | activation=tfc.GDN(name="igdn_2", inverse=True), 262 | ), 263 | layers.PReLU(shared_axes=[1, 2]), 264 | tfc.SignalConv2D( 265 | n_channels, 266 | (9, 9), 267 | name="layer_3", 268 | corr=False, 269 | strides_up=2, 270 | padding="same_zeros", 271 | use_bias=True, 272 | activation=tf.nn.sigmoid, 273 | ), 274 | ] 275 | 276 | def call(self, x): 277 | for sublayer in self.sublayers: 278 | x = sublayer(x) 279 | return x 280 | 281 | 282 | def real_awgn(x, stddev): 283 | """Implements the real additive white gaussian noise channel. 284 | Args: 285 | x: channel input symbols 286 | stddev: standard deviation of noise 287 | Returns: 288 | y: noisy channel output symbols 289 | """ 290 | # additive white gaussian noise 291 | awgn = tf.random.normal(tf.shape(x), 0, stddev, dtype=tf.float32) 292 | y = x + awgn 293 | 294 | return y 295 | 296 | 297 | def fading(x, stddev, h=None): 298 | """Implements the fading channel with multiplicative fading and 299 | additive white gaussian noise. 300 | Args: 301 | x: channel input symbols 302 | stddev: standard deviation of noise 303 | Returns: 304 | y: noisy channel output symbols 305 | """ 306 | # channel gain 307 | if h is None: 308 | h = tf.complex( 309 | tf.random.normal([tf.shape(x)[0], 1], 0, 1 / np.sqrt(2)), 310 | tf.random.normal([tf.shape(x)[0], 1], 0, 1 / np.sqrt(2)), 311 | ) 312 | 313 | # additive white gaussian noise 314 | awgn = tf.complex( 315 | tf.random.normal(tf.shape(x), 0, 1 / np.sqrt(2)), 316 | tf.random.normal(tf.shape(x), 0, 1 / np.sqrt(2)), 317 | ) 318 | 319 | return (h * x + stddev * awgn), h 320 | 321 | 322 | def phase_invariant_fading(x, stddev, h=None): 323 | """Implements the fading channel with multiplicative fading and 324 | additive white gaussian noise. Also assumes that phase shift 325 | introduced by the fading channel is known at the receiver, making 326 | the model equivalent to a real slow fading channel. 327 | 328 | Args: 329 | x: channel input symbols 330 | stddev: standard deviation of noise 331 | Returns: 332 | y: noisy channel output symbols 333 | """ 334 | # channel gain 335 | if h is None: 336 | n1 = tf.random.normal([tf.shape(x)[0], 1], 0, 1 / np.sqrt(2), dtype=tf.float32) 337 | n2 = tf.random.normal([tf.shape(x)[0], 1], 0, 1 / np.sqrt(2), dtype=tf.float32) 338 | 339 | h = tf.sqrt(tf.square(n1) + tf.square(n2)) 340 | 341 | # additive white gaussian noise 342 | awgn = tf.random.normal(tf.shape(x), 0, stddev / np.sqrt(2), dtype=tf.float32) 343 | 344 | return (h * x + awgn), h 345 | 346 | 347 | class Channel(layers.Layer): 348 | def __init__(self, channel_type, channel_snr, name="channel", **kwargs): 349 | super(Channel, self).__init__(name=name, **kwargs) 350 | self.channel_type = channel_type 351 | self.channel_snr = channel_snr 352 | 353 | def call(self, inputs): 354 | (encoded_img, prev_h) = inputs 355 | inter_shape = tf.shape(encoded_img) 356 | # reshape array to [-1, dim_z] 357 | z = layers.Flatten()(encoded_img) 358 | # convert from snr to std 359 | print("channel_snr: {}".format(self.channel_snr)) 360 | noise_stddev = np.sqrt(10 ** (-self.channel_snr / 10)) 361 | 362 | # Add channel noise 363 | if self.channel_type == "awgn": 364 | dim_z = tf.shape(z)[1] 365 | # normalize latent vector so that the average power is 1 366 | z_in = tf.sqrt(tf.cast(dim_z, dtype=tf.float32)) * tf.nn.l2_normalize( 367 | z, axis=1 368 | ) 369 | z_out = real_awgn(z_in, noise_stddev) 370 | h = tf.ones_like(z_in) # h just makes sense on fading channels 371 | 372 | elif self.channel_type == "fading": 373 | dim_z = tf.shape(z)[1] // 2 374 | # convert z to complex representation 375 | z_in = tf.complex(z[:, :dim_z], z[:, dim_z:]) 376 | # normalize the latent vector so that the average power is 1 377 | z_norm = tf.reduce_sum( 378 | tf.math.real(z_in * tf.math.conj(z_in)), axis=1, keepdims=True 379 | ) 380 | z_in = z_in * tf.complex( 381 | tf.sqrt(tf.cast(dim_z, dtype=tf.float32) / z_norm), 0.0 382 | ) 383 | z_out, h = fading(z_in, noise_stddev, prev_h) 384 | # convert back to real 385 | z_out = tf.concat([tf.math.real(z_out), tf.math.imag(z_out)], 1) 386 | 387 | elif self.channel_type == "fading-real": 388 | # half of the channels are I component and half Q 389 | dim_z = tf.shape(z)[1] // 2 390 | # normalization 391 | z_in = tf.sqrt(tf.cast(dim_z, dtype=tf.float32)) * tf.nn.l2_normalize( 392 | z, axis=1 393 | ) 394 | z_out, h = phase_invariant_fading(z_in, noise_stddev, prev_h) 395 | 396 | else: 397 | raise Exception("This option shouldn't be an option!") 398 | 399 | # convert signal back to intermediate shape 400 | z_out = tf.reshape(z_out, inter_shape) 401 | # compute average power 402 | avg_power = tf.reduce_mean(tf.math.real(z_in * tf.math.conj(z_in))) 403 | # add avg_power as layer's metric 404 | return z_out, avg_power, h 405 | 406 | 407 | class OutputsCombiner(layers.Layer): 408 | def __init__(self, name="out_combiner", **kwargs): 409 | super(OutputsCombiner, self).__init__(name=name, **kwargs) 410 | self.conv1 = layers.Conv2D(48, 3, 1, padding="same") 411 | self.prelu1 = layers.PReLU(shared_axes=[1, 2]) 412 | self.conv2 = layers.Conv2D(3, 3, 1, padding="same", activation=tf.nn.sigmoid) 413 | 414 | def call(self, inputs): 415 | img_prev, residual = inputs 416 | 417 | reconst = tf.concat([img_prev, residual], axis=-1) 418 | reconst = self.conv1(reconst) 419 | reconst = self.prelu1(reconst) 420 | reconst = self.conv2(reconst) 421 | 422 | return reconst 423 | 424 | 425 | class DeepJSCCF(layers.Layer): 426 | def __init__( 427 | self, 428 | channel_snr, 429 | conv_depth, 430 | channel_type, 431 | feedback_snr, 432 | refinement_layer, 433 | layer_id, 434 | target_analysis=False, 435 | name="deep_jscc_f", 436 | **kwargs 437 | ): 438 | super(DeepJSCCF, self).__init__(name=name, **kwargs) 439 | 440 | n_channels = 3 # change this if working with BW images 441 | self.refinement_layer = refinement_layer 442 | self.feedback_snr = feedback_snr 443 | self.layer = layer_id 444 | self.encoder = Encoder(conv_depth) 445 | self.decoder = Decoder(n_channels, name="decoder_output") 446 | self.channel = Channel(channel_type, channel_snr, name="channel_output") 447 | if self.refinement_layer: 448 | self.image_combiner = OutputsCombiner(name="out_comb") 449 | self.target_analysis = target_analysis 450 | 451 | def call(self, inputs): 452 | if self.refinement_layer: 453 | ( 454 | img, 455 | prev_img_out_fb, 456 | prev_chn_out_fb, 457 | prev_img_out_dec, 458 | prev_chn_out_dec, 459 | prev_chn_gain, 460 | ) = inputs 461 | 462 | img_in = tf.concat([prev_img_out_fb, img], axis=-1) 463 | 464 | else: # base layer 465 | # inputs is just the original image 466 | img_in = img = inputs 467 | prev_chn_gain = None 468 | 469 | chn_in = self.encoder(img_in) 470 | chn_out, avg_power, chn_gain = self.channel((chn_in, prev_chn_gain)) 471 | 472 | # add feedback noise to chn_output 473 | if self.feedback_snr is None: # No feedback noise 474 | chn_out_fb = chn_out 475 | else: 476 | fb_noise_stddev = np.sqrt(10 ** (-self.feedback_snr / 10)) 477 | chn_out_fb = real_awgn(chn_out, fb_noise_stddev) 478 | 479 | if self.refinement_layer: 480 | # combine chn_output with previous stored chn_outs 481 | chn_out_exp = tf.concat([chn_out, prev_chn_out_dec], axis=-1) 482 | residual_img = self.decoder(chn_out_exp) 483 | # combine residual ith previous stored image reconstruction 484 | decoded_img = self.image_combiner((prev_img_out_dec, residual_img)) 485 | 486 | # feedback estimation 487 | # Note: the ops below is just computed when this is not the last 488 | # layer (as this op is not included in the loss function when this 489 | # is the output), so decoder is just trained with actual chn_outs, 490 | # and the op below just happens when trainable=False 491 | chn_out_exp_fb = tf.concat([chn_out_fb, prev_chn_out_fb], axis=-1) 492 | residual_img_fb = self.decoder(chn_out_exp_fb) 493 | decoded_img_fb = self.image_combiner([prev_img_out_fb, residual_img_fb]) 494 | else: 495 | chn_out_exp = chn_out 496 | decoded_img = self.decoder(chn_out_exp) 497 | 498 | chn_out_exp_fb = chn_out_fb 499 | decoded_img_fb = self.decoder(chn_out_exp_fb) 500 | 501 | # keep track of some metrics 502 | self.add_metric( 503 | tf.image.psnr(img, decoded_img, max_val=1.0), 504 | aggregation="mean", 505 | name="psnr{}".format(self.layer), 506 | ) 507 | self.add_metric( 508 | tf.image.psnr(img, decoded_img_fb, max_val=1.0), 509 | aggregation="mean", 510 | name="psnr_fb{}".format(self.layer), 511 | ) 512 | self.add_metric( 513 | tf.reduce_mean(tf.math.square(img - decoded_img)), 514 | aggregation="mean", 515 | name="mse{}".format(self.layer), 516 | ) 517 | self.add_metric( 518 | avg_power, aggregation="mean", name="avg_pwr{}".format(self.layer) 519 | ) 520 | 521 | return (decoded_img, decoded_img_fb, chn_out_exp, chn_out_exp_fb, chn_gain) 522 | 523 | def change_channel_snr(self, channel_snr): 524 | self.channel.channel_snr = channel_snr 525 | 526 | def change_feedback_snr(self, feedback_snr): 527 | self.feedback_snr = feedback_snr 528 | 529 | 530 | def main(args): 531 | # get dataset 532 | x_train, x_val, x_tst = get_dataset(args) 533 | 534 | if args.delete_previous_model and tf.io.gfile.exists(args.model_dir): 535 | print("Deleting previous model files at {}".format(args.model_dir)) 536 | tf.io.gfile.rmtree(args.model_dir) 537 | tf.io.gfile.makedirs(args.model_dir) 538 | else: 539 | print("Starting new model at {}".format(args.model_dir)) 540 | tf.io.gfile.makedirs(args.model_dir) 541 | 542 | # load model 543 | prev_layer_out = None 544 | # add input placeholder to please keras 545 | img = tf.keras.Input(shape=(None, None, 3)) 546 | 547 | if not args.run_eval_once: 548 | feedback_snr = None if not args.feedback_noise else args.feedback_snr_train 549 | channel_snr = args.channel_snr_train 550 | else: 551 | feedback_snr = None if not args.feedback_noise else args.feedback_snr_eval 552 | channel_snr = args.channel_snr_eval 553 | 554 | all_models = [] 555 | for layer in range(args.n_layers): 556 | ckpt_file = os.path.join(args.model_dir, "ckpt_layer{}".format(layer)) 557 | layer_name = "layer{}".format(layer) 558 | ae_layer = DeepJSCCF( 559 | channel_snr, 560 | int(args.conv_depth), 561 | args.channel, 562 | feedback_snr, 563 | layer > 0, # refinement or base? 564 | layer, 565 | args.target_analysis, 566 | name=layer_name, 567 | ) 568 | 569 | # connect ae_layer to previous model, (if any) 570 | if layer == 0: # base layer 571 | # model returns img and channel outputs 572 | layer_output = ae_layer(img) 573 | else: 574 | # add prev layer outputs as input for cur layer 575 | ( 576 | prev_img_out_dec, 577 | prev_img_out_fb, 578 | prev_chn_out_dec, 579 | prev_chn_out_fb, 580 | prev_chn_gain, 581 | ) = prev_layer_out 582 | layer_output = ae_layer( 583 | ( 584 | img, 585 | prev_img_out_fb, 586 | prev_chn_out_fb, 587 | prev_img_out_dec, 588 | prev_chn_out_dec, 589 | prev_chn_gain, 590 | ) 591 | ) 592 | 593 | ( 594 | decoded_img, 595 | _decoded_img_fb, 596 | _chn_out_exp, 597 | _chn_out_exp_fb, 598 | _chn_gain, 599 | ) = layer_output 600 | model = tf.keras.Model(inputs=img, outputs=decoded_img) 601 | 602 | model_metrics = [ 603 | tf.keras.metrics.MeanSquaredError(), 604 | psnr_metric, 605 | PSNRsVar(name="psnr_var{}".format(layer)), 606 | ] 607 | if args.target_analysis: 608 | model_metrics.append(TargetPSNRsHistogram(name="target{}".format(layer))) 609 | model.compile( 610 | optimizer=tf.keras.optimizers.Adam(learning_rate=args.learn_rate), 611 | loss="mse", 612 | metrics=model_metrics, 613 | ) 614 | 615 | # check if checkpoint already exists and load it 616 | if (layer == 0 and args.pretrained_base_layer) or glob.glob(ckpt_file + "*"): 617 | # trick to restore metrics too (see tensorflow guide on saving and 618 | # serializing subclassed models) 619 | model.train_on_batch(x_train) 620 | if layer == 0 and args.pretrained_base_layer: 621 | print("Using pre-trained base layer!") 622 | model.load_weights( 623 | os.path.join( 624 | args.pretrained_base_layer, "ckpt_layer{}".format(layer) 625 | ) 626 | ) 627 | else: 628 | print("Restoring weights from checkpoint!") 629 | model.load_weights(ckpt_file) 630 | 631 | print(model.summary()) 632 | 633 | # skip training if just running eval or if loading first layer from 634 | # pretrained ckpt 635 | if not (args.run_eval_once or (layer == 0 and args.pretrained_base_layer)): 636 | train_patience = 3 if args.dataset_train != "imagenet" else 2 637 | callbacks = [ 638 | tf.keras.callbacks.EarlyStopping( 639 | patience=train_patience, 640 | monitor="val_psnr_metric", 641 | min_delta=10e-3, 642 | verbose=1, 643 | mode="max", 644 | restore_best_weights=True, 645 | ), 646 | tf.keras.callbacks.TensorBoard(log_dir=args.eval_dir), 647 | # just save a single checkpoint with best. If more is wanted, 648 | # create a new callback 649 | tf.keras.callbacks.ModelCheckpoint( 650 | filepath=ckpt_file, 651 | monitor="val_psnr_metric", 652 | mode="max", 653 | save_best_only=True, 654 | verbose=1, 655 | save_weights_only=True, 656 | ), 657 | tf.keras.callbacks.TerminateOnNaN(), 658 | ] 659 | 660 | if args.dataset_train == "imagenet": 661 | callbacks.append(NBatchLogger(100)) 662 | 663 | model.fit( 664 | x_train, 665 | epochs=args.train_epochs, 666 | validation_data=x_val, 667 | callbacks=callbacks, 668 | verbose=2, 669 | validation_freq=args.epochs_between_evals, 670 | validation_steps=( 671 | DATASETS[args.dataset_train]._NUM_IMAGES["validation"] 672 | // args.batch_size_train 673 | ), 674 | ) 675 | 676 | # freeze weights of already trained layers 677 | model.trainable = False 678 | # define model as prev_model 679 | prev_layer_out = layer_output 680 | all_models.append(model) 681 | 682 | print("EVALUATION!!!") 683 | # normally we just eval the complete model, unless we are doing target_analysis 684 | models = [model] if not args.target_analysis else all_models 685 | for eval_model in models: 686 | out_eval = eval_model.evaluate(x_tst, verbose=2) 687 | for m, v in zip(eval_model.metrics_names, out_eval): 688 | met_name = "_".join(["eval", m]) 689 | print("{}={}".format(met_name, v), end=" ") 690 | print() 691 | print() 692 | 693 | 694 | def get_dataset(args): 695 | data_options = tf.data.Options() 696 | data_options.experimental_deterministic = False 697 | data_options.experimental_optimization.apply_default_optimizations = True 698 | data_options.experimental_optimization.map_parallelization = True 699 | data_options.experimental_optimization.parallel_batch = True 700 | data_options.experimental_optimization.autotune_buffers = True 701 | 702 | def prepare_dataset(dataset, mode, parse_record_fn, bs): 703 | dataset = dataset.with_options(data_options) 704 | if mode == "train": 705 | dataset = dataset.shuffle(buffer_size=dataset_obj.SHUFFLE_BUFFER) 706 | dataset = dataset.map( 707 | lambda v: parse_record_fn(v, mode, tf.float32), 708 | num_parallel_calls=tf.data.experimental.AUTOTUNE, 709 | ) 710 | return dataset.batch(bs) 711 | 712 | dataset_obj = DATASETS[args.dataset_train] 713 | parse_record_fn = dataset_obj.parse_record 714 | if args.dataset_train != "imagenet": 715 | tr_val_dataset = dataset_obj.get_dataset(True, args.data_dir_train) 716 | tr_dataset = tr_val_dataset.take(dataset_obj._NUM_IMAGES["train"]) 717 | val_dataset = tr_val_dataset.skip(dataset_obj._NUM_IMAGES["train"]) 718 | else: # treat imagenet differently, as we usually dont use it for training 719 | tr_dataset = dataset_obj.get_dataset(True, args.data_dir_train) 720 | val_dataset = dataset_obj.get_dataset(False, args.data_dir_train) 721 | # Train 722 | x_train = prepare_dataset( 723 | tr_dataset, "train", parse_record_fn, args.batch_size_train 724 | ) 725 | # Validation 726 | x_val = prepare_dataset(val_dataset, "val", parse_record_fn, args.batch_size_train) 727 | 728 | # Test 729 | dataset_obj = DATASETS[args.dataset_eval] 730 | parse_record_fn = dataset_obj.parse_record 731 | tst_dataset = dataset_obj.get_dataset(False, args.data_dir_eval) 732 | x_tst = prepare_dataset(tst_dataset, "test", parse_record_fn, args.batch_size_eval) 733 | x_tst.repeat(10) # number of realisations per image on evaluation 734 | 735 | return x_train, x_val, x_tst 736 | 737 | 738 | if __name__ == "__main__": 739 | # parse args 740 | p = configargparse.ArgParser() 741 | p.add( 742 | "-c", 743 | "--my-config", 744 | required=False, 745 | is_config_file=True, 746 | help="config file path", 747 | ) 748 | p.add( 749 | "--conv_depth", 750 | type=float, 751 | default=16, 752 | help=( 753 | "Number of channels of last conv layer, used to define the " 754 | "compression rate: k/n=c_out/(16*3)" 755 | ), 756 | required=True, 757 | ) 758 | p.add( 759 | "--n_layers", 760 | type=int, 761 | default=3, 762 | help=("Number of layers/rounds used in the transmission"), 763 | required=True, 764 | ) 765 | p.add( 766 | "--channel", 767 | type=str, 768 | default="awgn", 769 | choices=["awgn", "fading", "fading-real"], 770 | help="Model of channel used (awgn, fading)", 771 | ) 772 | p.add( 773 | "--model_dir", 774 | type=str, 775 | default="/tmp/train_logs", 776 | help=("The location of the model checkpoint files."), 777 | ) 778 | p.add( 779 | "--eval_dir", 780 | type=str, 781 | default="/tmp/train_logs/eval", 782 | help=("The location of eval files (tensorboard, etc)."), 783 | ) 784 | p.add( 785 | "--delete_previous_model", 786 | action="store_true", 787 | default=False, 788 | help=("If model_dir has checkpoints, delete it before" "starting new run"), 789 | ) 790 | p.add( 791 | "--channel_snr_train", 792 | type=float, 793 | default=1, 794 | help="target SNR of channel during training (dB)", 795 | ) 796 | p.add( 797 | "--channel_snr_eval", 798 | type=float, 799 | default=1, 800 | help="target SNR of channel during evaluation (dB)", 801 | ) 802 | p.add( 803 | "--feedback_noise", 804 | action="store_true", 805 | default=False, 806 | help=("Apply (AWGN) noise to feedback channel"), 807 | ) 808 | p.add( 809 | "--feedback_snr_train", 810 | type=float, 811 | default=20, 812 | help=( 813 | "SNR (dB) of the feedback channel " 814 | "(only applies when feedback_noise=True)" 815 | ), 816 | ) 817 | p.add( 818 | "--feedback_snr_eval", 819 | type=float, 820 | default=20, 821 | help=( 822 | "SNR (dB) of the feedback channel (only applies when feedback_noise=True)" 823 | ), 824 | ) 825 | p.add( 826 | "--learn_rate", 827 | type=float, 828 | default=0.0001, 829 | help="Learning rate for Adam optimizer", 830 | ) 831 | p.add( 832 | "--run_eval_once", 833 | action="store_true", 834 | default=False, 835 | help="Skip train, run only eval and exit", 836 | ) 837 | p.add( 838 | "--train_epochs", 839 | type=int, 840 | default=10000, 841 | help=( 842 | "The number of epochs used to train (each epoch goes over the whole dataset)" 843 | ), 844 | ) 845 | p.add("--batch_size_train", type=int, default=128, help="Batch size for training") 846 | p.add("--batch_size_eval", type=int, default=128, help="Batch size for evaluation") 847 | p.add( 848 | "--epochs_between_evals", 849 | type=int, 850 | default=30, 851 | help=("the number of training epochs to run between evaluations."), 852 | ) 853 | p.add( 854 | "--dataset_train", 855 | type=str, 856 | default="cifar", 857 | choices=DATASETS.keys(), 858 | help=("Choose image dataset. Options: {}".format(DATASETS.keys())), 859 | ) 860 | p.add( 861 | "--dataset_eval", 862 | type=str, 863 | default="cifar", 864 | choices=DATASETS.keys(), 865 | help=("Choose image dataset. Options: {}".format(DATASETS.keys())), 866 | ) 867 | p.add( 868 | "--data_dir_train", 869 | type=str, 870 | default="/tmp/train_data", 871 | help="Directory where to store the training data set", 872 | ) 873 | p.add( 874 | "--data_dir_eval", 875 | type=str, 876 | default="/tmp/train_data", 877 | help="Directory where to store the eval data set", 878 | ) 879 | p.add( 880 | "--pretrained_base_layer", 881 | type=str, 882 | help="Use existing checkpoints for base layer", 883 | ) 884 | p.add( 885 | "--target_analysis", 886 | action="store_true", 887 | default=False, 888 | help="perform PSNR target analysis", 889 | ) 890 | 891 | args = p.parse_args() 892 | 893 | print("#######################################") 894 | print("Current execution paramenters:") 895 | for arg, value in sorted(vars(args).items()): 896 | print("{}: {}".format(arg, value)) 897 | print("#######################################") 898 | main(args) 899 | --------------------------------------------------------------------------------