├── .gitignore ├── README.md ├── config ├── csinet_indoor_cost2100_pow.json ├── csinet_indoor_cost2100_xsede.json ├── csinet_lstm_outdoor_cost2100.json ├── csinet_lstm_outdoor_cost2100_xsede.json ├── csinet_outdoor_cost2100.json └── csinet_outdoor_cost2100_pow.json ├── csinet-lstm ├── csinet.py ├── csinet_lstm.py ├── csinet_lstm_eval_cosine.py ├── csinet_lstm_quant.py ├── csinet_lstm_train.py ├── csinet_lstm_train.sh ├── csinet_quant.py ├── csinet_resid.py ├── csinet_train.py └── csinet_train.sh └── torch └── csinet_torch.py /.gitignore: -------------------------------------------------------------------------------- 1 | *log 2 | data/ 3 | *__pycache__* 4 | *slurm* 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CsiNet-LSTM 2 | 3 | Repository for reproduction of CsiNet-LSTM used in MarkovNet paper (currently preprint). 4 | 5 | - [1] Liu, Z., del Rosario, M., & Ding, Z. (2020). A Markovian Model-Driven Deep Learning Framework for Massive MIMO CSI Feedback. arXiv preprint arXiv:2009.09468. 6 | - [2] T. Wang, C. Wen, S. Jin, and G. Y. Li, “Deep learning-based csi feedback approach for time-varying massive mimo channels,” IEEE Wireless Communications Letters, vol. 8, no. 2, pp. 416–419, April 2019. 7 | 8 | ## Data 9 | 10 | **TODO**: Add a link to 20 timeslot COST2100 data. 11 | 12 | ## Dependencies 13 | 14 | **TODO**: Add a) exhaustive list of dependencies for this repo and/or b) .yml/.dockerfile for setting up working environment. 15 | 16 | This repository relies heavily on the [`brat` repository](https://github.com/mdelrosa/brat). In particular, the `brat/utils` directory handles parsing .json config files, loading data, and evaluating network performance. 17 | 18 | The typical hierarchy which this 19 | 20 | - `home` (your home directory) 21 | - `git` 22 | - `brat` (repo available [here](https://github.com/mdelrosa/brat)) 23 | - `csinet-lstm` (this repo) 24 | - `csinet-lstm` 25 | - `csinet_train.py` (training script for csinet at single timeslot) 26 | - `csinet_lstm_train.py` (training script for csinet_lstm; typically ten timeslots) 27 | -------------------------------------------------------------------------------- /config/csinet_indoor_cost2100_pow.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_dir": "/pylon5/ec5phkp/mdelrosa/model/csinet-lstm", 3 | "aux_bool": 1, 4 | "df":"channels_first", 5 | "dataset_spec": ["/pylon5/ec5phkp/mdelrosa/data/indoor_pow/H_user_t", "32all.mat", "Hur_down_t1", 0.75], 6 | "diff_spec": ["/pylon5/ec5phkp/mdelrosa/data/indoor_pow/P_diff_T"], 7 | "batch_num": 20, 8 | "learning_rate": 0.001, 9 | "batch_size": 200, 10 | "network_name": "csinet", 11 | "minmax_file": "/pylon5/ec5phkp/mdelrosa/data/indoor_pow/power/Data100_minmax.csv", 12 | "norm_range": "norm_H4", 13 | "T":1, 14 | "subsample_prop": 1, 15 | "n_delay": 32, 16 | "img_channels": 2, 17 | "thresh_idx_path": "/pylon5/ec5phkp/mdelrosa/data/indoor_pow/power/H_pow_idx.mat" 18 | } 19 | -------------------------------------------------------------------------------- /config/csinet_indoor_cost2100_xsede.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_dir": "/pylon5/ec5phkp/mdelrosa/model/csinet-lstm", 3 | "aux_bool": 1, 4 | "M_1": 512, 5 | "df":"channels_first", 6 | "epochs": 600, 7 | "gpu_num": 0, 8 | "dataset_spec": ["/pylon5/ec5phkp/mdelrosa/data/indoor_pow/Data100_Htrainin_down_FDD_H4_32ant_T","/pylon5/ec5phkp/mdelrosa/data/indoor_pow/Data100_Hval_down_FDD_H4_32ant_T","/pylon5/ec5phkp/mdelrosa/data/indoor_pow/Data100_Htest_down_FDD_H4_32ant_T"], 9 | "diff_spec": ["/pylon5/ec5phkp/mdelrosa/data/indoor_pow/P_diff_Hval_down_T","/pylon5/ec5phkp/mdelrosa/data/indoor_pow/P_diff_Htest_down_T"], 10 | "batch_num": 20, 11 | "lrs": [0.001], 12 | "batch_sizes": [200], 13 | "network_name": "csinet", 14 | "minmax_file": "/pylon5/ec5phkp/mdelrosa/data/indoor_pow/Data100_minmax.csv", 15 | "norm_range": "norm_H4", 16 | "T":1 17 | } 18 | -------------------------------------------------------------------------------- /config/csinet_lstm_outdoor_cost2100.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_dir": "/home/mdelrosa/models/csinet-lstm", 3 | "aux_bool": 1, 4 | "share_bool": 0, 5 | "M_1": 512, 6 | "df":"channels_first", 7 | "epochs": 500, 8 | "t1_train": true, 9 | "t2_train": true, 10 | "gpu_num": 1, 11 | "lstm_latent_bool": 0, 12 | "conv_lstm_bool": 0, 13 | "dataset_spec": ["/home/zyuliu/data_sharing/outdoor_20slots/H_user_t", "32all.mat", "Hur_down_t1", 0.75], 14 | "diff_spec": ["../data/outdoor/P_diff_T"], 15 | "batch_num": 20, 16 | "lr": 0.001, 17 | "batch_size": 100, 18 | "network_name": "csinet_lstm", 19 | "subnetwork_name": "csinet", 20 | "minmax_file": "../data/outdoor/Data100_minmax.csv", 21 | "norm_range": "norm_H4", 22 | "envir": "outdoor", 23 | "T": 10, 24 | "subsample_prop": 1, 25 | "pass_through_bool": 0, 26 | "load_bool": 1, 27 | "pretrained_bool": 1, 28 | "LSTM_only_bool": false, 29 | "quant_bool": false 30 | } 31 | -------------------------------------------------------------------------------- /config/csinet_lstm_outdoor_cost2100_xsede.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_dir": "/ocean/projects/ecs190004p/mdelrosa/model/csinet-lstm", 3 | "aux_bool": 1, 4 | "share_bool": 0, 5 | "M_1": 512, 6 | "df":"channels_first", 7 | "epochs": 500, 8 | "t1_train": true, 9 | "t2_train": true, 10 | "gpu_num": 1, 11 | "lstm_latent_bool": 0, 12 | "conv_lstm_bool": 0, 13 | "dataset_spec": ["/ocean/projects/ecs190004p/mdelrosa/data/outdoor_pow/H_user_t", "32all.mat", "Hur_down_t1", 0.75], 14 | "diff_spec": ["/ocean/projects/ecs190004p/mdelrosa/data/outdoor_pow/data/outdoor/P_diff_T"], 15 | "batch_num": 20, 16 | "lr": 0.001, 17 | "batch_size": 100, 18 | "network_name": "csinet_lstm", 19 | "subnetwork_name": "csinet", 20 | "minmax_file": "/ocean/projects/ecs190004p/mdelrosa/data/outdoor_pow//data/outdoor/Data100_minmax.csv", 21 | "norm_range": "norm_H4", 22 | "envir": "outdoor", 23 | "T": 10, 24 | "subsample_prop": 1, 25 | "pass_through_bool": 0, 26 | "load_bool": 1, 27 | "pretrained_bool": 1, 28 | "LSTM_only_bool": false, 29 | "quant_bool": false, 30 | "thresh_idx_path": false 31 | } 32 | -------------------------------------------------------------------------------- /config/csinet_outdoor_cost2100.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_dir": "/home/mdelrosa/models/csinet-lstm", 3 | "aux_bool": 1, 4 | "M_1": 512, 5 | "df":"channels_first", 6 | "epochs": 600, 7 | "gpu_num": 0, 8 | "dataset_spec": ["/home/mdelrosa/csi/LSTM/data/outdoor300/replication/Data100_Htrainin_down_FDD_H4_32ant","/home/mdelrosa/csi/LSTM/data/outdoor300/replication/Data100_Hvalin_down_FDD_H4_32ant","/home/mdelrosa/csi/LSTM/data/outdoor300/replication/Data100_Htestin_down_FDD_H4_32ant"], 9 | "batch_num": 20, 10 | "lrs": [0.001], 11 | "batch_sizes": [100], 12 | "network_name": "csinet", 13 | "minmax_file": "/home/mdelrosa/csi/LSTM/data/outdoor300/replication/mat_outdoor300_bw20MHz_slow_minmax.csv", 14 | "norm_range": "norm_H4", 15 | "T":1 16 | } 17 | -------------------------------------------------------------------------------- /config/csinet_outdoor_cost2100_pow.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_dir": "/home/mdelrosa/models/csinet-lstm", 3 | "aux_bool": 1, 4 | "df":"channels_first", 5 | "dataset_spec": ["/home/zyuliu/data_sharing/outdoor_20slots/H_user_t", "32all.mat", "Hur_down_t1", 0.75], 6 | "diff_spec": ["../data/outdoor/P_diff_T"], 7 | "batch_num": 20, 8 | "learning_rate": 0.001, 9 | "batch_size": 200, 10 | "network_name": "csinet", 11 | "minmax_file": "../data/outdoor/Data100_minmax.csv", 12 | "norm_range": "norm_H4", 13 | "T":1, 14 | "subsample_prop": 1, 15 | "n_delay": 32, 16 | "img_channels": 2 17 | } -------------------------------------------------------------------------------- /csinet-lstm/csinet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import concatenate, Dense, BatchNormalization, Reshape, add, LeakyReLU 3 | from tensorflow.keras import Input 4 | from tensorflow.keras.models import Model 5 | import numpy as np 6 | 7 | Conv2D = tf.keras.layers.Conv2D 8 | 9 | class CosineSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): 10 | def __init__(self, warmup_steps=200, epochs=600, max_lr=1e-3, min_lr=1e-4): 11 | super(CosineSchedule, self).__init__() 12 | 13 | self.warmup_steps = tf.cast(warmup_steps, tf.float32) 14 | self.epochs = tf.cast(epochs, tf.float32) 15 | self.max_lr = tf.cast(max_lr, tf.float32) 16 | self.min_lr = tf.cast(min_lr, tf.float32) 17 | self.diff_lr = max_lr - min_lr 18 | 19 | def warmup_rate(self): 20 | return self.diff_lr * self.step / self.warmup_steps + self.min_lr 21 | 22 | def cosine_rate(self): 23 | return self.diff_lr * ((tf.math.cos(self.step-self.warmup_steps*np.pi / self.epochs - self.warmup_steps) + 1) / 2) + self.min_lr 24 | 25 | def get_config(self): 26 | config = { 27 | # 'epochs': self.epochs, 28 | # 'warmup_steps': self.warmup_steps, 29 | # 'max_lr': self.max_lr, 30 | # 'min_lr': self.min_lr, 31 | # 'diff_lr': self.diff_lr, 32 | # 'rate': self.rate 33 | } 34 | return config 35 | 36 | 37 | def __call__(self, step): 38 | self.step = step 39 | rate = tf.cond(step < self.warmup_steps, self.warmup_rate, self.cosine_rate) 40 | self.rate = rate 41 | return rate 42 | 43 | def CsiNet(img_channels, img_height, img_width, encoded_dim, encoder_in=None, residual_num=2, aux=None, encoded_in=None, data_format="channels_last",name=None,out_activation='tanh'): 44 | 45 | # Bulid the autoencoder model of CsiNet 46 | def residual_network(x, residual_num, encoded_dim, aux): 47 | img_total = img_channels*img_height*img_width 48 | 49 | def add_common_layers(y): 50 | y = BatchNormalization(axis=1)(y) 51 | y = LeakyReLU()(y) 52 | return y 53 | 54 | def residual_block_decoded(y): 55 | y = Conv2D(128, kernel_size=(1, 1), padding='same',data_format='channels_first',name="deconv1")(y) 56 | y = add_common_layers(y) 57 | y = Conv2D(64, kernel_size=(1, 1), padding='same',data_format='channels_first',name="deconv2")(y) 58 | y = add_common_layers(y) 59 | y = Conv2D(32, kernel_size=(3, 3), padding='same',data_format='channels_first',name="deconv3")(y) 60 | y = add_common_layers(y) 61 | y = Conv2D(32, kernel_size=(3, 3), padding='same',data_format='channels_first',name="deconv4")(y) 62 | y = add_common_layers(y) 63 | y = Conv2D(16, kernel_size=(3, 3), padding='same',data_format='channels_first',name="deconv5")(y) 64 | y = add_common_layers(y) 65 | y = Conv2D(16, kernel_size=(3, 3), padding='same',data_format='channels_first',name="deconv6")(y) 66 | y = add_common_layers(y) 67 | y = Conv2D(2, (3, 3), activation=out_activation, padding='same',data_format='channels_first',name="predict")(y) 68 | return y 69 | 70 | # if encoder_in: 71 | x = Conv2D(8, (3, 3), padding='same', data_format=data_format, name='CR2_conv2d_1')(x) 72 | x = add_common_layers(x) 73 | x = Conv2D(16, (3, 3), padding='same', data_format=data_format, name='CR2_conv2d_2')(x) 74 | x = add_common_layers(x) 75 | x = Conv2D(2, (3, 3), padding='same', data_format=data_format, name='CR2_conv2d_3')(x) 76 | x = add_common_layers(x) 77 | 78 | x = Reshape((img_total,), name='CR2_reshape')(x) 79 | encoded = Dense(encoded_dim, activation='linear', name='CR2_dense')(x) 80 | # else: 81 | # x = Conv2D(2, (3, 3), padding='same', data_format=data_format)(x) 82 | # x = add_common_layers(x) 83 | 84 | # x = Reshape((img_total,))(x) 85 | # encoded = Dense(encoded_dim, activation='linear')(x) 86 | print("Aux check: {}".format(aux)) 87 | tens_type = type(x) 88 | if type(aux) == tens_type: 89 | x = Dense(img_total, activation='linear')(concatenate([aux,encoded])) 90 | else: 91 | x = Dense(img_total, activation='linear')(encoded) 92 | # reshape based on data_format 93 | if(data_format == "channels_first"): 94 | x = Reshape((img_channels, img_height, img_width,))(x) 95 | elif(data_format == "channels_last"): 96 | x = Reshape((img_height, img_width, img_channels,))(x) 97 | 98 | x = residual_block_decoded(x) 99 | 100 | return [x, encoded] 101 | 102 | if(data_format == "channels_last"): 103 | image_tensor = Input((img_height, img_width, img_channels)) 104 | elif(data_format == "channels_first"): 105 | image_tensor = Input((img_channels, img_height, img_width)) 106 | else: 107 | print("Unexpected tensor_shape param in CsiNet input.") 108 | # raise Exception 109 | # image_tensor = Input((img_channels, img_height, img_width)) 110 | [network_output, encoded] = residual_network(image_tensor, residual_num, encoded_dim, aux) 111 | print('network_output: {} - encoded: {} - aux: {}'.format(network_output, encoded, aux)) 112 | tens_type = type(image_tensor) 113 | print('image_tensor.dtype: {}'.format(tens_type)) 114 | print('type(aux): {}'.format(type(aux))) 115 | if type(aux) == tens_type: 116 | autoencoder = Model(inputs=[aux,image_tensor], outputs=[network_output,encoded]) 117 | else: 118 | autoencoder = Model(inputs=[image_tensor], outputs=[network_output, encoded]) 119 | if encoder_in: 120 | autoencoder.load_weights(by_name=True) 121 | return [autoencoder, encoded] 122 | -------------------------------------------------------------------------------- /csinet-lstm/csinet_lstm.py: -------------------------------------------------------------------------------- 1 | # CsiNet_LSTM.py 2 | 3 | import tensorflow as tf 4 | # try: 5 | from tensorflow.keras.layers import concatenate, Lambda, Dense, BatchNormalization, Reshape, Conv2D, add, LeakyReLU, LSTM, CuDNNLSTM, ConvLSTM2D 6 | from tensorflow.keras import Input 7 | from tensorflow.keras.models import Model, model_from_json 8 | from tensorflow.keras.callbacks import TensorBoard, Callback 9 | from tensorflow.keras.utils import plot_model 10 | from tensorflow.keras import initializers 11 | from tensorflow.keras.optimizers import Adam 12 | import scipy.io as sio 13 | import numpy as np 14 | import math 15 | import time 16 | from csinet import * 17 | 18 | # image params 19 | img_height = 32 20 | img_width = 32 21 | img_channels = 2 22 | img_total = img_height*img_width*img_channels 23 | # network params 24 | residual_num = 2 25 | # encoded_dim = 512 #compress rate=1/4->dim.=512, compress rate=1/16->dim.=128, compress rate=1/32->dim.=64, compress rate=1/64->dim.=32 26 | 27 | def get_file(envir,encoded_dim,train_date): 28 | file = 'CsiNet_'+(envir)+'_dim'+str(encoded_dim)+'_'+train_date 29 | return "result/model_%s.h5"%file 30 | 31 | def make_CsiNet(aux_bool,M_1, img_channels, img_height, img_width, encoded_dim, data_format, lo_bool=False): 32 | if aux_bool: 33 | aux = Input((M_1,)) 34 | else: 35 | aux = None 36 | # build CsiNet 37 | out_activation = 'tanh' 38 | autoencoder, encoded = CsiNet(img_channels, img_height, img_width, encoded_dim, aux=aux, data_format=data_format, out_activation=out_activation) # CSINet with M_1 dimensional latent space 39 | # autoencoder = Model(inputs=autoencoder.inputs,outputs=autoencoder.outputs[0]) 40 | prediction = autoencoder.outputs[0] 41 | encoded = autoencoder.outputs[1] 42 | if (lo_bool): 43 | model = Model(inputs=autoencoder.inputs,outputs=prediction) 44 | else: 45 | model = Model(inputs=autoencoder.inputs,outputs=[encoded,prediction]) 46 | # return [autoencoder, encoded] 47 | # optimizer = Adam() 48 | # model.compile(optimizer=optimizer, loss='mse') 49 | return model 50 | 51 | def CsiNet_LSTM(img_channels, img_height, img_width, T, M_1, M_2, envir="indoor", LSTM_depth=3,data_format='channels_first',t1_trainable=False,t2_trainable=True,pre_t1_bool=True,pre_t2_bool=True,aux_bool=True, share_bool=True, pass_through_bool=False, lstm_latent_bool=False, pre_lstm_bool=True, conv_lstm_bool=False, subnetwork_spec=".", pretrained_bool=True, LSTM_only_bool=False): 52 | # base CSINet models 53 | aux=Input((M_1,)) 54 | if(data_format == "channels_last"): 55 | x = Input((T, img_height, img_width, img_channels)) 56 | elif(data_format == "channels_first"): 57 | x = Input((T, img_channels, img_height, img_width)) 58 | else: 59 | print("Unexpected data_format param in CsiNet input.") # raise an exception eventually. For now, print a complaint 60 | if (not LSTM_only_bool): 61 | CsiNet_hi = make_CsiNet(aux_bool, M_1, img_channels, img_height, img_width, M_1, data_format) 62 | if (pretrained_bool): 63 | # if envir == "indoor": 64 | # config_hi = 'config/indoor0001/v2/angular/csinet_cr512.json' 65 | # elif envir == "outdoor": 66 | # config_hi = 'config/outdoor300/v2/csinet_cr512.json' 67 | # else: 68 | # print("Invalid environment variable.") 69 | # dim, date, model_dir = unpack_compact_json(config_hi) 70 | # network_name = get_network_name(config_hi) 71 | # CsiNet_hi = load_weights_into_model(network_name,model_dir,CsiNet_hi) 72 | weights_file = f"{subnetwork_spec[0]}/cr{M_1}/{subnetwork_spec[1]}.h5" 73 | CsiNet_hi.load_weights(weights_file) 74 | CsiNet_hi._name = "CsiNet_hi" 75 | CsiNet_hi.trainable = t1_trainable 76 | print("--- High Dimensional (M_1) Latent Space CsiNet ---") 77 | CsiNet_hi.summary() 78 | print('CsiNet_hi.inputs: {}'.format(CsiNet_hi.inputs)) 79 | print('CsiNet_hi.outputs: {}'.format(CsiNet_hi.outputs)) 80 | # TO-DO: split large input tensor to use as inputs to 1:T CSINets 81 | CsiOut = [] 82 | CsiOut_temp = [] 83 | for i in range(T): 84 | CsiIn = Lambda( lambda x: x[:,i,:,:,:])(x) 85 | print('#{}: CsiIn: {}'.format(i,CsiIn)) 86 | if i == 0: 87 | # use CsiNet_hi for t=1 88 | EncodedLayer, OutLayer = CsiNet_hi([aux,CsiIn]) 89 | print('EncodedLayer: {}'.format(EncodedLayer)) 90 | else: 91 | # choose whether or not to share parameters between low-dimensional timeslots 92 | if (i==1 or not share_bool): 93 | CsiNet_lo = make_CsiNet(aux_bool, M_1, img_channels, img_height, img_width, M_2, data_format, lo_bool=True) 94 | if (pretrained_bool): 95 | # if envir == "indoor": 96 | # config_lo = 'config/indoor0001/v2/angular/csinet_cr{}.json'.format(M_2) 97 | # elif envir == "outdoor": 98 | # config_lo = 'config/outdoor300/v2/csinet_cr{}.json'.format(M_2) 99 | # else: 100 | # print("Invalid environment variable.") 101 | weights_file = f"{subnetwork_spec[0]}/cr{M_2}/{subnetwork_spec[1]}.h5" 102 | CsiNet_lo.load_weights(weights_file) 103 | # dim, date, model_dir = unpack_compact_json(config_lo) 104 | # network_name = get_network_name(config_lo) 105 | # CsiNet_lo = load_weights_into_model(network_name,model_dir,CsiNet_lo) 106 | CsiNet_lo.trainable = t2_trainable 107 | CsiNet_lo._name = "CsiNet_lo_{}".format(i) 108 | print('CsiNet_lo.inputs: {}'.format(CsiNet_lo.inputs)) 109 | print('CsiNet_lo.outputs: {}'.format(CsiNet_lo.outputs)) 110 | if i==1: 111 | print("--- Low Dimensional (M_2) Latent Space CsiNet ---") 112 | CsiNet_lo.summary() 113 | if aux_bool: 114 | OutLayer = CsiNet_lo([EncodedLayer,CsiIn]) 115 | else: 116 | # use CsiNet_lo for t in [2:T] 117 | OutLayer = CsiNet_lo(CsiIn) 118 | print('#{} - OutLayer: {}'.format(i, OutLayer)) 119 | if data_format == "channels_last": 120 | CsiOut.append(Reshape((1,img_height,img_width,img_channels))(OutLayer)) 121 | if data_format == "channels_first": 122 | CsiOut.append(Reshape((1,img_channels,img_height,img_width))(OutLayer)) 123 | # for the moment, we don't handle separate case of loading convLSTM 124 | LSTM_in = concatenate(CsiOut,axis=1) 125 | else: 126 | LSTM_in = x # skip CsiNets 127 | # lstm_config = 'config/indoor0001/lstm_depth3_opt.json' 128 | print('--- Non-convolutional recurrent activations ---') 129 | LSTM_model = stacked_LSTM(img_channels, img_height, img_width, T, lstm_latent_bool, LSTM_depth=LSTM_depth,data_format=data_format) 130 | # comment back in to load in weights 131 | # if (pretrained_bool): 132 | # lstm_config = 'config/outdoor300/lstm_depth3_opt.json' 133 | # dim, date, model_dir = unpack_compact_json(lstm_config) 134 | # network_name = get_network_name(lstm_config) 135 | # LSTM_model = load_weights_into_model(network_name,model_dir,LSTM_model) # option to load weights; try random initialization for the network 136 | print(LSTM_model.summary()) 137 | 138 | print('LSTM_in.shape: {}'.format(LSTM_in.shape)) 139 | LSTM_out = LSTM_model(LSTM_in) 140 | 141 | # compile full model with large 4D tensor as input and LSTM 4D tensor as output 142 | if LSTM_only_bool: 143 | full_model = Model(inputs=[LSTM_in], outputs=[LSTM_out]) 144 | else: 145 | if pass_through_bool: 146 | full_model = Model(inputs=[aux,x], outputs=[LSTM_in]) 147 | else: 148 | full_model = Model(inputs=[aux,x], outputs=[LSTM_out]) 149 | full_model.compile(optimizer='adam', loss='mse') 150 | full_model.summary() 151 | return full_model 152 | 153 | def split_CsiNet(model, CR): 154 | # split model into encoder and decoder 155 | layers = [] 156 | for layer in model.layers: 157 | # print("layer.name: {} - type(layer): {}".format(layer.name, type(layer))) 158 | # layers.append(layer) 159 | # if 'dense' in layer.name: 160 | # print('Dense layer "{}" has output shape {}'.format(layer.name,layer.output_shape)) 161 | if layer.output_shape == (None,CR): 162 | print('Feedback layer "{}"'.format(layer.name)) 163 | feedback_layer_output = layer.output # take feedback layer as output of decoder 164 | elif 'dense' in layer.name: 165 | enc_input = layer.input # get concatenate layer's dimension to generate new inp for encoder 166 | dec_input = model.input 167 | # enc_input = Input((enc_in_dim)) 168 | dec_model = Model(inputs=[dec_input],outputs=[feedback_layer_output]) 169 | enc_model = Model(inputs=[enc_input],outputs=[model.output]) 170 | 171 | def stacked_LSTM(img_channels, img_height, img_width, T, lstm_latent_bool, LSTM_depth=3, plot_bool=False, data_format="channels_first",kernel_initializer=initializers.glorot_uniform(seed=100), recurrent_initializer=initializers.orthogonal(seed=100)): 172 | # assume entire time-series of CSI from 1:T is concatenated 173 | LSTM_dim = img_channels*img_height*img_width 174 | if(data_format == "channels_last"): 175 | orig_shape = (T, img_height, img_width, img_channels) 176 | elif(data_format == "channels_first"): 177 | orig_shape = (T, img_channels, img_height, img_width) 178 | x = Input(shape=orig_shape) 179 | LSTM_tup = (T,LSTM_dim) 180 | recurrent_out = Reshape(LSTM_tup)(x) 181 | for i in range(LSTM_depth): 182 | # By default, orthogonal/glorot_uniform initializers for recurrent/kernel 183 | # recurrent_out = LSTM(LSTM_dim, return_sequences=True, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, stateful=False)(recurrent_out) 184 | recurrent_out = CuDNNLSTM(LSTM_dim, return_sequences=True, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, stateful=False)(recurrent_out) # CuDNNLSTM does not support recurrent activations; switch back to vanilla LSTM in meantime 185 | # print("Dim of LSTM #{} - {}".format(i+1,recurrent_out.shape)) 186 | out = Reshape(orig_shape)(recurrent_out) 187 | LSTM_model = Model(inputs=[x], outputs=[out]) 188 | return LSTM_model 189 | 190 | def add_common_layers(y): 191 | y = BatchNormalization()(y) 192 | y = LeakyReLU()(y) 193 | return y 194 | -------------------------------------------------------------------------------- /csinet-lstm/csinet_lstm_eval_cosine.py: -------------------------------------------------------------------------------- 1 | # CsiNet_LSTM_eval_cosine.py 2 | # script for evaluating cosine similarity of CsiNet-LSTM 3 | 4 | if __name__ == "__main__": 5 | import argparse 6 | import os 7 | import copy 8 | import sys 9 | import pickle 10 | sys.path.append("/home/mdelrosa/git/brat") 11 | from utils.NMSE_performance import calc_NMSE, get_NMSE, denorm_H3, renorm_H4, denorm_H4, denorm_sphH4 12 | from utils.cosine_sim_performance import cosine_similarity, cosine_similarity_mat 13 | from utils.data_tools import dataset_pipeline_full, subsample_batches, split_complex 14 | from utils.parsing import str2bool 15 | from utils.timing import Timer 16 | from utils.unpack_json import get_keys_from_json 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("-d", "--debug_flag", type=int, default=0, help="flag for toggling debugging mode") 19 | parser.add_argument("-b", "--n_batch", type=int, default=20, help="number of batches to fit on (ignored during debug mode)") 20 | parser.add_argument("-l", "--dir", type=str, default=None, help="subdirectory for saving model, checkpoint, history") 21 | parser.add_argument("-e", "--env", type=str, default="indoor", help="environment (either indoor or outdoor)") 22 | parser.add_argument("-ep", "--epochs", type=int, default=10, help="number of epochs to train for") 23 | parser.add_argument("-t", "--train_argv", type=str2bool, default=True, help="flag for toggling training") 24 | parser.add_argument("-g", "--n_gpu", type=int, default=1, help="index of gpu for training") 25 | parser.add_argument("-r", "--rate", type=int, default=512, help="number of elements in latent code (i.e., encoding rate)") 26 | parser.add_argument("-de", "--depth", type=int, default=3, help="depth of lstm") 27 | parser.add_argument("-p", "--pretrained_bool", type=str2bool, default=True, help="bool for using pretrained CsiNet for each timeslot") 28 | parser.add_argument("-lo", "--load_bool", type=str2bool, default=False, help="bool for loading weights into CsiNet-LSTM network") 29 | parser.add_argument("-a", "--aux_bool", type=str2bool, default=True, help="bool for building CsiNet with auxiliary input") 30 | parser.add_argument("-m", "--aux_size", type=int, default=512, help="integer for auxiliary input's latent rate") 31 | parser.add_argument("-sr", "--stride", type=int, default=1, help="space between timeslots for each step (default 1); controls feedback interval") 32 | parser.add_argument("-v", "--viz_batch", type=int, default=-1, help="index of element to save for visualization") 33 | parser.add_argument("-nc", "--n_carriers", type=int, default=128, help="num carriers to test cosine similarity against") 34 | opt = parser.parse_args() 35 | 36 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"; # The GPU id to use, usually either "0" or "1"; 37 | os.environ["CUDA_VISIBLE_DEVICES"]="{}".format(opt.n_gpu); # Do other imports now... 38 | print("debug_flag: {} -- train_argv: {}".format(opt.debug_flag, opt.train_argv)) 39 | 40 | if opt.env == "indoor": 41 | json_config = '../config/csinet_lstm_indoor_cost2100_full.json' # 0 epochs 42 | elif opt.env == "outdoor": 43 | json_config = '../config/csinet_lstm_outdoor_cost2100_full.json' # 0 epochs 44 | # quant_config = "../config/quant/10bits.json 45 | 46 | M_1, data_format, network_name, subnetwork_name, model_dir, norm_range, minmax_file, share_bool, T, dataset_spec, diff_spec, batch_num, lr, batch_size, subsample_prop, thresh_idx_path = get_keys_from_json(json_config, keys=['M_1', 'df', 'network_name', 'subnetwork_name', 'model_dir', 'norm_range', 'minmax_file', 'share_bool', 'T', 'dataset_spec', 'diff_spec', 'batch_num', 'lr', 'batch_size', 'subsample_prop', 'thresh_idx_path']) 47 | aux_bool, quant_bool, LSTM_only_bool, pass_through_bool, t1_train, t2_train, lstm_latent_bool = get_keys_from_json(json_config, keys=['aux_bool', 'quant_bool', 'LSTM_only_bool', 'pass_through_bool', 't1_train', 't2_train', 'lstm_latent_bool'],is_bool=True) # import these as booleans rather than int, str 48 | 49 | import scipy.io as sio 50 | import numpy as np 51 | import math 52 | import time 53 | import sys 54 | # import os 55 | try: 56 | from tensorflow.keras.optimizers import Adam 57 | from tensorflow.keras.callbacks import TensorBoard, Callback, ModelCheckpoint, EarlyStopping 58 | except: 59 | import keras 60 | from keras.optimizers import Adam 61 | from keras.callbacks import TensorBoard, Callback, ModelCheckpoint 62 | from tensorflow.core.protobuf import rewriter_config_pb2 63 | from csinet_lstm import * 64 | # from QuantizeData import quantize, get_minmax 65 | 66 | def reset_keras(): 67 | sess = tf.keras.backend.get_session() 68 | tf.keras.backend.clear_session() 69 | sess.close() 70 | # limit gpu resource allocation 71 | try: 72 | config = tf.compat.v1.ConfigProto() 73 | except: 74 | config = tf.ConfigProto() 75 | # config.gpu_options.visible_device_list = '1' 76 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 77 | 78 | # disable arithmetic optimizer 79 | off = rewriter_config_pb2.RewriterConfig.OFF 80 | config.graph_options.rewrite_options.arithmetic_optimization = off 81 | 82 | try: 83 | session = tf.compat.v1.Session(config=config) 84 | tf.compat.v1.keras.backend.set_session(session) 85 | except: 86 | session = tf.Session(config=config) 87 | keras.backend.set_session(session) 88 | # tf.global_variables_initializer() 89 | 90 | reset_keras() 91 | 92 | # fit params 93 | # image params 94 | img_height = 32 95 | img_width = 32 96 | img_channels = 2 97 | img_total = img_height*img_width*img_channels 98 | 99 | # Data loading 100 | batch_num = 1 if opt.debug_flag else batch_num # we'll use batch_num-1 for training and 1 for validation 101 | epochs = 1 if opt.debug_flag else opt.epochs 102 | 103 | pow_diff, data_train, data_val = dataset_pipeline_full(batch_num, opt.debug_flag, aux_bool, dataset_spec, diff_spec, M_1, T = T, img_channels = img_channels, img_height = img_height, img_width = img_width, train_argv = True, mode = "truncate") 104 | 105 | print(f"--- Load full CSI matrices truncated to {opt.n_carriers} carriers ---") 106 | pow_diff, data_train_full, data_val_full = dataset_pipeline_full(batch_num, opt.debug_flag, aux_bool, dataset_spec, diff_spec, M_1, T = T, img_channels = img_channels, img_height = img_height, img_width = img_width, train_argv = True, mode = "full", n_truncate=opt.n_carriers) 107 | # make freq-spatial domain from angular-spatial doman 108 | print(f"--- Convert angular-spatial to freq-spatial {opt.n_carriers} carriers ---") 109 | x_val_full_freq = np.fft.fft(data_val_full.view("complex"), axis=2) # delay-spatial -> freq-spatial 110 | 111 | # loading directly from unnormalized data; normalize data 112 | aux_val, x_val = data_val 113 | x_val = split_complex(x_val.view("complex"),T=T) 114 | print(f"--- x_val min={np.min(x_val)}, max={np.max(x_val)} ---") 115 | x_val = renorm_H4(x_val,minmax_file) 116 | data_val = aux_val, x_val 117 | print('-> post-renorm: x_val range is from {} to {}'.format(np.min(x_val),np.max(x_val))) 118 | 119 | aux_train, x_train = data_train 120 | if opt.train_argv: 121 | x_train = split_complex(x_train.view("complex"),T=T) 122 | x_train = renorm_H4(x_train,minmax_file) 123 | print(f"--- x_train min={np.min(x_train)}, max={np.max(x_train)} ---") 124 | data_train = [aux_train, x_train] 125 | print('-> post-renorm: x_train range is from {} to {}'.format(np.min(x_train),np.max(x_train))) 126 | 127 | CR_list = [128] 128 | # CR_list = [512, 256, 128, 64, 32] 129 | for M_2 in CR_list: 130 | # M_1, M_2 = opt.aux_size, opt.rate 131 | M_1 = opt.aux_size 132 | reset_keras() 133 | try: 134 | optimizer = Adam(learning_rate=lr, beta_1=0.9, beta_2=0.999) 135 | except: 136 | optimizer = Adam(lr=lr, beta_1=0.9, beta_2=0.999) 137 | print('-------------------------------------') 138 | print("Build CsiNet-LSTM for CR2={}".format(M_2)) 139 | print('-------------------------------------') 140 | 141 | # def callbacks 142 | class LossHistory(Callback): 143 | def on_train_begin(self, logs={}): 144 | self.losses_train = [] 145 | self.losses_val = [] 146 | 147 | def on_batch_end(self, batch, logs={}): 148 | self.losses_train.append(logs.get('loss')) 149 | 150 | def on_epoch_end(self, epoch, logs={}): 151 | self.losses_val.append(logs.get('val_loss')) 152 | 153 | outpath_base = f"{model_dir}/{opt.env}" 154 | if opt.dir != None: 155 | outpath_base += "/" + opt.dir 156 | outfile_base = f"{outpath_base}/cr{M_2}/{network_name}" 157 | subnetwork_spec = [outpath_base, subnetwork_name] 158 | 159 | # if opt.load_bool: 160 | # CsiNet_LSTM_model = CsiNet_LSTM(img_channels, img_height, img_width, T, M_1, M_2, envir=opt.env, LSTM_depth=opt.depth, data_format=data_format, t1_trainable=t1_train, t2_trainable=t2_train, share_bool=share_bool, pass_through_bool=pass_through_bool, LSTM_only_bool=LSTM_only_bool, subnetwork_spec=subnetwork_spec, pretrained_bool=opt.pretrained_bool) 161 | # CsiNet_LSTM_model.load_weights(f"{outfile_base}.h5") 162 | # print ("--- Pre-loaded network performance is... ---") 163 | # x_hat = CsiNet_LSTM_model.predict(data_val) 164 | 165 | # print("For Adam with lr={:1.1e} // batch_size={} // norm_range={}".format(lr,batch_size,norm_range)) 166 | # print("x_hat.dtype: {}".format(x_hat.dtype)) # sanity check on output datatype 167 | # if norm_range == "norm_H3": 168 | # x_hat_denorm = denorm_H3(x_hat,minmax_file) 169 | # x_val_denorm = denorm_H3(x_val,minmax_file) 170 | # if norm_range == "norm_H4": 171 | # x_hat_denorm = denorm_H4(x_hat,minmax_file) 172 | # x_val_denorm = denorm_H4(x_val,minmax_file) 173 | # print('-> x_hat range is from {} to {}'.format(np.min(x_hat_denorm),np.max(x_hat_denorm))) 174 | # print('-> x_val range is from {} to {} '.format(np.min(x_val_denorm),np.max(x_val_denorm))) 175 | # calc_NMSE(x_hat_denorm,x_val_denorm,T=T) 176 | # else: 177 | CsiNet_LSTM_model = CsiNet_LSTM(img_channels, img_height, img_width, T, M_1, M_2, LSTM_depth=opt.depth, data_format=data_format, t1_trainable=t1_train, t2_trainable=t2_train, share_bool=share_bool, pass_through_bool=pass_through_bool, LSTM_only_bool=LSTM_only_bool, subnetwork_spec=subnetwork_spec, pretrained_bool=opt.pretrained_bool) 178 | if opt.load_bool: 179 | print(f"--- Loading weights from {outfile_base}.h5 ---") 180 | CsiNet_LSTM_model.load_weights(f"{outfile_base}.h5") 181 | CsiNet_LSTM_model.compile(optimizer=optimizer, loss='mse') 182 | 183 | if (opt.train_argv): 184 | # save+serialize model to JSON 185 | model_json = CsiNet_LSTM_model.to_json() 186 | # outfile = f"{model_dir}/{opt.dir}/{network_name}_{opt.env}.json" 187 | with open(f"{outfile_base}.json", "w") as json_file: 188 | json_file.write(model_json) 189 | # serialize weights to HDF5 190 | # outfile = f"{model_dir}/model_{file}.h5" 191 | checkpoint = ModelCheckpoint(f"{outfile_base}_full.h5", monitor="val_loss",verbose=1,save_best_only=True,mode="min") 192 | early = EarlyStopping(monitor="val_loss", patience=50,verbose=1) 193 | 194 | history = LossHistory() 195 | 196 | CsiNet_LSTM_model.fit(data_train, x_train, 197 | epochs=epochs, 198 | batch_size=batch_size, 199 | shuffle=True, 200 | validation_data=(data_val, x_val), 201 | callbacks=[checkpoint, 202 | # early, 203 | history]) 204 | # TensorBoard(log_dir = path), 205 | 206 | filename = f'{outfile_base}_trainloss.csv' 207 | loss_history = np.array(history.losses_train) 208 | np.savetxt(filename, loss_history, delimiter=",") 209 | 210 | filename = f'{outfile_base}_valloss.csv' 211 | loss_history = np.array(history.losses_val) 212 | np.savetxt(filename, loss_history, delimiter=",") 213 | else: 214 | CsiNet_LSTM_model.load_weights(f"{outfile_base}_full.h5") 215 | 216 | #Testing data 217 | tStart = time.time() 218 | x_hat = CsiNet_LSTM_model.predict(data_val) 219 | tEnd = time.time() 220 | print ("It cost %f sec per sample (%f samples)" % ((tEnd - tStart)/x_val.shape[0],x_val.shape[0])) 221 | 222 | print("For Adam with lr={:1.1e} // batch_size={} // norm_range={}".format(lr,batch_size,norm_range)) 223 | if norm_range == "norm_H3": 224 | x_hat_denorm = denorm_H3(x_hat,minmax_file) 225 | x_val_denorm = denorm_H3(x_val,minmax_file) 226 | elif norm_range == "norm_H4": 227 | x_hat_denorm = denorm_H4(x_hat,minmax_file) 228 | x_val_denorm = denorm_H4(x_val,minmax_file) 229 | print('-> x_hat range is from {} to {}'.format(np.min(x_hat_denorm),np.max(x_hat_denorm))) 230 | print('-> x_val range is from {} to {} '.format(np.min(x_val_denorm),np.max(x_val_denorm))) 231 | 232 | #TODO: validate this pow_diff behavior 233 | pow_val = pow_diff[x_train.shape[0]:,:,:] 234 | calc_NMSE(x_hat_denorm,x_val_denorm,T=T,pow_diff=pow_val) 235 | 236 | #TODO: change this to load gt freq domain data 237 | #TODO: change this to append zeros to estimate delay domain data 238 | 239 | x_zeros = np.zeros((x_hat.shape[0], T, 2, opt.n_carriers-img_height, img_width)) 240 | x_hat_denorm = np.concatenate((x_hat_denorm, x_zeros), axis=3) 241 | x_val_denorm = np.concatenate((x_val_denorm, x_zeros), axis=3) 242 | x_hat_denorm = x_hat_denorm[:,:,0,:,:] + 1j*x_hat_denorm[:,:,1,:,:] 243 | x_val_denorm = x_val_denorm[:,:,0,:,:] + 1j*x_val_denorm[:,:,1,:,:] 244 | x_hat_freq = np.fft.fft(np.fft.fft(x_hat_denorm, axis=2), axis=3) 245 | x_val_freq = np.fft.fft(np.fft.fft(x_val_denorm, axis=2), axis=3) 246 | # rho_truncate, rho_all = cosine_similarity(x_hat_freq, x_val_freq, pow_diff_T=pow_val) 247 | # print(f"--- rho_truncate = {rho_truncate:6.5f}, rho_all = {rho_all:6.5f} ---") 248 | rho_truncate = cosine_similarity_mat(x_hat_freq, x_val_freq) 249 | print(f"--- rho_truncate = {rho_truncate:6.5f} ---") 250 | 251 | # now, report cosine similarity with gt 252 | rho_truncate_full = cosine_similarity_mat(x_hat_freq, x_val_full_freq) 253 | print(f"--- rho_truncate_full = {rho_truncate_full:6.5f} ---") 254 | 255 | if opt.viz_batch > -1 and not opt.train_argv: 256 | print(f"=== Saving input/output batch {opt.viz_batch} from validation set ===") 257 | # save input/output of validation batch for visualization 258 | viz_dict = { 259 | "input": x_val_denorm[opt.viz_batch, :, :, :, :], 260 | "output": x_hat_denorm[opt.viz_batch, :, :, :, :] 261 | } 262 | 263 | with open(f"{outfile_base}_batch{opt.viz_batch}.pkl", "wb") as f: 264 | pickle.dump(viz_dict, f) 265 | f.close() -------------------------------------------------------------------------------- /csinet-lstm/csinet_lstm_quant.py: -------------------------------------------------------------------------------- 1 | 2 | if __name__ == "__main__": 3 | import argparse 4 | import os 5 | import copy 6 | import sys 7 | sys.path.append("/home/mdelrosa/git/brat") 8 | from utils.NMSE_performance import calc_NMSE, get_NMSE, denorm_H3, renorm_H4, denorm_H4, denorm_sphH4 9 | from utils.data_tools import dataset_pipeline_col, subsample_batches 10 | from utils.parsing import str2bool 11 | from utils.timing import Timer 12 | from utils.unpack_json import get_keys_from_json 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("-d", "--debug_flag", type=int, default=0, help="flag for toggling debugging mode") 15 | parser.add_argument("-b", "--n_batch", type=int, default=20, help="number of batches to fit on (ignored during debug mode)") 16 | parser.add_argument("-l", "--dir", type=str, default=None, help="subdirectory for saving model, checkpoint, history") 17 | parser.add_argument("-e", "--env", type=str, default="indoor", help="environment (either indoor or outdoor)") 18 | parser.add_argument("-ep", "--epochs", type=int, default=10, help="number of epochs to train for") 19 | parser.add_argument("-t", "--train_argv", type=str2bool, default=True, help="flag for toggling training") 20 | parser.add_argument("-g", "--n_gpu", type=int, default=1, help="index of gpu for training") 21 | parser.add_argument("-r", "--rate", type=int, default=512, help="number of elements in latent code (i.e., encoding rate)") 22 | parser.add_argument("-de", "--depth", type=int, default=3, help="depth of lstm") 23 | parser.add_argument("-p", "--pretrained_bool", type=str2bool, default=True, help="bool for using pretrained CsiNet for each timeslot") 24 | parser.add_argument("-lo", "--load_bool", type=str2bool, default=False, help="bool for loading weights into CsiNet-LSTM network") 25 | parser.add_argument("-a", "--aux_bool", type=str2bool, default=True, help="bool for building CsiNet with auxiliary input") 26 | parser.add_argument("-m", "--aux_size", type=int, default=512, help="integer for auxiliary input's latent rate") 27 | parser.add_argument("-sr", "--stride", type=int, default=1, help="space between timeslots for each step (default 1); controls feedback interval") 28 | parser.add_argument("-q", "--quantization_bits", type=int, default=5, help="quantization bits per value") 29 | parser.add_argument("-i", "--t1_bits", type=int, default=8, help="quantization bits for first timeslot") 30 | parser.add_argument("-ql", "--quan_lam", type=float, default=1e-9, help="quantization regularizer") 31 | opt = parser.parse_args() 32 | 33 | quan_lam=opt.quan_lam 34 | dynamic_range_i = 2**(opt.quantization_bits - 1) 35 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"; # The GPU id to use, usually either "0" or "1"; 36 | os.environ["CUDA_VISIBLE_DEVICES"]="{}".format(opt.n_gpu); # Do other imports now... 37 | print("debug_flag: {} -- train_argv: {}".format(opt.debug_flag, opt.train_argv)) 38 | 39 | if opt.env == "indoor": 40 | # if opt.rate != 512: 41 | # json_config = 'config/indoor0001/T10/replication/angular/csinet_lstm_v2_CR{}.json'.format(CR_arg) # 500 epochs 42 | json_config = '../config/csinet_lstm_indoor_cost2100.json' 43 | temp_config = lstm_config = None 44 | # else: 45 | # # json_config = 'config/indoor0001/T10/replication/angular/csinet_lstm_v2_CR{}_best.json'.format(CR_arg) # 500 epochs 46 | # json_config = '../config/csinet_lstm_outdoor_cost2100.json' # 0 epochs 47 | # temp_config, lstm_config = get_keys_from_json(json_config, keys=['temp_config', 'lstm_config']) 48 | elif opt.env == "outdoor": 49 | json_config = '../config/csinet_lstm_outdoor_cost2100.json' 50 | # json_config = 'config/outdoor300/T10/csinet_lstm_v2_CR{}.json'.format(CR_arg) # Depth 3 51 | temp_config = lstm_config = None 52 | 53 | M_1, data_format, network_name, subnetwork_name, model_dir, norm_range, minmax_file, share_bool, T, dataset_spec, diff_spec, batch_num, lr, batch_size, subsample_prop, thresh_idx_path = get_keys_from_json(json_config, keys=['M_1', 'df', 'network_name', 'subnetwork_name', 'model_dir', 'norm_range', 'minmax_file', 'share_bool', 'T', 'dataset_spec', 'diff_spec', 'batch_num', 'lr', 'batch_size', 'subsample_prop', 'thresh_idx_path']) 54 | aux_bool, quant_bool, LSTM_only_bool, pass_through_bool, t1_train, t2_train, lstm_latent_bool = get_keys_from_json(json_config, keys=['aux_bool', 'quant_bool', 'LSTM_only_bool', 'pass_through_bool', 't1_train', 't2_train', 'lstm_latent_bool'],is_bool=True) # import these as booleans rather than int, str 55 | 56 | import scipy.io as sio 57 | import numpy as np 58 | import math 59 | import time 60 | import sys 61 | # import os 62 | try: 63 | from tensorflow.keras.optimizers import Adam 64 | from tensorflow.keras.callbacks import TensorBoard, Callback, ModelCheckpoint, EarlyStopping 65 | except: 66 | import keras 67 | from keras.optimizers import Adam 68 | from keras.callbacks import TensorBoard, Callback, ModelCheckpoint 69 | from tensorflow.core.protobuf import rewriter_config_pb2 70 | from csinet_lstm import * 71 | from csinet_quant import CsiNet_quant 72 | # from QuantizeData import quantize, get_minmax 73 | 74 | def reset_keras(): 75 | sess = tf.keras.backend.get_session() 76 | tf.keras.backend.clear_session() 77 | sess.close() 78 | # limit gpu resource allocation 79 | try: 80 | config = tf.compat.v1.ConfigProto() 81 | except: 82 | config = tf.ConfigProto() 83 | # config.gpu_options.visible_device_list = '1' 84 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 85 | 86 | # disable arithmetic optimizer 87 | off = rewriter_config_pb2.RewriterConfig.OFF 88 | config.graph_options.rewrite_options.arithmetic_optimization = off 89 | 90 | try: 91 | session = tf.compat.v1.Session(config=config) 92 | tf.compat.v1.keras.backend.set_session(session) 93 | except: 94 | session = tf.Session(config=config) 95 | keras.backend.set_session(session) 96 | # tf.global_variables_initializer() 97 | 98 | reset_keras() 99 | 100 | # fit params 101 | # image params 102 | img_height = 32 103 | img_width = 32 104 | img_channels = 2 105 | img_total = img_height*img_width*img_channels 106 | 107 | # Data loading 108 | batch_num = 1 if opt.debug_flag else batch_num # we'll use batch_num-1 for training and 1 for validation 109 | epochs = 1 if opt.debug_flag else opt.epochs 110 | 111 | # T_dummy = 1 # while debugging, we'll repeat the first timeslot 112 | pow_diff, data_train, data_val = dataset_pipeline_col(opt.debug_flag, opt.aux_bool, dataset_spec, diff_spec, opt.aux_size, T = T, img_channels = img_channels, img_height = img_height, img_width = img_width, data_format = data_format, train_argv = opt.train_argv, subsample_prop=subsample_prop, thresh_idx_path=thresh_idx_path, stride=opt.stride) 113 | 114 | print(f"pow_diff.shape: {pow_diff.shape}") 115 | 116 | # loading directly from unnormalized data; normalize data 117 | aux_val, x_val = data_val 118 | x_val = renorm_H4(x_val,minmax_file) 119 | data_val = aux_val, x_val 120 | print(f"-> aux_val.shape: {aux_val.shape} - x_val.shape: {x_val.shape}") 121 | print('-> post-renorm: x_val range is from {} to {}'.format(np.min(x_val),np.max(x_val))) 122 | 123 | aux_train, x_train = data_train 124 | x_train = renorm_H4(x_train,minmax_file) 125 | data_train = [aux_train, x_train] 126 | print(f"-> aux_train.shape: {aux_train.shape} - x_train.shape: {x_train.shape}") 127 | print('-> post-renorm: x_train range is from {} to {}'.format(np.min(x_train),np.max(x_train))) 128 | 129 | outpath_base = f"{model_dir}/{opt.env}" 130 | if opt.dir != None: 131 | outpath_base += "/" + opt.dir 132 | 133 | rates = [512, 256, 128, 64, 32] 134 | for rate in rates: 135 | outfile_base = f"{outpath_base}/cr{rate}/{network_name}" 136 | subnetwork_spec = [outpath_base, subnetwork_name] 137 | 138 | # load model 139 | if temp_config == None: 140 | # outfile = f"{model_dir}/model_CsiNet_LSTM_{envir}_dim{M_2}_{dates[0]}" 141 | outfile_base = f"{outpath_base}/cr{rate}/{network_name}" 142 | CsiNet_LSTM_model = CsiNet_LSTM(img_channels, img_height, img_width, T, M_1, rate, envir=opt.env, LSTM_depth=opt.depth, data_format=data_format, t1_trainable=t1_train, t2_trainable=t2_train, share_bool=share_bool, pass_through_bool=pass_through_bool,LSTM_only_bool=LSTM_only_bool, subnetwork_spec=subnetwork_spec) 143 | # tf.keras.models.model_from_json("{}.json".format(outfile)) 144 | # template_model.load_weights("{}.h5".format(outfile)) 145 | template_model = tf.keras.models.load_model("{}.h5".format(outfile_base)) 146 | n_layers = len(template_model.layers) 147 | if n_layers == 2: 148 | # handle the combined model 149 | print("Combined model -- load each layer by name.") 150 | CsiNet_T10_model = template_model[0] 151 | LSTM_model = template_model[1] 152 | print("CsiNet_T10_model.summary()") 153 | CsiNet_T10_model.summary() 154 | print("LSTM_model.summary()") 155 | LSTM_model.summary() 156 | else: 157 | CsiNet_LSTM_model.load_weights("{}.h5".format(outfile_base)) 158 | else: 159 | LSTM_model, CsiNet_LSTM_model = combine_model(temp_config, lstm_config, json_config, data_train, data_val, data_test, debug_flag=opt.debug_flag) 160 | 161 | # preloaded performance 162 | if (opt.debug_flag): 163 | print ("--- Pre-loaded network performance is... ---") 164 | x_hat = CsiNet_LSTM_model.predict(data_val) 165 | 166 | print("For Adam with lr={:1.1e} // batch_size={} // norm_range={}".format(lr,batch_size,norm_range)) 167 | print("x_hat.dtype: {}".format(x_hat.dtype)) # sanity check on output datatype 168 | if norm_range == "norm_H3": 169 | x_hat_denorm = denorm_H3(x_hat,minmax_file) 170 | x_val_denorm = denorm_H3(data_val[1],minmax_file) 171 | if norm_range == "norm_H4": 172 | x_hat_denorm = denorm_H4(x_hat,minmax_file) 173 | x_val_denorm = denorm_H4(data_val[1],minmax_file) 174 | print('-> x_hat range is from {} to {}'.format(np.min(x_hat_denorm),np.max(x_hat_denorm))) 175 | print('-> x_val range is from {} to {} '.format(np.min(x_val_denorm),np.max(x_val_denorm))) 176 | calc_NMSE(x_hat_denorm,x_val_denorm,T=T) 177 | 178 | CsiNet_models = [] 179 | CsiNet_names = ["CsiNet_hi"] + [f"CsiNet_lo_{i}" for i in range(1,10)] 180 | 181 | # autoencoder_models = [] 182 | CsiOut = [] 183 | 184 | aux = tf.keras.Input((M_1)) 185 | if(data_format == "channels_last"): 186 | x = tf.keras.Input((T, img_height, img_width, img_channels)) 187 | elif(data_format == "channels_first"): 188 | x = tf.keras.Input((T, img_channels, img_height, img_width)) 189 | else: 190 | print("Unexpected data_format param in CsiNet input.") # raise an exception eventually. For now, print a complaint 191 | 192 | print("aux (full network aux input): {}".format(aux)) 193 | print("x (full network input): {}".format(x)) 194 | 195 | side_max = 1 196 | side_min = -1 197 | 198 | # iterate through CsiNet models, insert quantizer in between encoder and decoder 199 | for i, model_name in enumerate(CsiNet_names): 200 | CsiIn = Lambda( lambda x: x[:,i,:,:,:])(x) 201 | # CsiIn = CsiNet_LSTM_model.layers[lambda_idx[i] 202 | x_all = np.squeeze(np.vstack((data_train[1][:,i,:,:,:], data_val[1][:,i,:,:,:]))).astype('float32') 203 | aux_all = np.vstack((data_train[0], data_val[0])).astype('float32') if i == 0 else aux_t1 204 | data_all = [aux_all, x_all] 205 | model_name_adj = model_name if temp_config == None else f"{model_name}_new" 206 | CsiNet_model = CsiNet_LSTM_model.get_layer(f"{model_name_adj}") 207 | 208 | # use CsiNet_aux_quant to load in template model weights 209 | encoded_dim = M_1 if i == 0 else rate 210 | encoded_in = aux if i == 0 else t1_encoded 211 | if i == 0: 212 | enc_hat, x_hat = CsiNet_model.predict(data_all) 213 | dynamic_range = 2 ** (opt.t1_bits - 1) 214 | else: 215 | CsiNet_enc_model = Model(inputs=CsiNet_model.input, 216 | outputs=CsiNet_model.get_layer('CR2_dense').output) 217 | enc_hat = CsiNet_enc_model.predict(data_all) 218 | dynamic_range = dynamic_range_i 219 | print(f"#{i} x_hat.shape: {x_hat.shape}, enc_hat.shape: {enc_hat.shape}") 220 | enc_min = np.min(enc_hat) 221 | enc_max = np.max(enc_hat) 222 | print(f"enc_min: {enc_min}, enc_max: {enc_max}") 223 | print(f"side_min: {side_min}, side_max: {side_max}") 224 | CsiNet_quant_model = CsiNet_quant(encoded_dim, 225 | code_min = enc_min, 226 | code_max = enc_max, 227 | side_min = side_min, 228 | side_max = side_max, 229 | dynamic_range = dynamic_range) 230 | CsiNet_quant_model.build_full_network() 231 | CsiNet_quant_model.load_template_weights_v2(CsiNet_model) 232 | print(f"#{i} encoded_in.shape: {encoded_in.shape}") 233 | companded, encoded, autoencoder_out = CsiNet_quant_model.autoencoder([encoded_in,CsiIn]) 234 | print(f"#{i} encoded.shape: {encoded.shape}") 235 | if i == 0: 236 | t1_encoded = encoded 237 | aux_t1 = enc_hat # actual values to be fed in as side info at all t_i 238 | side_min = enc_min 239 | side_max = enc_max 240 | 241 | if data_format == "channels_last": 242 | CsiOut.append(Reshape((1,img_height,img_width,img_channels))(autoencoder_out)) 243 | if data_format == "channels_first": 244 | CsiOut.append(Reshape((1,img_channels,img_height,img_width))(autoencoder_out)) 245 | 246 | LSTM_in = concatenate(CsiOut,axis=1) 247 | 248 | # get pretrained LSTM model 249 | # TODO: Get rid of magic string -- rename LSTM model appropriately 250 | 251 | # option 1 252 | # LSTM_model = CsiNet_LSTM_model.get_layer("model_20") 253 | 254 | # option 2 255 | # LSTM_name_new = "model_22" if temp_config == None else "model_20" 256 | # LSTM_template_model = CsiNet_LSTM_model.get_layer(LSTM_name_new) 257 | 258 | # option 3 259 | LSTM_template_model = CsiNet_LSTM_model.layers[-1] # last layer is LSTM model 260 | 261 | LSTM_model = stacked_LSTM(img_channels, img_height, img_width, T, lstm_latent_bool, LSTM_depth=opt.depth,data_format=data_format) 262 | LSTM_model.set_weights(LSTM_template_model.get_weights()) 263 | # for i, temp_weights in tqdm(enumerate(LSTM_template_model.get_weights())): 264 | dummy_i = 0 265 | for target, temp in zip(LSTM_model.get_weights(), LSTM_template_model.get_weights()): 266 | print(f"Assertion #{dummy_i}") 267 | foo = target == temp 268 | assert(foo.all()) 269 | dummy_i += 1 270 | LSTM_out = LSTM_model(LSTM_in) 271 | 272 | # LSTM_out = LSTM_template_model(LSTM_in) 273 | 274 | # CsiNet_LSTM_code_quant_model = Model(inputs = [aux,x], outputs = [LSTM_in]) 275 | CsiNet_LSTM_code_quant_model = Model(inputs = [aux,x], outputs = [LSTM_out]) 276 | CsiNet_LSTM_code_quant_model.summary() 277 | CsiNet_LSTM_code_quant_model.compile(optimizer = "adam", loss = "mse") 278 | 279 | # evaluate model 280 | print (f"--- {opt.env} {rate} with quantized codewords ({opt.quantization_bits} bits) is... ---") 281 | x_hat = CsiNet_LSTM_code_quant_model.predict(data_val) 282 | 283 | print("For Adam with lr={:1.1e} // batch_size={} // norm_range={}".format(lr,batch_size,norm_range)) 284 | if norm_range == "norm_H3": 285 | x_hat_denorm = denorm_H3(x_hat,minmax_file) 286 | x_val_denorm = denorm_H3(data_val[1],minmax_file) 287 | if norm_range == "norm_H4": 288 | x_hat_denorm = denorm_H4(x_hat,minmax_file) 289 | x_val_denorm = denorm_H4(data_val[1],minmax_file) 290 | print('-> x_hat range is from {} to {}'.format(np.min(x_hat_denorm),np.max(x_hat_denorm))) 291 | print('-> x_val range is from {} to {} '.format(np.min(x_val_denorm),np.max(x_val_denorm))) 292 | calc_NMSE(x_hat_denorm,x_val_denorm,T=T,pow_diff=pow_diff) -------------------------------------------------------------------------------- /csinet-lstm/csinet_lstm_train.py: -------------------------------------------------------------------------------- 1 | # CsiNet_LSTM_train.py 2 | 3 | if __name__ == "__main__": 4 | import argparse 5 | import os 6 | import copy 7 | import sys 8 | import pickle 9 | sys.path.append("/home/mdelrosa/git/brat") 10 | from utils.NMSE_performance import calc_NMSE, get_NMSE, denorm_H3, renorm_H4, denorm_H4, denorm_sphH4 11 | from utils.data_tools import dataset_pipeline_col, subsample_batches 12 | from utils.parsing import str2bool 13 | from utils.timing import Timer 14 | from utils.unpack_json import get_keys_from_json 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("-d", "--debug_flag", type=int, default=0, help="flag for toggling debugging mode") 17 | parser.add_argument("-b", "--n_batch", type=int, default=20, help="number of batches to fit on (ignored during debug mode)") 18 | parser.add_argument("-l", "--dir", type=str, default=None, help="subdirectory for saving model, checkpoint, history") 19 | parser.add_argument("-e", "--env", type=str, default="indoor", help="environment (either indoor or outdoor)") 20 | parser.add_argument("-ep", "--epochs", type=int, default=10, help="number of epochs to train for") 21 | parser.add_argument("-t", "--train_argv", type=str2bool, default=True, help="flag for toggling training") 22 | parser.add_argument("-g", "--n_gpu", type=int, default=1, help="index of gpu for training") 23 | parser.add_argument("-r", "--rate", type=int, default=512, help="number of elements in latent code (i.e., encoding rate)") 24 | parser.add_argument("-de", "--depth", type=int, default=3, help="depth of lstm") 25 | parser.add_argument("-p", "--pretrained_bool", type=str2bool, default=True, help="bool for using pretrained CsiNet for each timeslot") 26 | parser.add_argument("-lo", "--load_bool", type=str2bool, default=False, help="bool for loading weights into CsiNet-LSTM network") 27 | parser.add_argument("-a", "--aux_bool", type=str2bool, default=True, help="bool for building CsiNet with auxiliary input") 28 | parser.add_argument("-m", "--aux_size", type=int, default=512, help="integer for auxiliary input's latent rate") 29 | parser.add_argument("-sr", "--stride", type=int, default=1, help="space between timeslots for each step (default 1); controls feedback interval") 30 | parser.add_argument("-v", "--viz_batch", type=int, default=-1, help="index of element to save for visualization") 31 | opt = parser.parse_args() 32 | 33 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"; # The GPU id to use, usually either "0" or "1"; 34 | os.environ["CUDA_VISIBLE_DEVICES"]="{}".format(opt.n_gpu); # Do other imports now... 35 | print("debug_flag: {} -- train_argv: {}".format(opt.debug_flag, opt.train_argv)) 36 | 37 | if opt.env == "indoor": 38 | json_config = '../config/csinet_lstm_indoor_cost2100.json' # 0 epochs 39 | elif opt.env == "outdoor": 40 | json_config = '../config/csinet_lstm_outdoor_cost2100.json' # 0 epochs 41 | # quant_config = "../config/quant/10bits.json" 42 | 43 | M_1, data_format, network_name, subnetwork_name, model_dir, norm_range, minmax_file, share_bool, T, dataset_spec, diff_spec, batch_num, lr, batch_size, subsample_prop, thresh_idx_path = get_keys_from_json(json_config, keys=['M_1', 'df', 'network_name', 'subnetwork_name', 'model_dir', 'norm_range', 'minmax_file', 'share_bool', 'T', 'dataset_spec', 'diff_spec', 'batch_num', 'lr', 'batch_size', 'subsample_prop', 'thresh_idx_path']) 44 | aux_bool, quant_bool, LSTM_only_bool, pass_through_bool, t1_train, t2_train, lstm_latent_bool = get_keys_from_json(json_config, keys=['aux_bool', 'quant_bool', 'LSTM_only_bool', 'pass_through_bool', 't1_train', 't2_train', 'lstm_latent_bool'],is_bool=True) # import these as booleans rather than int, str 45 | 46 | import scipy.io as sio 47 | import numpy as np 48 | import math 49 | import time 50 | import sys 51 | # import os 52 | try: 53 | from tensorflow.keras.optimizers import Adam 54 | from tensorflow.keras.callbacks import TensorBoard, Callback, ModelCheckpoint, EarlyStopping 55 | except: 56 | import keras 57 | from keras.optimizers import Adam 58 | from keras.callbacks import TensorBoard, Callback, ModelCheckpoint 59 | from tensorflow.core.protobuf import rewriter_config_pb2 60 | from csinet_lstm import * 61 | # from QuantizeData import quantize, get_minmax 62 | 63 | def reset_keras(): 64 | sess = tf.keras.backend.get_session() 65 | tf.keras.backend.clear_session() 66 | sess.close() 67 | # limit gpu resource allocation 68 | try: 69 | config = tf.compat.v1.ConfigProto() 70 | except: 71 | config = tf.ConfigProto() 72 | # config.gpu_options.visible_device_list = '1' 73 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 74 | 75 | # disable arithmetic optimizer 76 | off = rewriter_config_pb2.RewriterConfig.OFF 77 | config.graph_options.rewrite_options.arithmetic_optimization = off 78 | 79 | try: 80 | session = tf.compat.v1.Session(config=config) 81 | tf.compat.v1.keras.backend.set_session(session) 82 | except: 83 | session = tf.Session(config=config) 84 | keras.backend.set_session(session) 85 | # tf.global_variables_initializer() 86 | 87 | reset_keras() 88 | 89 | # fit params 90 | # image params 91 | img_height = 32 92 | img_width = 32 93 | img_channels = 2 94 | img_total = img_height*img_width*img_channels 95 | 96 | # Data loading 97 | batch_num = 1 if opt.debug_flag else batch_num # we'll use batch_num-1 for training and 1 for validation 98 | epochs = 1 if opt.debug_flag else opt.epochs 99 | 100 | # data_train, data_val, data_test = dataset_pipeline(batch_num, opt.debug_flag, aux_bool, dataset_spec, M_1, T = T, img_channels = img_channels, img_height = img_height, img_width = img_width, data_format = data_format, train_argv = opt.train_argv, merge_val_test = True) 101 | pow_diff, data_train, data_val = dataset_pipeline_col(opt.debug_flag, opt.aux_bool, dataset_spec, diff_spec, opt.aux_size, T = T, img_channels = img_channels, img_height = img_height, img_width = img_width, data_format = data_format, train_argv = opt.train_argv, subsample_prop=subsample_prop, thresh_idx_path=thresh_idx_path, stride=opt.stride) 102 | 103 | # tf Dataset object 104 | # SHUFFLE_BUFFER_SIZE = batch_size*5 105 | 106 | # def data_generator(data): 107 | # i = 0 108 | # while i < data.shape[0]: 109 | # yield data[i,:,:,:,:], data[i,:,:,:,:] 110 | # i += 1 111 | 112 | # loading directly from unnormalized data; normalize data 113 | aux_val, x_val = data_val 114 | x_val = renorm_H4(x_val,minmax_file) 115 | data_val = aux_val, x_val 116 | print(f"-> pre reshape: x_val.shape: {x_val.shape}") 117 | # x_val = np.reshape(x_val, (x_val.shape[0]*x_val.shape[1], x_val.shape[2], x_val.shape[3], x_val.shape[4])) 118 | # print(f"-> post reshape: x_val.shape: {x_val.shape}") 119 | # aux_val = np.tile(aux_val, (T,1)) 120 | print(f"-> aux_val.shape: {aux_val.shape} - x_val.shape: {x_val.shape}") 121 | print('-> post-renorm: x_val range is from {} to {}'.format(np.min(x_val),np.max(x_val))) 122 | # val_gen = tf.data.Dataset.from_tensor_slices(({"input_1": aux_val, "input_2": x_val}, x_val)).batch(batch_size).repeat() 123 | # val_gen = tf.data.Dataset.from_generator(data_generator, args=[x_val], output_types=(tf.float32, tf.float32), output_shapes=((None,)+x_val.shape[1:], (None,)+x_val.shape[1:])).batch(batch_size).repeat() 124 | 125 | if opt.train_argv: 126 | aux_train, x_train = data_train 127 | x_train = renorm_H4(x_train,minmax_file) 128 | data_train = [aux_train, x_train] 129 | # print(f"pre reshape: x_train.shape: {x_train.shape}") 130 | # x_train = np.reshape(x_train, (x_train.shape[0]*x_train.shape[1], x_train.shape[2], x_train.shape[3], x_train.shape[4])) 131 | # print(f"post reshape: x_train.shape: {x_train.shape}") 132 | # aux_train = np.tile(aux_train, (T,1)) 133 | print(f"-> aux_train.shape: {aux_train.shape} - x_train.shape: {x_train.shape}") 134 | print('-> post-renorm: x_train range is from {} to {}'.format(np.min(x_train),np.max(x_train))) 135 | # train_gen = tf.data.Dataset.from_tensor_slices(({"input_1": aux_train, "input_2": x_train}, x_train)).shuffle(SHUFFLE_BUFFER_SIZE).batch(batch_size).repeat() 136 | # train_gen = tf.data.Dataset.from_generator(data_generator, args=[x_train], output_types=(tf.float32, tf.float32), output_shapes=((None,)+x_train.shape[1:], (None,)+x_train.shape[1:])).shuffle(SHUFFLE_BUFFER_SIZE).batch(batch_size).repeat() 137 | 138 | # train_gen = tf.data.Dataset.from_tensor_slices((data_train, x_train)).shuffle(SHUFFLE_BUFFER_SIZE).batch(batch_size).repeat() 139 | # val_gen = tf.data.Dataset.from_tensor_slices((data_val, x_val)).shuffle(SHUFFLE_BUFFER_SIZE).batch(batch_size).repeat() 140 | 141 | # if (quant_bool): 142 | # if (val_min == "range"): # hacky -- gets minmax from our file 143 | # print("--- Get quantization range from minmax file ---") 144 | # val_min, val_max = get_minmax([x_train,x_val,x_test]) # gets downlink by default 145 | # print("val_min: {} -- val_max: {}".format(val_min,val_max)) 146 | # x_quant_test = quantize(x_test,val_min=val_min,val_max=val_max,bits=bits) 147 | # print("NMSE for quantization error...") 148 | # if norm_range == "norm_H3": 149 | # x_quant_test_denorm = denorm_H3(x_quant_test,minmax_file) 150 | # x_test_denorm = denorm_H3(x_test,minmax_file) 151 | # if norm_range == "norm_H4": 152 | # x_quant_test_denorm = denorm_H4(x_quant_test,minmax_file) 153 | # x_test_denorm = denorm_H4(x_test,minmax_file) 154 | # print('-> x_quant range is from {} to {}'.format(np.min(x_quant_test_denorm),np.max(x_quant_test_denorm))) 155 | # print('-> x_test range is from {} to {} '.format(np.min(x_test_denorm),np.max(x_test_denorm))) 156 | # print('test: {}'.format(np.mean(np.sum(x_quant_test_denorm-x_test_denorm)))) 157 | # calc_NMSE(x_quant_test_denorm,x_test_denorm,T=T) 158 | 159 | M_1, M_2 = opt.aux_size, opt.rate 160 | reset_keras() 161 | try: 162 | optimizer = Adam(learning_rate=lr, beta_1=0.9, beta_2=0.999) 163 | except: 164 | optimizer = Adam(lr=lr, beta_1=0.9, beta_2=0.999) 165 | print('-------------------------------------') 166 | print("Build CsiNet-LSTM for CR2={}".format(M_2)) 167 | print('-------------------------------------') 168 | 169 | # def callbacks 170 | class LossHistory(Callback): 171 | def on_train_begin(self, logs={}): 172 | self.losses_train = [] 173 | self.losses_val = [] 174 | 175 | def on_batch_end(self, batch, logs={}): 176 | self.losses_train.append(logs.get('loss')) 177 | 178 | def on_epoch_end(self, epoch, logs={}): 179 | self.losses_val.append(logs.get('val_loss')) 180 | 181 | outpath_base = f"{model_dir}/{opt.env}" 182 | if opt.dir != None: 183 | outpath_base += "/" + opt.dir 184 | outfile_base = f"{outpath_base}/cr{opt.rate}/{network_name}" 185 | subnetwork_spec = [outpath_base, subnetwork_name] 186 | # if opt.load_bool: 187 | # CsiNet_LSTM_model = CsiNet_LSTM(img_channels, img_height, img_width, T, M_1, M_2, envir=opt.env, LSTM_depth=opt.depth, data_format=data_format, t1_trainable=t1_train, t2_trainable=t2_train, share_bool=share_bool, pass_through_bool=pass_through_bool, LSTM_only_bool=LSTM_only_bool, subnetwork_spec=subnetwork_spec, pretrained_bool=opt.pretrained_bool) 188 | # CsiNet_LSTM_model.load_weights(f"{outfile_base}.h5") 189 | # print ("--- Pre-loaded network performance is... ---") 190 | # x_hat = CsiNet_LSTM_model.predict(data_val) 191 | 192 | # print("For Adam with lr={:1.1e} // batch_size={} // norm_range={}".format(lr,batch_size,norm_range)) 193 | # print("x_hat.dtype: {}".format(x_hat.dtype)) # sanity check on output datatype 194 | # if norm_range == "norm_H3": 195 | # x_hat_denorm = denorm_H3(x_hat,minmax_file) 196 | # x_val_denorm = denorm_H3(x_val,minmax_file) 197 | # if norm_range == "norm_H4": 198 | # x_hat_denorm = denorm_H4(x_hat,minmax_file) 199 | # x_val_denorm = denorm_H4(x_val,minmax_file) 200 | # print('-> x_hat range is from {} to {}'.format(np.min(x_hat_denorm),np.max(x_hat_denorm))) 201 | # print('-> x_val range is from {} to {} '.format(np.min(x_val_denorm),np.max(x_val_denorm))) 202 | # calc_NMSE(x_hat_denorm,x_val_denorm,T=T) 203 | # else: 204 | CsiNet_LSTM_model = CsiNet_LSTM(img_channels, img_height, img_width, T, M_1, M_2, LSTM_depth=opt.depth, data_format=data_format, t1_trainable=t1_train, t2_trainable=t2_train, share_bool=share_bool, pass_through_bool=pass_through_bool, LSTM_only_bool=LSTM_only_bool, subnetwork_spec=subnetwork_spec, pretrained_bool=opt.pretrained_bool) 205 | if opt.load_bool: 206 | CsiNet_LSTM_model.load_weights(f"{outfile_base}.h5") 207 | CsiNet_LSTM_model.compile(optimizer=optimizer, loss='mse') 208 | 209 | if (opt.train_argv): 210 | # save+serialize model to JSON 211 | model_json = CsiNet_LSTM_model.to_json() 212 | # outfile = f"{model_dir}/{opt.dir}/{network_name}_{opt.env}.json" 213 | with open(f"{outfile_base}.json", "w") as json_file: 214 | json_file.write(model_json) 215 | # serialize weights to HDF5 216 | # outfile = f"{model_dir}/model_{file}.h5" 217 | checkpoint = ModelCheckpoint(f"{outfile_base}.h5", monitor="val_loss",verbose=1,save_best_only=True,mode="min") 218 | early = EarlyStopping(monitor="val_loss", patience=50,verbose=1) 219 | 220 | history = LossHistory() 221 | 222 | CsiNet_LSTM_model.fit(data_train, x_train, 223 | epochs=epochs, 224 | batch_size=batch_size, 225 | shuffle=True, 226 | validation_data=(data_val, x_val), 227 | callbacks=[checkpoint, 228 | # early, 229 | history]) 230 | # TensorBoard(log_dir = path), 231 | 232 | filename = f'{outfile_base}_trainloss.csv' 233 | loss_history = np.array(history.losses_train) 234 | np.savetxt(filename, loss_history, delimiter=",") 235 | 236 | filename = f'{outfile_base}_valloss.csv' 237 | loss_history = np.array(history.losses_val) 238 | np.savetxt(filename, loss_history, delimiter=",") 239 | 240 | #Testing data 241 | tStart = time.time() 242 | x_hat = CsiNet_LSTM_model.predict(data_val) 243 | tEnd = time.time() 244 | print ("It cost %f sec per sample (%f samples)" % ((tEnd - tStart)/x_val.shape[0],x_val.shape[0])) 245 | 246 | print("For Adam with lr={:1.1e} // batch_size={} // norm_range={}".format(lr,batch_size,norm_range)) 247 | if norm_range == "norm_H3": 248 | x_hat_denorm = denorm_H3(x_hat,minmax_file) 249 | x_val_denorm = denorm_H3(x_val,minmax_file) 250 | elif norm_range == "norm_H4": 251 | x_hat_denorm = denorm_H4(x_hat,minmax_file) 252 | x_val_denorm = denorm_H4(x_val,minmax_file) 253 | print('-> x_hat range is from {} to {}'.format(np.min(x_hat_denorm),np.max(x_hat_denorm))) 254 | print('-> x_val range is from {} to {} '.format(np.min(x_val_denorm),np.max(x_val_denorm))) 255 | 256 | calc_NMSE(x_hat_denorm,x_val_denorm,T=T,pow_diff=pow_diff) 257 | 258 | if opt.viz_batch > -1 and not opt.train_argv: 259 | print(f"=== Saving input/output batch {opt.viz_batch} from validation set ===") 260 | # save input/output of validation batch for visualization 261 | viz_dict = { 262 | "input": x_val_denorm[opt.viz_batch, :, :, :, :], 263 | "output": x_hat_denorm[opt.viz_batch, :, :, :, :] 264 | } 265 | 266 | with open(f"{outfile_base}_batch{opt.viz_batch}.pkl", "wb") as f: 267 | pickle.dump(viz_dict, f) 268 | f.close() -------------------------------------------------------------------------------- /csinet-lstm/csinet_lstm_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | #SBATCH 4 | #SBATCH --mail-type=ALL 5 | #SBATCH --mail-user=mdelrosa@ucdavis.edu 6 | #SBATCH -p GPU-shared 7 | #SBATCH --job-name=out80ms 8 | #SBATCH --time=2-0 9 | #SBATCH --gres=gpu:4 10 | #SBATCH --mem=64000 # memory required by job 11 | 12 | module load cuda 13 | conda activate conda activate /ocean/projects/ecs190004p/mdelrosa/.conda/tf114 14 | 15 | python csinet_lstm_train.py -d 0 -g -1 -e outdoor -r 512 -l csinet80ms -ep 500 -sr 2 16 | 17 | #wait 18 | -------------------------------------------------------------------------------- /csinet-lstm/csinet_quant.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import tensorflow as tf 3 | import numpy as np 4 | from tensorflow.keras import Input 5 | from tensorflow.keras import backend as K 6 | from tensorflow.keras.layers import Layer, Dense, BatchNormalization, Reshape, Conv2D, add, LeakyReLU, Lambda, Reshape, concatenate 7 | from tensorflow.keras.models import Model, load_model 8 | from tensorflow.keras.models import Model, model_from_json 9 | from tensorflow.keras import regularizers, initializers, activations 10 | import os 11 | 12 | class CsiNet_quant(): 13 | """ 14 | Wrapper that breaks out encoder/decoder. Allows for insertion of quantization layer. 15 | """ 16 | def __init__(self, encoded_dim, dynamic_range=32, img_channels=2, img_height=32, img_width=32, encoder_in=None, residual_num=2, aux_shape=512, encoded_in=None, data_format="channels_first",name=None,out_activation='tanh', code_min = -1, code_max = 1, side_min = -1, side_max = 1): 17 | self.dynamic_range = dynamic_range 18 | self.img_channels = img_channels 19 | self.img_height = img_height 20 | self.img_width = img_width 21 | self.img_total = img_channels*img_height*img_width 22 | self.encoded_dim = encoded_dim 23 | self.encoder_in = encoder_in 24 | self.residual_num = residual_num 25 | self.aux_shape = aux_shape 26 | self.aux = tf.keras.Input(aux_shape) 27 | self.encoded_in = encoded_in 28 | self.data_format = data_format 29 | self.name = name 30 | self.out_activation = out_activation 31 | self.use_bias = True # not sure on this one 32 | self.code_min = code_min 33 | self.code_max = code_max 34 | self.side_min = side_min 35 | self.side_max = side_max 36 | 37 | def add_common_layers(self, y, enc_bool=False): 38 | if enc_bool: 39 | y = BatchNormalization(name='CR2_batch_normalization')(y) 40 | y = LeakyReLU(name='CR2_leaky_re_lu')(y) 41 | else: 42 | y = BatchNormalization(axis=1)(y) 43 | y = LeakyReLU()(y) 44 | return y 45 | 46 | def residual_block_decoder(self, y): 47 | 48 | y = Conv2D(128, kernel_size=(1, 1), padding='same',data_format=self.data_format,name="deconv1", use_bias=self.use_bias)(y) 49 | y = self.add_common_layers(y) 50 | y = Conv2D(64, kernel_size=(1, 1), padding='same',data_format=self.data_format,name="deconv2", use_bias=self.use_bias)(y) 51 | y = self.add_common_layers(y) 52 | y = Conv2D(32, kernel_size=(3, 3), padding='same',data_format=self.data_format,name="deconv3", use_bias=self.use_bias)(y) 53 | y = self.add_common_layers(y) 54 | y = Conv2D(32, kernel_size=(3, 3), padding='same',data_format=self.data_format,name="deconv4", use_bias=self.use_bias)(y) 55 | y = self.add_common_layers(y) 56 | y = Conv2D(16, kernel_size=(3, 3), padding='same',data_format=self.data_format,name="deconv5", use_bias=self.use_bias)(y) 57 | y = self.add_common_layers(y) 58 | y = Conv2D(16, kernel_size=(3, 3), padding='same',data_format=self.data_format,name="deconv6", use_bias=self.use_bias)(y) 59 | y = self.add_common_layers(y) 60 | y = Conv2D(2, (3, 3), activation=self.out_activation, padding='same',data_format=self.data_format,name="predict", use_bias=self.use_bias)(y) 61 | return y 62 | 63 | # Bulid the autoencoder model of CsiNet 64 | def encoder_network(self, x): 65 | 66 | x = Conv2D(8, (3, 3), padding='same', data_format=self.data_format, name='CR2_conv2d_1', use_bias=self.use_bias)(x) 67 | x = self.add_common_layers(x) 68 | x = Conv2D(16, (3, 3), padding='same', data_format=self.data_format, name='CR2_conv2d_2', use_bias=self.use_bias)(x) 69 | x = self.add_common_layers(x) 70 | x = Conv2D(2, (3, 3), padding='same', data_format=self.data_format, name='CR2_conv2d_3', use_bias=self.use_bias)(x) 71 | x = self.add_common_layers(x) 72 | 73 | x = Reshape((self.img_total,), name='CR2_reshape')(x) 74 | print(f"--- In encoder_network: self.encoded_dim={self.encoded_dim} ---") 75 | encoded = Dense(self.encoded_dim, activation='linear', name='CR2_dense')(x) 76 | 77 | return encoded 78 | 79 | def encoder_quantizer(self, encoded, x_min, x_max): 80 | return None 81 | 82 | def decoder_network(self, encoded): 83 | x = Dense(self.img_total, activation='linear')(encoded) 84 | if(self.data_format == "channels_first"): 85 | x = Reshape((self.img_channels, self.img_height, self.img_width,))(x) 86 | elif(self.data_format == "channels_last"): 87 | x = Reshape((self.img_height, self.img_width, self.img_channels,))(x) 88 | 89 | x = self.residual_block_decoder(x) 90 | 91 | return x 92 | 93 | # --- quantizer layers --- 94 | # --> these require extrema from pre-trained network <-- 95 | 96 | # helper function -- mu-law encoding 97 | def encoder_mu(self, x, encoded_dim, img_total = 2048, dynamic_range_i=32, mu = 255., side_bool = False): 98 | code_max = self.side_max if side_bool else self.code_max 99 | code_min = self.side_min if side_bool else self.code_min 100 | code_abs_max = np.amax(np.absolute([code_max, code_min])) 101 | encoded_quan_norm = Lambda(lambda x: x / code_abs_max)(x) 102 | encoded_quan_norm_u = Lambda(lambda x: K.sign(x) * K.log(1 + mu * K.abs(x)) / np.log(1 + mu))(encoded_quan_norm) 103 | encoded = Lambda(lambda x: x*(dynamic_range_i-1))(encoded_quan_norm_u) 104 | return encoded 105 | 106 | def decoder_mu(self, encoded, side_encoded = None, dynamic_range_i=32, mu = 255.): 107 | x = Lambda(lambda x: x / (dynamic_range_i-1))(encoded) 108 | x = Lambda(lambda x: K.sign(x) * (1 / mu) *(K.exp(K.abs(x)*np.log(1 + mu)) - 1))(x) 109 | code_abs_max = np.amax(np.absolute([self.code_max, self.code_min])) 110 | x = Lambda(lambda x: x * code_abs_max)(x) 111 | 112 | if type(side_encoded) != type(None): 113 | # decode t1 codeword if present 114 | y = Lambda(lambda x: x / (dynamic_range_i-1))(side_encoded) 115 | y = Lambda(lambda x: K.sign(x) * (1 / mu) *(K.exp(K.abs(x)*np.log(1 + mu)) - 1))(y) 116 | code_abs_max = np.amax(np.absolute([self.side_max, self.side_min])) 117 | y = Lambda(lambda x: x * code_abs_max)(y) 118 | x = concatenate([x,y]) 119 | 120 | return x 121 | 122 | def quantizer_o(self, encoded, encoded_dim, dynamic_range_i=32, kernel_regularizer = None, name_list=["bequantization", "dequantization", "quantization"]): 123 | hada1 = Hadamard(encoded_dim=encoded_dim, kernel_regularizer=kernel_regularizer, name=name_list[0]) # kernel_initializer=keras.initializers.Constant(dynamic_range), 124 | # hada1 = Hadamard(encoded_dim=encoded_dim, kernel_regularizer=quan_reg(quan_lam)) 125 | 126 | hada2 = Hadamard_div(encoded_dim=encoded_dim, name=name_list[1]) 127 | 128 | self.create_inversed_weights(hada1, hada2, (None,) + (encoded_dim,)) 129 | encoded_bp = hada1(encoded) 130 | encoded_quan = Roundings(dynamic_range_i=dynamic_range_i, name=name_list[2])(encoded_bp) 131 | encoded_dequan = hada2(encoded_quan) 132 | 133 | return encoded_dequan 134 | 135 | def create_inversed_weights(self, hada1, hada2, input_shape): 136 | with K.name_scope(hada1.name): 137 | hada1.build(input_shape) 138 | with K.name_scope(hada2.name): 139 | hada2.build(input_shape) 140 | # hada2.kernel = K.variable(1)/hada1.kernel 141 | hada2.kernel = hada1.kernel 142 | hada2._trainable_weights = [] 143 | hada2._trainable_weights.append(hada2.kernel) 144 | 145 | def build_full_network(self): 146 | if(self.data_format == "channels_last"): 147 | image_tensor = Input((self.img_height, self.img_width, self.img_channels)) 148 | elif(self.data_format == "channels_first"): 149 | image_tensor = Input((self.img_channels, self.img_height, self.img_width)) 150 | else: 151 | print("Unexpected tensor_shape param in CsiNet input.") 152 | # raise Exception 153 | encoded = self.encoder_network(image_tensor) 154 | 155 | # mu-law companding (no trainable scalars) 156 | companded = self.mu_law_pipeline(encoded, side_bool=False, dynamic_range_i=self.dynamic_range) 157 | side_companded = self.mu_law_pipeline(self.aux, side_bool=True, dynamic_range_i=self.dynamic_range, compander_name="side_companded") 158 | x = concatenate([side_companded,companded]) 159 | 160 | # mu-law companding (trainable layers) 161 | # x = self.encoder_mu(encoded, self.encoded_dim, side_bool = False) 162 | # mu_aux = self.encoder_mu(self.aux, self.aux_shape, side_bool = True) 163 | # x = self.quantizer_o(x, self.encoded_dim, dynamic_range_i=self.dynamic_range) 164 | # y = self.quantizer_o(mu_aux, self.aux_shape, dynamic_range_i=self.dynamic_range, name_list=["side_bequantization", "side_dequantization", "side_quantization"]) 165 | # x = self.decoder_mu(x, side_encoded=y) # decoder handles quantization of encoded tensor and aux encoded tensor 166 | 167 | # x = concatenate([self.aux, encoded]) 168 | network_output = self.decoder_network(x) 169 | 170 | tens_type = type(image_tensor) 171 | if type(self.aux) == tens_type: 172 | autoencoder = Model(inputs=[self.aux,image_tensor], outputs=[companded, encoded, network_output]) 173 | else: 174 | autoencoder = Model(inputs=[image_tensor], outputs=[companded, encoded, network_output]) 175 | # if self.encoder_in: 176 | # autoencoder.load_weights(by_name=True) 177 | self.autoencoder = autoencoder 178 | self.encoded = encoded 179 | 180 | def mu_law_pipeline(self, x, img_total = 2048, dynamic_range_i=32, mu = 255., side_bool = False, compander_name="companded"): 181 | code_max = self.side_max if side_bool else self.code_max 182 | code_min = self.side_min if side_bool else self.code_min 183 | code_abs_max = np.amax(np.absolute([code_max, code_min])) 184 | encoded_quan_norm = Lambda(lambda x: x / code_abs_max)(x) 185 | encoded_quan_norm_u = Lambda(lambda x: K.sign(x) * K.log(1 + mu * K.abs(x)) / np.log(1 + mu))(encoded_quan_norm) 186 | encoded = Lambda(lambda x: tf.math.round(x*(dynamic_range_i-1)))(encoded_quan_norm_u) 187 | print(encoded) 188 | x = Lambda(lambda x: x / (dynamic_range_i-1))(encoded) 189 | x = Lambda(lambda x: K.sign(x) * (1 / mu) *(K.exp(K.abs(x)*np.log(1 + mu)) - 1))(x) 190 | x = Lambda(lambda x: x * code_abs_max, name = compander_name)(x) 191 | return x 192 | 193 | def load_template_weights(self, template_model): 194 | """ 195 | load weights from pretrained unquantized model into model with quantizer 196 | see here: https://stackoverflow.com/a/43702449 197 | """ 198 | 199 | # ideally something like this would work... 200 | # temp_name = 'temp.h5' 201 | # template_model.save_weights(temp_name) 202 | # self.autoencoder.load_weights(temp_name, by_name=True) 203 | 204 | # but this might have to be done :facepalm: 205 | # magic numbers - models are the same up to enc_limit 206 | enc_limit = 12 207 | # dec_offset = 11 # this worked for single mu law quantizer/dequantizer 208 | dec_offset = 18 209 | temp_idx = 0 210 | quant_idx = 0 211 | template_weights = template_model.get_weights() 212 | print("--- Loading {} weights into quantized model ---".format(len(template_weights))) 213 | p_bar = tqdm(total = len(template_weights)) 214 | while (temp_idx < len(template_weights)): 215 | quant_layer_weights = self.autoencoder.layers[quant_idx].weights 216 | quant_num = len(quant_layer_weights) 217 | layer_weights = [] 218 | if (quant_num == 0): 219 | quant_idx += 1 220 | else: 221 | for i in range(quant_num): 222 | layer_weights.append(template_weights[temp_idx+i]) 223 | self.autoencoder.layers[quant_idx].set_weights(layer_weights) 224 | temp_idx += quant_num 225 | p_bar.update(quant_num) 226 | quant_idx += 1 227 | if quant_idx == enc_limit: 228 | quant_idx += dec_offset-1 # skip ahead by num layers in quant layers 229 | layer_weights = [] 230 | p_bar.close() 231 | print("--- Finished loading ---") 232 | 233 | def load_template_weights_v2(self, template_model): 234 | """ 235 | load weights from pretrained unquantized model into model with quantizer 236 | see here: https://stackoverflow.com/a/43702449 237 | """ 238 | 239 | enc_limit = 20 240 | dec_offset = 0 241 | template_weights = template_model.get_weights() 242 | quant_slice = self.autoencoder.get_weights()[enc_limit:enc_limit+dec_offset] 243 | template_weights_new = template_weights[:enc_limit] + quant_slice + template_weights[enc_limit:] 244 | self.autoencoder.set_weights(np.array(template_weights_new)) 245 | for i, temp_weights in tqdm(enumerate(template_weights_new)): 246 | for quant, temp in zip(self.autoencoder.get_weights()[i], temp_weights): 247 | foo = quant == temp 248 | assert(foo.all()) 249 | 250 | def load_template_weights_v3(self, template_model): 251 | """ 252 | load weights from pretrained unquantized model into model with quantizer 253 | see here: https://stackoverflow.com/a/43702449 254 | """ 255 | enc_limit = 12 256 | # dec_offset = 11 # this worked for single mu law quantizer/dequantizer 257 | dec_offset = 18 258 | temp_idx = 0 259 | quant_idx = 0 260 | template_weights = template_model.get_weights() 261 | print("--- Loading {} weights into quantized model ---".format(len(template_weights))) 262 | p_bar = tqdm(total = len(template_weights)) 263 | while (temp_idx < len(template_weights)): 264 | quant_layer_weights = self.autoencoder.layers[quant_idx].weights 265 | quant_num = len(quant_layer_weights) 266 | layer_weights = [] 267 | if (quant_num == 0): 268 | quant_idx += 1 269 | else: 270 | for i in range(quant_num): 271 | layer_weights.append(template_weights[temp_idx+i]) 272 | self.autoencoder.layers[quant_idx].set_weights(layer_weights) 273 | temp_idx += quant_num 274 | p_bar.update(quant_num) 275 | quant_idx += 1 276 | if quant_idx == enc_limit: 277 | quant_idx += dec_offset-1 # skip ahead by num layers in quant layers 278 | layer_weights = [] 279 | p_bar.close() 280 | print("--- Finished loading ---") 281 | 282 | class Roundings(Layer): 283 | 284 | def __init__(self, dynamic_range_i=16, **kwargs): 285 | super(Roundings, self).__init__(**kwargs) 286 | self.supports_masking = True 287 | self.dynamic_range = dynamic_range_i 288 | 289 | def sum_sigmoid(self, x): 290 | dynamic_range = self.dynamic_range 291 | r = 100. 292 | i = tf.constant(-dynamic_range + 0.5, dtype=tf.float32) 293 | j = tf.zeros(shape=tf.shape(x), dtype=tf.float32) 294 | [_, approx_round] = tf.while_loop(lambda i, j: tf.less(i, dynamic_range - 0.5), 295 | lambda i, j: [tf.add(i, 1), tf.add(j, K.sigmoid(r * (x - i)))], [i, j]) 296 | return tf.add(approx_round, -dynamic_range) 297 | 298 | def call(self, inputs, **kwargs): 299 | return self.sum_sigmoid(inputs) 300 | 301 | def get_config(self): 302 | config = {'dynamic_range': int(self.dynamic_range)} 303 | base_config = super(Roundings, self).get_config() 304 | return dict(list(base_config.items()) + list(config.items())) 305 | 306 | def compute_output_shape(self, input_shape): 307 | return input_shape 308 | 309 | class Hadamard_div(Layer): 310 | 311 | def __init__(self, kernel_initializer='ones', \ 312 | kernel_regularizer=None, activation='linear', encoded_dim = 32, trainable = True, **kwargs): 313 | super(Hadamard_div, self).__init__(**kwargs) 314 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 315 | self.activation = activations.get(activation) 316 | self.kernel_initializer = initializers.get(kernel_initializer) 317 | self.encoded_dim = encoded_dim 318 | 319 | def build(self, input_shape): 320 | # Create a trainable weight variable for this layer. 321 | self.kernel = self.add_weight(name='kernel', 322 | shape=(self.encoded_dim,), 323 | initializer=self.kernel_initializer, 324 | regularizer=self.kernel_regularizer, 325 | trainable=True) 326 | super(Hadamard_div, self).build(input_shape) # Be sure to call this somewhere! 327 | 328 | def call(self, x): 329 | print(x.shape, self.kernel.shape) 330 | outputs = x / self.kernel 331 | if self.activation is not None: 332 | return self.activation(outputs) 333 | return outputs 334 | 335 | def compute_output_shape(self, input_shape): 336 | # print(input_shape) 337 | return input_shape 338 | 339 | class Hadamard(Layer): 340 | 341 | def __init__(self, kernel_initializer='ones', \ 342 | kernel_regularizer=None, encoded_dim = 32, trainable = True, **kwargs): 343 | super(Hadamard, self).__init__(**kwargs) 344 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 345 | self.kernel_initializer = initializers.get(kernel_initializer) 346 | self.encoded_dim = encoded_dim 347 | 348 | def build(self, input_shape): 349 | # Create a trainable weight variable for this layer. 350 | self.kernel = self.add_weight(name='kernel', 351 | shape=(self.encoded_dim,), 352 | initializer=self.kernel_initializer, 353 | regularizer=self.kernel_regularizer, 354 | trainable=True) 355 | super(Hadamard, self).build(input_shape) # Be sure to call this somewhere! 356 | 357 | def call(self, x): 358 | print(x.shape, self.kernel.shape) 359 | outputs = x * self.kernel 360 | return outputs 361 | 362 | def compute_output_shape(self, input_shape): 363 | return input_shape -------------------------------------------------------------------------------- /csinet-lstm/csinet_resid.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import concatenate, Dense, BatchNormalization, Reshape, add, LeakyReLU 3 | from tensorflow.keras import Input 4 | from tensorflow.keras.models import Model 5 | import numpy as np 6 | 7 | Conv2D = tf.keras.layers.Conv2D 8 | 9 | class CosineSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): 10 | def __init__(self, warmup_steps=200, epochs=600, max_lr=1e-3, min_lr=1e-4): 11 | super(CosineSchedule, self).__init__() 12 | 13 | self.warmup_steps = tf.cast(warmup_steps, tf.float32) 14 | self.epochs = tf.cast(epochs, tf.float32) 15 | self.max_lr = tf.cast(max_lr, tf.float32) 16 | self.min_lr = tf.cast(min_lr, tf.float32) 17 | self.diff_lr = max_lr - min_lr 18 | 19 | def warmup_rate(self): 20 | return self.diff_lr * self.step / self.warmup_steps + self.min_lr 21 | 22 | def cosine_rate(self): 23 | return self.diff_lr * ((tf.math.cos(self.step-self.warmup_steps*np.pi / self.epochs - self.warmup_steps) + 1) / 2) + self.min_lr 24 | 25 | def get_config(self): 26 | config = { 27 | # 'epochs': self.epochs, 28 | # 'warmup_steps': self.warmup_steps, 29 | # 'max_lr': self.max_lr, 30 | # 'min_lr': self.min_lr, 31 | # 'diff_lr': self.diff_lr, 32 | # 'rate': self.rate 33 | } 34 | return config 35 | 36 | 37 | def __call__(self, step): 38 | self.step = step 39 | rate = tf.cond(step < self.warmup_steps, self.warmup_rate, self.cosine_rate) 40 | self.rate = rate 41 | return rate 42 | 43 | def CsiNet(img_channels, img_height, img_width, encoded_dim, encoder_in=None, residual_num=2, aux=None, encoded_in=None, data_format="channels_last",name=None,out_activation='sigmoid'): 44 | img_total = img_channels*img_height*encoded_dim 45 | 46 | # Bulid the autoencoder model of CsiNet 47 | def residual_network(x, residual_num, encoded_dim, aux): 48 | def add_common_layers(y,enc_bool=False): 49 | if enc_bool: 50 | y = BatchNormalization(name='CR2_batch_normalization')(y) 51 | y = LeakyReLU(name='CR2_leaky_re_lu')(y) 52 | else: 53 | y = BatchNormalization()(y) 54 | y = LeakyReLU()(y) 55 | return y 56 | def residual_block_decoded(y): 57 | shortcut = y 58 | 59 | # according to CsiNet-LSTM paper Fig. 1, residual network has 2-filter conv2D layers before other conv2D layers 60 | y = Conv2D(2, kernel_size=(3, 3), padding='same', data_format=data_format)(y) 61 | y = add_common_layers(y) 62 | 63 | y = Conv2D(8, kernel_size=(3, 3), padding='same', data_format=data_format)(y) 64 | y = add_common_layers(y) 65 | 66 | y = Conv2D(16, kernel_size=(3, 3), padding='same', data_format=data_format)(y) 67 | y = add_common_layers(y) 68 | 69 | y = Conv2D(2, kernel_size=(3, 3), padding='same', data_format=data_format)(y) 70 | y = BatchNormalization()(y) 71 | 72 | y = add([shortcut, y]) 73 | y = LeakyReLU()(y) 74 | 75 | return y 76 | 77 | # if encoder_in: 78 | x = Conv2D(2, (3, 3), padding='same', data_format=data_format, name='CR2_conv2d')(x) 79 | x = add_common_layers(x,enc_bool=True) 80 | 81 | x = Reshape((img_total,), name='CR2_reshape')(x) 82 | encoded = Dense(encoded_dim, activation='linear', name='CR2_dense')(x) 83 | 84 | print("Aux check: {}".format(aux)) 85 | tens_type = type(x) 86 | if type(aux) == tens_type: 87 | x = Dense(img_total, activation='linear')(concatenate([aux,encoded])) 88 | else: 89 | x = Dense(img_total, activation='linear')(encoded) 90 | # reshape based on data_format 91 | if(data_format == "channels_first"): 92 | x = Reshape((img_channels, img_height, img_width,))(x) 93 | elif(data_format == "channels_last"): 94 | x = Reshape((img_height, img_width, img_channels,))(x) 95 | 96 | for i in range(residual_num): 97 | x = residual_block_decoded(x) 98 | x = Conv2D(2, (3, 3), activation=out_activation, padding='same', data_format=data_format)(x) 99 | 100 | return [x, encoded] 101 | 102 | if(data_format == "channels_last"): 103 | image_tensor = Input((img_height, img_width, img_channels)) 104 | elif(data_format == "channels_first"): 105 | image_tensor = Input((img_channels, img_height, img_width)) 106 | else: 107 | print("Unexpected tensor_shape param in CsiNet input.") 108 | # raise Exception 109 | [network_output, encoded] = residual_network(image_tensor, residual_num, encoded_dim, aux) 110 | print('network_output: {} - encoded: {} - aux: {}'.format(network_output, encoded, aux)) 111 | tens_type = type(image_tensor) 112 | print('image_tensor.dtype: {}'.format(tens_type)) 113 | print('type(aux): {}'.format(type(aux))) 114 | if type(aux) == tens_type: 115 | autoencoder = Model(inputs=[aux,image_tensor], outputs=[network_output,encoded]) 116 | else: 117 | autoencoder = Model(inputs=[image_tensor], outputs=[network_output, encoded]) 118 | if encoder_in: 119 | autoencoder.load_weights(by_name=True) 120 | return [autoencoder, encoded] -------------------------------------------------------------------------------- /csinet-lstm/csinet_train.py: -------------------------------------------------------------------------------- 1 | 2 | if __name__ == "__main__": 3 | import argparse 4 | import pickle 5 | import os 6 | import copy 7 | import sys 8 | sys.path.append("/jet/home/mdelrosa/git/brat") 9 | from utils.NMSE_performance import calc_NMSE, get_NMSE, denorm_H3, renorm_H4, denorm_H4, renorm_tanh, denorm_tanh 10 | from utils.data_tools import dataset_pipeline_col, dataset_pipeline_complex, subsample_batches, load_pow_diff 11 | from utils.parsing import str2bool 12 | from utils.timing import Timer 13 | from utils.unpack_json import get_keys_from_json 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("-d", "--debug_flag", type=int, default=0, help="flag for toggling debugging mode") 16 | parser.add_argument("-l", "--dir", type=str, default=None, help="subdirectory for saving model, checkpoint, history") 17 | parser.add_argument("-e", "--env", type=str, default="indoor", help="environment (either indoor or outdoor)") 18 | parser.add_argument("-ep", "--epochs", type=int, default=10, help="number of epochs to train for") 19 | parser.add_argument("-tr", "--train_argv", type=str2bool, default=True, help="flag for toggling training") 20 | parser.add_argument("-g", "--n_gpu", type=int, default=1, help="index of gpu for training") 21 | parser.add_argument("-r", "--rate", type=int, default=512, help="number of elements in latent code (i.e., encoding rate)") 22 | parser.add_argument("-lo", "--load_bool", type=str2bool, default=False, help="bool for loading weights into CsiNet") 23 | parser.add_argument("-a", "--aux_bool", type=str2bool, default=True, help="bool for building CsiNet with auxiliary input") 24 | parser.add_argument("-m", "--aux_size", type=int, default=512, help="integer for auxiliary input's latent rate") 25 | opt = parser.parse_args() 26 | 27 | if opt.env == "outdoor": 28 | # json_config = '../config/csinet_outdoor_cost2100_pow.json' 29 | json_config = '../config/csinet_outdoor_cost2100_tanh.json' 30 | # json_config = '../config/csinet_outdoor_cost2100_pow_subsample.json' 31 | elif opt.env == "indoor": 32 | json_config = '../config/csinet_indoor_cost2100_pow.json' 33 | # json_config = '../config/csinet_indoor_cost2100_tanh.json' 34 | # json_config = '../config/csinet_indoor_cost2100_pow_subsample.json' 35 | # json_config = '../config/csinet_indoor_cost2100_old.json' # requires dataset_pipeline_complex 36 | 37 | model_dir, norm_range, minmax_file, dataset_spec, diff_spec, batch_num, lr, batch_size, network_name, T, data_format, subsample_prop, thresh_idx_path = get_keys_from_json(json_config, keys=['model_dir','norm_range','minmax_file','dataset_spec', 'diff_spec', 'batch_num', 'learning_rate', 'batch_size', 'network_name', 'T', 'df', 'subsample_prop', 'thresh_idx_path']) 38 | # lr = lrs[0] 39 | # batch_size = batch_sizes[0] 40 | 41 | # encoded_dims, dates, result_dir, aux_bool, opt.rate, data_format, epochs, t1_train, t2_train, gpu_num, lstm_latent_bool, conv_lstm_bool = unpack_json(json_config) 42 | 43 | import os 44 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"; # The GPU id to use, usually either "0" or "1"; 45 | os.environ["CUDA_VISIBLE_DEVICES"]="{}".format(opt.n_gpu); # Do other imports now... 46 | 47 | import tensorflow as tf 48 | from tensorflow.keras.optimizers import Adam 49 | from tensorflow.keras import Input 50 | from tensorflow.keras.models import Model 51 | from tensorflow.keras.callbacks import TensorBoard, Callback, ModelCheckpoint, EarlyStopping 52 | import scipy.io as sio 53 | import numpy as np 54 | import math 55 | import time 56 | import sys 57 | 58 | from csinet import * 59 | from tensorflow.core.protobuf import rewriter_config_pb2 60 | # from NMSE_performance import calc_NMSE, denorm_H3, denorm_H4 61 | 62 | # norm_range = get_norm_range(json_config) 63 | 64 | def reset_keras(): 65 | sess = tf.keras.backend.get_session() 66 | tf.keras.backend.clear_session() 67 | sess.close() 68 | # limit gpu resource allocation 69 | config = tf.compat.v1.ConfigProto() 70 | # config.gpu_options.visible_device_list = '1' 71 | config.gpu_options.per_process_gpu_memory_fraction = 1.0 72 | # physical_devices = tf.config.list_physical_devices('GPU') 73 | # try: 74 | # tf.config.experimental.set_memory_growth(physical_devices[0], True) 75 | # except: 76 | # # Invalid device or cannot modify virtual devices once initialized. 77 | # print("Cannot access 'set_memory_growth' - skipping.") 78 | # pass 79 | 80 | # disable arithmetic optimizer 81 | off = rewriter_config_pb2.RewriterConfig.OFF 82 | config.graph_options.rewrite_options.arithmetic_optimization = off 83 | 84 | session = tf.compat.v1.Session(config=config) 85 | tf.compat.v1.keras.backend.set_session(session) 86 | 87 | reset_keras() 88 | 89 | # config = tf.ConfigProto() 90 | # config.gpu_options.per_process_gpu_memory_fraction = 1.0 91 | # session = tf.Session(config=config) 92 | 93 | # image params 94 | img_height = 32 95 | img_width = 32 96 | img_channels = 2 97 | img_total = img_height*img_width*img_channels 98 | # network params 99 | residual_num = 2 100 | # encoded_dim = 512 #compress rate=1/4->dim.=512, compress rate=1/16->dim.=128, compress rate=1/32->dim.=64, compress rate=1/64->dim.=32 101 | 102 | epochs = 1 if opt.debug_flag else opt.epochs 103 | batch_num = 1 if opt.debug_flag else batch_num 104 | 105 | pow_diff, data_train, data_val = dataset_pipeline_col(opt.debug_flag, opt.aux_bool, dataset_spec, diff_spec, opt.aux_size, T = T, img_channels = img_channels, img_height = img_height, img_width = img_width, data_format = data_format, train_argv = opt.train_argv, subsample_prop=subsample_prop, thresh_idx_path=thresh_idx_path) 106 | # data_train, data_val = dataset_pipeline_complex(opt.debug_flag, opt.aux_bool, dataset_spec, diff_spec, opt.aux_size, T = T, img_channels = img_channels, img_height = img_height, img_width = img_width, data_format = data_format, train_argv = opt.train_argv, subsample_prop=subsample_prop) 107 | 108 | # print(f"-> pre reshape: pow_diff.shape: {pow_diff.shape}") 109 | # pow_diff = np.reshape(np.real(pow_diff), (pow_diff.shape[0]*pow_diff.shape[1], -1)) 110 | # print(f"-> post reshape: pow_diff.shape: {pow_diff.shape}") 111 | 112 | # SHUFFLE_BUFFER_SIZE = batch_size*5 113 | 114 | # loading directly from unnormalized data; normalize data 115 | aux_val, x_val = data_val 116 | print('-> pre-renorm: x_val range is from {} to {}'.format(np.min(x_val),np.max(x_val))) 117 | if norm_range == "norm_H4": 118 | x_val = renorm_H4(x_val,minmax_file) 119 | elif norm_range == "tanh": 120 | x_val = renorm_tanh(x_val,minmax_file) 121 | data_val = aux_val, x_val 122 | print(f"-> pre reshape: x_val.shape: {x_val.shape}") 123 | # x_val = np.reshape(x_val, (x_val.shape[0]*x_val.shape[1], x_val.shape[2], x_val.shape[3], x_val.shape[4])) 124 | # print(f"-> post reshape: x_val.shape: {x_val.shape}") 125 | # aux_val = np.tile(aux_val, (T,1)) 126 | print(f"-> aux_val.shape: {aux_val.shape} - x_val.shape: {x_val.shape}") 127 | print('-> post-renorm: x_val range is from {} to {}'.format(np.min(x_val),np.max(x_val))) 128 | # val_gen = tf.data.Dataset.from_tensor_slices(({"input_1": aux_val, "input_2": x_val}, x_val)).batch(batch_size).repeat() 129 | 130 | if opt.train_argv: 131 | aux_train, x_train = data_train 132 | if norm_range == "norm_H4": 133 | x_train = renorm_H4(x_train,minmax_file) 134 | elif norm_range == "tanh": 135 | x_train = renorm_tanh(x_train,minmax_file) 136 | 137 | data_train = aux_train, x_train 138 | # print(f"pre reshape: x_train.shape: {x_train.shape}") 139 | # x_train = np.reshape(x_train, (x_train.shape[0]*x_train.shape[1], x_train.shape[2], x_train.shape[3], x_train.shape[4])) 140 | # print(f"post reshape: x_train.shape: {x_train.shape}") 141 | # aux_train = np.tile(aux_train, (T,1)) 142 | print(f"-> aux_train.shape: {aux_train.shape} - x_train.shape: {x_train.shape}") 143 | print('-> post-renorm: x_train range is from {} to {}'.format(np.min(x_train),np.max(x_train))) 144 | # train_gen = tf.data.Dataset.from_tensor_slices(({"input_1": aux_train, "input_2": x_train}, x_train)).shuffle(SHUFFLE_BUFFER_SIZE).batch(batch_size).repeat() 145 | 146 | # opt.rates = [512, 128, 64, 32] 147 | print('Build and train CsiNet for rate={}'.format(opt.rate)) 148 | # reset_keras() 149 | optimizer = Adam(learning_rate=lr) 150 | if opt.aux_bool: 151 | aux = Input((opt.aux_size,)) 152 | else: 153 | aux = None 154 | 155 | # build CsiNet 156 | outpath_base = f"{model_dir}/{opt.env}" 157 | if opt.dir != None: 158 | outpath_base += "/" + opt.dir 159 | outfile_base = f"{outpath_base}/cr{opt.rate}/{network_name}" 160 | # file = 'CsiNet_'+(envir)+'_dim'+str(opt.opt.rate)+'_{}'.format(date) 161 | 162 | out_activation = 'tanh' 163 | autoencoder, encoded = CsiNet(img_channels, img_height, img_width, opt.rate, aux=aux, data_format=data_format, out_activation=out_activation) # CSINet with opt.rate dimensional latent space 164 | autoencoder = Model(inputs=autoencoder.inputs,outputs=autoencoder.outputs[0]) 165 | 166 | if opt.load_bool: 167 | # outfile = "{}/model_{}.h5".format(model_dir,file) 168 | autoencoder.load_weights(f"{outfile_base}.h5") 169 | print ("--- Pre-loaded network performance is... ---") 170 | x_hat = autoencoder.predict(data_val) 171 | 172 | print("For Adam with lr={:1.1e} // batch_size={} // norm_range={}".format(lr,batch_size,norm_range)) 173 | if norm_range == "norm_H3": 174 | x_hat_denorm = denorm_H3(x_hat,minmax_file) 175 | x_val_denorm = denorm_H3(x_val,minmax_file) 176 | elif norm_range == "norm_H4": 177 | x_hat_denorm = denorm_H4(x_hat,minmax_file) 178 | x_val_denorm = denorm_H4(x_val,minmax_file) 179 | elif norm_range == "tanh": 180 | x_hat_denorm = denorm_tanh(x_hat,minmax_file) 181 | x_val_denorm = denorm_tanh(x_val,minmax_file) 182 | print('-> x_hat range is from {} to {}'.format(np.min(x_hat_denorm),np.max(x_hat_denorm))) 183 | print('-> x_val range is from {} to {} '.format(np.min(x_val_denorm),np.max(x_val_denorm))) 184 | calc_NMSE(x_hat_denorm,x_val_denorm,T=T) 185 | else: 186 | model_json = autoencoder.to_json() 187 | # outfile = "{}/model_{}.json".format(result_dir,file) 188 | with open(f"{outfile_base}.json", "w") as json_file: 189 | json_file.write(model_json) 190 | 191 | autoencoder.compile(optimizer=optimizer, loss='mse') 192 | print(autoencoder.summary()) 193 | 194 | if opt.train_argv: 195 | class LossHistory(Callback): 196 | def on_train_begin(self, logs={}): 197 | self.losses_train = [] 198 | self.losses_val = [] 199 | 200 | def on_batch_end(self, batch, logs={}): 201 | self.losses_train.append(logs.get('loss')) 202 | 203 | def on_epoch_end(self, epoch, logs={}): 204 | self.losses_val.append(logs.get('val_loss')) 205 | 206 | history = LossHistory() 207 | 208 | # early stopping callback 209 | es = EarlyStopping(monitor='val_loss',mode='min',patience=20,verbose=1) 210 | 211 | # path = f'{outfile_base}_tensorboard' 212 | 213 | # save+serialize model to JSON 214 | # model_json = autoencoder.to_json() 215 | # outfile = "{}/model_{}.json".format(result_dir,file) 216 | # with open(outfile, "w") as json_file: 217 | # json_file.write(model_json) 218 | # serialize weights to HDF5 219 | # outfile = "{}/model_{}.h5".format(result_dir,file) 220 | # autoencoder.save_weights(outfile) 221 | 222 | outfile = f"{outfile_base}.h5" 223 | checkpoint = ModelCheckpoint(outfile, monitor="val_loss",verbose=1,save_best_only=True,mode="min") 224 | 225 | steps_per_epoch = x_train.shape[0] // batch_size 226 | val_steps = x_val.shape[0] // batch_size 227 | 228 | autoencoder.fit( 229 | # train_gen, 230 | data_train, 231 | x_train, 232 | epochs=epochs, 233 | # steps_per_epoch=steps_per_epoch, 234 | batch_size=batch_size, 235 | shuffle=True, 236 | validation_data=(data_val, x_val), 237 | # validation_data=val_gen, 238 | # validation_steps=val_steps, 239 | callbacks=[history, checkpoint] 240 | ) 241 | # TensorBoard(log_dir = path)]) 242 | 243 | # filename = f'{model_dir}/{opt.env}/{opt.dir}/{network_name}_trainloss.csv' 244 | # loss_history = np.array(history.losses_train) 245 | # np.savetxt(filename, loss_history, delimiter=",") 246 | 247 | # filename = f'{model_dir}/{opt.env}/{opt.dir}/{network_name}_valloss.csv' 248 | # loss_history = np.array(history.losses_val) 249 | # np.savetxt(filename, loss_history, delimiter=",") 250 | 251 | #Testing data 252 | weights_file = f"{outfile_base}.h5" 253 | print(f"--- Loading weights from {weights_file} ---") 254 | autoencoder.load_weights(weights_file) 255 | autoencoder.trainable = False 256 | # tStart = time.time() 257 | if opt.aux_bool == 1: 258 | x_hat = autoencoder.predict([aux_val, x_val]) 259 | else: 260 | x_hat = autoencoder.predict(x_val) 261 | # tEnd = time.time() 262 | # print ("It cost %f sec" % ((tEnd - tStart)/x_val.shape[0])) 263 | print(64*'=') 264 | print("For CR2={} // Adam with lr={:1.1e} // batch_size={} // norm_range={}".format(opt.rate,lr,batch_size,norm_range)) 265 | print('-> pre-denorm: x_hat range is from {} to {}'.format(np.min(x_hat),np.max(x_hat))) 266 | print('-> pre-denorm: x_val range is from {} to {} '.format(np.min(x_val),np.max(x_val))) 267 | if norm_range == "norm_H3": 268 | x_hat_denorm = denorm_H3(x_hat,minmax_file) 269 | x_val_denorm = denorm_H3(x_val,minmax_file) 270 | elif norm_range == "norm_H4": 271 | x_hat_denorm = denorm_H4(x_hat,minmax_file) 272 | x_val_denorm = denorm_H4(x_val,minmax_file) 273 | elif norm_range == "tanh": 274 | x_hat_denorm = denorm_tanh(x_hat,minmax_file) 275 | x_val_denorm = denorm_tanh(x_val,minmax_file) 276 | print('-> post-denorm: x_hat range is from {} to {}'.format(np.min(x_hat_denorm),np.max(x_hat_denorm))) 277 | print('-> post-denorm: x_val range is from {} to {} '.format(np.min(x_val_denorm),np.max(x_val_denorm))) 278 | 279 | # new method (borrowed from PyTorch impl, trace-based) 280 | x_hat_denorm = x_hat_denorm[:,0,:,:] + 1j*x_hat_denorm[:,1,:,:] 281 | x_val_denorm = x_val_denorm[:,0,:,:] + 1j*x_val_denorm[:,1,:,:] 282 | x_shape = x_val_denorm.shape 283 | mse, nmse = get_NMSE(x_hat_denorm, x_val_denorm, return_mse=True, n_ang=x_shape[1], n_del=x_shape[2]) 284 | results = {} 285 | print(f"-> Truncated NMSE = {nmse:5.3f} | MSE = {mse:.4E}") 286 | results["best_nmse"] = nmse 287 | results["best_mse"] = mse 288 | if len(diff_spec) != 0: 289 | # pow_diff = load_pow_diff(diff_spec) 290 | mse, nmse = get_NMSE(x_hat_denorm, x_val_denorm, return_mse=True, n_ang=x_shape[1], n_del=x_shape[2], pow_diff_timeslot=pow_diff) 291 | print(f"-> Full NMSE = {nmse:5.3f} | MSE = {mse:.4E}") 292 | results["best_nmse_full"] = nmse 293 | results["best_mse_full"] = mse 294 | 295 | # original method (tensorflow, CsiNet-LSTM, sum of squared errors) 296 | # if len(diff_spec) != 0: 297 | # pow_diff = load_pow_diff(diff_spec) 298 | # results = calc_NMSE(x_hat_denorm, x_val_denorm, T=1, diff_test=pow_diff) 299 | 300 | print(64*'=') 301 | 302 | # dump nmse results to pickle file 303 | if opt.train_argv: 304 | with open(f"{outfile_base}_results.pkl", "wb") as f: 305 | pickle.dump(results, f) 306 | f.close() 307 | -------------------------------------------------------------------------------- /csinet-lstm/csinet_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | #SBATCH 4 | #SBATCH --mail-type=ALL 5 | #SBATCH --mail-user=mdelrosa@ucdavis.edu 6 | #SBATCH -p GPU-AI 7 | #SBATCH --job-name=cr512_out 8 | #SBATCH --time=0-12 9 | #SBATCH --gres=gpu:volta16:1 10 | #SBATCH --mem=64000 # memory required by job 11 | 12 | source $HOME/.bashrc 13 | 14 | python csinet_train.py -d 0 -g 0 -e indoor -r 256 -l csinet -ep 1000 15 | python csinet_train.py -d 0 -g 0 -e indoor -r 128 -l csinet -ep 1000 16 | python csinet_train.py -d 0 -g 0 -e indoor -r 64 -l csinet -ep 1000 17 | python csinet_train.py -d 0 -g 0 -e indoor -r 32 -l csinet -ep 1000 18 | 19 | #wait 20 | -------------------------------------------------------------------------------- /torch/csinet_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import torchvision 7 | from torchvision import transforms 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from tqdm import trange, tqdm 12 | 13 | class Encoder(torch.nn.Module): 14 | """ encoder for CsiNet """ 15 | def __init__(self, n_chan, H, W, latent_dim): 16 | super(Encoder, self).__init__() 17 | self.img_total = H*W 18 | self.n_chan = n_chan 19 | self.latent_dim = latent_dim 20 | self.enc_conv1 = nn.Conv2d(2, 8, 3, padding=1) 21 | self.bn_1 = nn.BatchNorm2d(8) 22 | self.enc_conv2 = nn.Conv2d(8, 16, 3, padding=1) 23 | self.bn_2 = nn.BatchNorm2d(16) 24 | self.enc_conv3 = nn.Conv2d(16, 2, 3, padding=1) 25 | self.bn_3 = nn.BatchNorm2d(2) 26 | self.enc_dense = nn.Linear(H*W*n_chan, latent_dim) 27 | 28 | # TODO: try different activation functions here (i.e., swish) 29 | self.activ = nn.LeakyReLU(0.1) # TODO: make sure slope matches TF slope 30 | 31 | def forward(self, x): 32 | x = self.activ(self.bn_1(self.enc_conv1(x))) 33 | x = self.activ(self.bn_2(self.enc_conv2(x))) 34 | x = self.activ(self.bn_3(self.enc_conv3(x))) 35 | x = torch.reshape(x, (x.size(0), -1,)) # TODO: verify -- does this return num samples in both channels? 36 | x = self.enc_dense(x) 37 | return x 38 | 39 | class Decoder(torch.nn.Module): 40 | """ decoder for CsiNet """ 41 | def __init__(self, n_chan, H, W, latent_dim, aux_dim=512): 42 | super(Decoder, self).__init__() 43 | self.H = H 44 | self.W = W 45 | self.img_total = H*W 46 | self.n_chan = n_chan 47 | self.dec_dense = nn.Linear(latent_dim+aux_dim, self.img_total*self.n_chan) 48 | self.dec_conv1 = nn.Conv2d(2, 128, 1) 49 | self.bn_1 = nn.BatchNorm2d(128) 50 | self.dec_conv2 = nn.Conv2d(128, 64, 1) 51 | self.bn_2 = nn.BatchNorm2d(64) 52 | self.dec_conv3 = nn.Conv2d(64, 32, 3, padding=1) 53 | self.bn_3 = nn.BatchNorm2d(32) 54 | self.dec_conv4 = nn.Conv2d(32, 32, 3, padding=1) 55 | self.bn_4 = nn.BatchNorm2d(32) 56 | self.dec_conv5 = nn.Conv2d(32, 16, 3, padding=1) 57 | self.bn_5 = nn.BatchNorm2d(16) 58 | self.dec_conv6 = nn.Conv2d(16, 16, 3, padding=1) 59 | self.bn_6 = nn.BatchNorm2d(16) 60 | self.dec_conv7 = nn.Conv2d(16, 2, 3, padding=1) 61 | 62 | self.activ = nn.LeakyReLU(0.1) # TODO: make sure slope matches TF slope 63 | self.out_activ = nn.Tanh() 64 | 65 | def forward(self, x): 66 | """ x = aux, input """ 67 | aux, H_in = x 68 | x = self.dec_dense(torch.cat((aux, H_in), 1)) 69 | x = torch.reshape(x, (x.size(0), self.n_chan, self.H, self.W)) 70 | x = self.activ(self.bn_1(self.dec_conv1(x))) 71 | x = self.activ(self.bn_2(self.dec_conv2(x))) 72 | x = self.activ(self.bn_3(self.dec_conv3(x))) 73 | x = self.activ(self.bn_4(self.dec_conv4(x))) 74 | x = self.activ(self.bn_5(self.dec_conv5(x))) 75 | x = self.activ(self.bn_6(self.dec_conv6(x))) 76 | x = self.out_activ(self.dec_conv7(x)) 77 | return x 78 | 79 | class CsiNet(nn.Module): 80 | """ CsiNet for csi estimation """ 81 | def __init__(self, encoder, decoder, latent_dim, device=None): 82 | super(CsiNet, self).__init__() 83 | self.decoder = decoder 84 | self.encoder = encoder 85 | self.latent_dim = latent_dim 86 | self.device = device 87 | self.training = True 88 | 89 | def forward(self, x): 90 | """forward call for CsiNet""" 91 | aux, H_in = x 92 | h_enc = self.encoder(H_in) 93 | return self.decoder((aux, h_enc)) 94 | 95 | def latent_loss(self, z_mean, z_stddev): 96 | """ if we want to do semi-supervised learning, then we could define the loss here """ 97 | pass 98 | 99 | if __name__ == "__main__": 100 | import argparse 101 | import pickle 102 | import copy 103 | import sys 104 | sys.path.append("/home/mdelrosa/git/brat") 105 | from utils.NMSE_performance import get_NMSE, denorm_H3, denorm_sphH4, denorm_H4, renorm_H4 106 | from utils.data_tools import dataset_pipeline_col, subsample_batches 107 | from utils.parsing import str2bool 108 | from utils.timing import Timer 109 | from utils.unpack_json import get_keys_from_json 110 | from utils.trainer import fit, score, save_predictions, save_checkpoint_history, load_checkpoint_history 111 | 112 | # set up timers 113 | timers = { 114 | "fit_timer": Timer("Fit"), 115 | "predict_timer": Timer("Predict"), 116 | "score_timer": Timer("Score") 117 | } 118 | 119 | # parse command line args 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("-d", "--debug_flag", type=int, default=0, help="flag for toggling debugging mode") 122 | parser.add_argument("-g", "--gpu_num", type=int, default=0, help="number for torch device (cuda:gpu_num)") 123 | parser.add_argument("-b", "--n_batch", type=int, default=20, help="number of batches to fit on (ignored during debug mode)") 124 | parser.add_argument("-l", "--dir", type=str, default=None, help="subdirectory for saving model, checkpoint, history") 125 | parser.add_argument("-e", "--env", type=str, default="indoor", help="environment (either indoor or outdoor)") 126 | parser.add_argument("-ep", "--epochs", type=int, default=10, help="number of epochs to train for") 127 | parser.add_argument("-tr", "--train_argv", type=str2bool, default=True, help="flag for toggling training") 128 | parser.add_argument("-t", "--n_truncate", type=int, default=32, help="value to truncate to along delay axis.") 129 | parser.add_argument("-ts", "--timeslot", type=int, default=0, help="timeslot which we are training (0-indexed).") 130 | parser.add_argument("-r", "--rate", type=int, default=512, help="number of elements in latent code (i.e., encoding rate)") 131 | parser.add_argument("-dt", "--data_type", type=str, default="norm_H4", help="type of dataset to train on (norm_H4, norm_sphH4)") 132 | parser.add_argument("-a", "--aux_bool", type=str2bool, default=True, help="bool for building CsiNet with auxiliary input") 133 | parser.add_argument("-m", "--aux_size", type=int, default=512, help="integer for auxiliary input's latent rate") 134 | opt = parser.parse_args() 135 | 136 | device = torch.device(f'cuda:{opt.gpu_num}' if torch.cuda.is_available() else 'cpu') 137 | print(f"--- Device is {device} ---") 138 | 139 | if opt.env == "outdoor": 140 | json_config = '../config/csinet_outdoor_cost2100_pow.json' 141 | elif opt.env == "indoor": 142 | json_config = '../config/csinet_indoor_cost2100_pow.json' 143 | 144 | # elif opt.data_type == "norm_sphH4": 145 | # # json_config = "../config/csinet-pro-indoor0001-sph.json" if opt.env == "indoor" else "../config/csinet-pro-outdoor300-sph.json" 146 | # json_config = "../config/csinet-pro-quadriga-indoor0001-sph.json" if opt.env == "indoor" else "../config/csinet-pro-quadriga-outdoor300-sph.json" 147 | 148 | # model_dir, norm_range, minmax_file, dataset_spec, diff_spec, batch_num, lrs, batch_sizes, network_name, T, data_format = get_keys_from_json(json_config, keys=['model_dir','norm_range','minmax_file','dataset_spec', 'diff_spec', 'batch_num', 'lrs', 'batch_sizes', 'network_name', 'T', 'df']) 149 | dataset_spec, minmax_file, img_channels, data_format, norm_range, T, network_name, model_dir, n_delay, lr, batch_size, diff_spec = get_keys_from_json(json_config, keys=["dataset_spec", "minmax_file", "img_channels", "df", "norm_range", "T", "network_name", "model_dir", "n_delay", "learning_rate", "batch_size", "diff_spec"]) 150 | # lr = lrs[0] 151 | # batch_size = batch_sizes[0] 152 | aux_bool_list = get_keys_from_json(json_config, keys=["aux_bool"], is_bool=True) 153 | aux_bool = aux_bool_list[0] # dumb, but get_keys_from_json returns list 154 | 155 | input_dim = (2,32,n_delay) 156 | epochs = 10 if opt.debug_flag else opt.epochs 157 | 158 | batch_num = 1 if opt.debug_flag else opt.n_batch # dataset batches 159 | M_1 = None # legacy holdover from CsiNet-LSTM 160 | 161 | # load all data splits 162 | 163 | data_train, data_val = dataset_pipeline_col(opt.debug_flag, opt.aux_bool, dataset_spec, opt.aux_size, T = T, img_channels = input_dim[0], img_height = input_dim[1], img_width = input_dim[2], data_format = data_format, train_argv = opt.train_argv) 164 | 165 | aux_val, x_val = data_val 166 | x_val = renorm_H4(x_val,minmax_file) 167 | print('-> post-renorm: x_val range is from {} to {}'.format(np.min(x_val),np.max(x_val))) 168 | valid_dataset = torch.utils.data.TensorDataset(torch.from_numpy(aux_val).float().to(device), torch.from_numpy(x_val).to(device)) 169 | 170 | if opt.train_argv: 171 | aux_train, x_train = data_train 172 | x_train = renorm_H4(x_train,minmax_file) 173 | print('-> post-renorm: x_train range is from {} to {}'.format(np.min(x_train),np.max(x_train))) 174 | train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(aux_train).float().to(device), torch.from_numpy(x_train).to(device)) 175 | 176 | model_dir += "/" + opt.env 177 | if opt.dir != None: 178 | model_dir += "/" + opt.dir 179 | 180 | # cr_list = [512, 256, 128, 64, 32] if opt.rate == 0 else [opt.rate]# rates for different compression ratios 181 | cr_list = [256, 128, 64, 32] if opt.rate == 0 else [opt.rate]# rates for different compression ratios 182 | for cr in cr_list: 183 | 184 | valid_ldr = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size) 185 | if opt.train_argv: 186 | train_ldr = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 187 | 188 | encoder = Encoder(input_dim[0], input_dim[1], opt.n_truncate, cr) 189 | decoder = Decoder(input_dim[0], input_dim[1], opt.n_truncate, cr) 190 | csinet_pro = CsiNet(encoder, decoder, cr, device=device).to(device) 191 | 192 | pickle_dir = f"{model_dir}/cr{cr}/t1" 193 | print(f"--- pickle_dir is {pickle_dir} ---") 194 | 195 | if opt.train_argv: 196 | print(f"--- Fitting on training set ({x_train.shape[0]} batches) ---") 197 | model, checkpoint, history, optimizer, timers = fit(csinet_pro, 198 | train_ldr, 199 | valid_ldr, 200 | batch_num, 201 | epochs=epochs, 202 | timers=timers, 203 | json_config=json_config, 204 | debug_flag=opt.debug_flag, 205 | pickle_dir=pickle_dir) 206 | else: 207 | print(f"--- Loading model, checkpoint, history, optimizer from {model_dir} ---") 208 | model, checkpoint, history, optimizer = load_checkpoint_history(pickle_dir, csinet_pro, network_name=network_name) 209 | 210 | checkpoint = score(csinet_pro, 211 | valid_ldr, 212 | x_val, 213 | batch_num, 214 | checkpoint, 215 | history, 216 | optimizer, 217 | timers=timers, 218 | json_config=json_config, 219 | debug_flag=opt.debug_flag, 220 | str_mod=f"CsiNet CR={cr} - {opt.env} -", 221 | diff_spec=diff_spec 222 | ) 223 | 224 | if not opt.debug_flag: 225 | 226 | # del train_ldr 227 | # train_ldr = torch.utils.data.DataLoader((torch.from_numpy(aux_train).to(device), torch.from_numpy(x_train).to(device)), batch_size=batch_size) 228 | # train_ldr = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) 229 | # save_predictions(csinet_pro, train_ldr, data_train, optimizer, timers, json_config=json_config, dir=pickle_dir, split="train") 230 | # save_predictions(csinet_pro, valid_ldr, data_test, optimizer, timers, json_config=json_config, dir=pickle_dir, split="valid") 231 | save_checkpoint_history(checkpoint, history, optimizer, dir=pickle_dir, network_name=network_name) 232 | # del train_ldr, valid_ldr --------------------------------------------------------------------------------