├── FullRadarCubeDataset.py ├── LICENSE.txt ├── README.md ├── RadarUnet.py ├── matched_filter.py ├── sigproc.py └── train.py /FullRadarCubeDataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | from typing import Dict 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | import h5py 9 | 10 | class FullRadarCubeDatasetConfig: 11 | def __init__(self): 12 | self.shuffle = True 13 | self.input_load_callback = None 14 | self.target_load_callback = None 15 | self.target_filename = None 16 | self.input_filename = None 17 | self.mode = "train" 18 | self.data_set_size = 1000 19 | 20 | self.number_train_samples = None 21 | self.number_valid_samples = None 22 | self.number_test_samples = None 23 | 24 | class FullRadarCubeDataset(Dataset): 25 | def __init__(self, data_set_config: FullRadarCubeDatasetConfig): 26 | 27 | self.config = data_set_config 28 | self.rng = np.random.default_rng(seed=0) 29 | 30 | if self.config.number_test_samples is None: 31 | number_test_samples = int(np.ceil(self.config.data_set_size*0.10)) # 10 % test samples 32 | else: 33 | number_test_samples = self.config.number_test_samples 34 | 35 | if self.config.number_valid_samples is None: 36 | number_valid_samples = int(np.ceil(self.config.data_set_size*0.10)) # 10 % validation samples 37 | else: 38 | number_valid_samples = self.config.number_valid_samples 39 | 40 | number_train_samples = self.config.number_train_samples 41 | number_test_samples = number_test_samples 42 | 43 | if self.config.mode == "train": 44 | self.idx_offset = 0 45 | self.dataset_size = number_train_samples 46 | self.data_indices = np.arange(self.idx_offset, number_train_samples + self.idx_offset) 47 | elif self.config.mode == "valid": 48 | self.idx_offset = number_train_samples 49 | self.dataset_size = number_valid_samples 50 | self.data_indices = np.arange(self.idx_offset, number_valid_samples + self.idx_offset) 51 | elif self.config.mode == "test": 52 | self.idx_offset = number_train_samples + number_valid_samples 53 | self.data_indices = np.arange(self.idx_offset, self.idx_offset + number_test_samples) 54 | self.dataset_size = number_test_samples 55 | else: 56 | raise Exception(f"Load mode {self.load_mode} is not supported. Supported modes are: train, test, all") 57 | 58 | if self.config.shuffle: 59 | random.seed(1) 60 | random.shuffle(self.data_indices) 61 | 62 | def shuffle_data(self, seed=1): 63 | random.seed(seed) 64 | random.shuffle(self.data_indices) 65 | 66 | def __getitem__(self, idx): 67 | with h5py.File(self.config.input_filename, 'r') as input_h5: 68 | if self.config.input_load_callback is None: 69 | x_ra = np.array(input_h5.get(f'ra_data_{self.data_indices[idx]:06d}')) 70 | x_rd = np.array(input_h5.get(f'rd_data_{self.data_indices[idx]:06d}')) 71 | x = (x_ra, x_rd) 72 | else: 73 | x = self.config.input_load_callback(input_h5, self.data_indices[idx]) 74 | 75 | with h5py.File(self.config.target_filename, 'r') as target_h5: 76 | if self.config.target_load_callback is None: 77 | y = np.array(target_h5.get(f'ra_data_{self.data_indices[idx]:06d}')) 78 | else: 79 | y = self.config.target_load_callback(target_h5, self.data_indices[idx]) 80 | 81 | return x,y 82 | 83 | def __len__(self): 84 | return self.dataset_size 85 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Super-Resolution Radar Imaging with Sparse Arrays Using a Deep Neural Network Trained with Enhanced Virtual Data 2 | 3 | Here, we publish the pytorch neural network training code of our paper "Super-Resolution Radar Imaging with Sparse Arrays Using a Deep Neural Network Trained with Enhanced Virtual Data". 4 | -------------------------------------------------------------------------------- /RadarUnet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class AttentionGate(nn.Module): 7 | """ 8 | Implemented as described in 9 | Oktay, Ozan, et al. 10 | "Attention u-net: Learning where to look for the pancreas." arXiv preprint arXiv:1804.03999 (2018). 11 | """ 12 | def __init__(self, g_channels, x_channels, inter_channels=None): 13 | super(AttentionGate, self).__init__() 14 | 15 | if inter_channels is None: 16 | inter_channels = x_channels // 2 17 | if inter_channels == 0: 18 | inter_channels = 1 19 | else: 20 | inter_channels = inter_channels 21 | 22 | self.g_channels = g_channels 23 | self.x_channels = x_channels 24 | self.inter_channels = inter_channels 25 | self.g_weights = None # size of padding only known at runtime 26 | 27 | self.x_weights = nn.Conv2d(in_channels=x_channels, out_channels=inter_channels, kernel_size=(1,1), stride=2) 28 | self.in_relu = nn.ReLU(inplace=True) 29 | 30 | self.value_weights = nn.Conv2d(in_channels=inter_channels, out_channels=1, kernel_size=(1,1)) 31 | self.out_sigmoid = nn.Sigmoid() 32 | self.resampler = None 33 | 34 | def forward(self, x, g): 35 | 36 | if self.resampler is None: 37 | self.resampler = nn.Upsample(x.shape[2:], mode='bilinear', align_corners=False) 38 | 39 | if self.g_weights is None: 40 | # pad g to prevent shape mismatch 41 | padding_size_x = (x.shape[2] // 2 - g.shape[2]) // 2 42 | padding_size_y = (x.shape[3] // 2 - g.shape[3]) // 2 43 | self.g_weights = nn.Conv2d(in_channels=self.g_channels, out_channels=self.inter_channels, 44 | kernel_size=(1,1), padding=(padding_size_x, padding_size_y), padding_mode="replicate") 45 | 46 | if g.is_cuda: 47 | self.g_weights.to("cuda:0") 48 | 49 | g_weighted = self.g_weights(g) 50 | x_weighted = self.x_weights(x) 51 | 52 | # pad if necessary 53 | if g_weighted.shape[2] < x_weighted.shape[2] or g_weighted.shape[3] < x_weighted.shape[3]: 54 | pad_x = x_weighted.shape[2] - g_weighted.shape[2] 55 | pad_y = x_weighted.shape[3] - g_weighted.shape[3] 56 | # https://stackoverflow.com/questions/48686945/reshaping-a-tensor-with-padding-in-pytorch 57 | g_weighted = F.pad(g_weighted, pad=(0, pad_y, 0, pad_x)) 58 | 59 | x_g_added = x_weighted + g_weighted 60 | x_g_relued = self.in_relu(x_g_added) 61 | 62 | x_g_weighted = self.value_weights(x_g_relued) 63 | x_g_sigmoid = self.out_sigmoid(x_g_weighted) 64 | 65 | attention_values = self.resampler(x_g_sigmoid) 66 | output_x = attention_values * x 67 | 68 | return output_x 69 | 70 | class RadarUNet(nn.Module): 71 | """ 72 | Implemented as described in 73 | Orr, Itai, Moshik Cohen, and Zeev Zalevsky. 74 | "High-resolution radar road segmentation using weakly supervised learning." 75 | Nature Machine Intelligence 3.3 (2021): 239-246. 76 | 77 | A good tutorial can also be found in : 78 | https://www.youtube.com/watch?v=u1loyDCoGbE 79 | """ 80 | 81 | def double_conv(self, in_channels, out_channels, padding=0, last_relu=True): 82 | 83 | if last_relu: 84 | conv = nn.Sequential( 85 | nn.Conv2d(in_channels, out_channels, kernel_size=(3,3), padding=padding, padding_mode="replicate"), 86 | nn.ReLU(inplace=True), 87 | nn.Conv2d(out_channels, out_channels, kernel_size=(3,3), padding=padding, padding_mode="replicate"), 88 | nn.BatchNorm2d(out_channels), 89 | nn.ReLU(inplace=True)) 90 | else: 91 | conv = nn.Sequential( 92 | nn.Conv2d(in_channels, out_channels, kernel_size=(3,3), padding=padding, padding_mode="replicate"), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(out_channels, out_channels, kernel_size=(3,3), padding=padding, padding_mode="replicate")) 95 | return conv 96 | 97 | def crop_image(self, input_image, output_image): 98 | out_size_x = output_image.shape[2] 99 | in_size_x = input_image.shape[2] 100 | delta_x = in_size_x - out_size_x 101 | 102 | if delta_x % 2 == 0: 103 | delta_x_left = delta_x // 2 104 | delta_x_right = delta_x // 2 105 | else: 106 | delta_x_left = delta_x // 2 107 | delta_x_right = delta_x // 2 + 1 108 | 109 | out_size_y = output_image.shape[3] 110 | in_size_y = input_image.shape[3] 111 | delta_y = in_size_y - out_size_y 112 | if delta_y % 2 == 0: 113 | delta_y_left = delta_y // 2 114 | delta_y_right = delta_y // 2 115 | else: 116 | delta_y_left = delta_y // 2 117 | delta_y_right = delta_y // 2 + 1 118 | 119 | cropped_image = input_image[:, :, delta_x_left:-delta_x_right, delta_y_left:-delta_y_right] 120 | return cropped_image 121 | 122 | def __init__(self, output_image_shape, input_channels=1, enable_attention=True, segmentation=False): 123 | """ 124 | output_image_shape -> Shape of output image without channel shape 125 | input_channels -> Number of input shapes for exammple two for real and complex 126 | segmentation -> if segmentation is enabled a softmax is applied 127 | """ 128 | super(RadarUNet, self).__init__() 129 | 130 | self.enable_attention = enable_attention 131 | self.segmentation = segmentation 132 | 133 | self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2) 134 | self.down_conv1 = self.double_conv(input_channels, 64) 135 | self.down_conv2 = self.double_conv(64, 128) 136 | self.down_conv3 = self.double_conv(128, 256) 137 | self.down_conv4 = self.double_conv(256, 512) 138 | self.down_conv5 = self.double_conv(512, 1024) 139 | 140 | self.down_drop1 = nn.Dropout2d(0.2) 141 | self.down_drop2 = nn.Dropout2d(0.2) 142 | self.down_drop3 = nn.Dropout2d(0.2) 143 | self.down_drop4 = nn.Dropout2d(0.2) 144 | 145 | self.up_trans_conv1 = nn.ConvTranspose2d(1024, 512, kernel_size=(2,2), stride=2) 146 | self.up_trans_conv2 = nn.ConvTranspose2d(512, 256, kernel_size=(2,2), stride=2) 147 | self.up_trans_conv3 = nn.ConvTranspose2d(256, 128, kernel_size=(2,2), stride=2) 148 | self.up_trans_conv4 = nn.ConvTranspose2d(128, 64, kernel_size=(2,2), stride=2) 149 | 150 | self.up_conv1 = self.double_conv(1024, 512, padding=3) 151 | self.up_conv2 = self.double_conv(512, 256, padding=3) 152 | self.up_conv3 = self.double_conv(256, 128, padding=3) 153 | self.up_conv4 = self.double_conv(128, 64, padding=3) 154 | 155 | if segmentation: 156 | self.final_up_conv = self.double_conv(64, 2) 157 | self.final_softmax = nn.Softmax2d() 158 | else: 159 | self.final_up_conv = self.double_conv(64, 1, last_relu=False) 160 | 161 | self.final_upsampler = nn.Upsample(output_image_shape, mode='bilinear', align_corners=False) 162 | 163 | if enable_attention: 164 | # from bottom to up 165 | self.att_layer1 = AttentionGate(g_channels=1024, x_channels=512) 166 | self.att_layer2 = AttentionGate(g_channels=512, x_channels=256) 167 | self.att_layer3 = AttentionGate(g_channels=256, x_channels=128) 168 | self.att_layer4 = AttentionGate(g_channels=128, x_channels=64) 169 | 170 | def forward(self, x): 171 | """ 172 | expecting following dimensions: (batch, channel, width, height) 173 | """ 174 | 175 | # downsampling path 176 | x1 = self.down_conv1(x) 177 | #x1 = self.down_drop1(x1) 178 | x2 = self.max_pool_2x2(x1) 179 | 180 | x3 = self.down_conv2(x2) 181 | #x3 = self.down_drop2(x3) 182 | x4 = self.max_pool_2x2(x3) 183 | 184 | x5 = self.down_conv3(x4) 185 | #x5 = self.down_drop3(x5) 186 | x6 = self.max_pool_2x2(x5) 187 | 188 | x7 = self.down_conv4(x6) 189 | #x7 = self.down_drop4(x7) 190 | 191 | if self.enable_attention: 192 | x_left2 = self.att_layer2(x5, x7) 193 | else: 194 | x_left2 = x5 195 | x_up2 = self.up_trans_conv2(x7) 196 | x5_cropped = self.crop_image(x_left2, x_up2) 197 | x_comb2 = torch.cat((x5_cropped, x_up2), 1) 198 | x_comb2 = self.up_conv2(x_comb2) 199 | 200 | if self.enable_attention: 201 | x_left3 = self.att_layer3(x3, x_comb2) 202 | else: 203 | x_left3 = x3 204 | x_up3 = self.up_trans_conv3(x_comb2) 205 | x3_cropped = self.crop_image(x_left3, x_up3) 206 | x_comb3 = torch.cat((x3_cropped, x_up3), 1) 207 | x_comb3 = self.up_conv3(x_comb3) 208 | 209 | if self.enable_attention: 210 | x_left4 = self.att_layer4(x1, x_comb3) 211 | else: 212 | x_left4 = x1 213 | x_up4 = self.up_trans_conv4(x_comb3) 214 | x1_cropped = self.crop_image(x_left4, x_up4) 215 | x_comb4 = torch.cat((x1_cropped, x_up4), 1) 216 | x_comb4 = self.up_conv4(x_comb4) 217 | 218 | final_output = self.final_up_conv(x_comb4) 219 | final_output = self.final_upsampler(final_output) 220 | 221 | if self.segmentation: 222 | final_output = self.final_softmax(final_output) 223 | 224 | return final_output -------------------------------------------------------------------------------- /matched_filter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pycuda.driver as cuda 3 | import pycuda.autoinit 4 | from pycuda.compiler import SourceModule 5 | 6 | """ 7 | This file implements a matched filter to pre-process the ground-truth data. 8 | A detailed description of the algorithm can be found in: 9 | 10 | C. Schüßler, M. Hoffmann, I. Ullmann, R. Ebelt and M. Vossiek, 11 | "Deep Learning Based Image Enhancement for Automotive Radar Trained With an Advanced Virtual Sensor," 12 | in IEEE Access, vol. 10, pp. 40419-40431, 2022, doi: 10.1109/ACCESS.2022.3166227, 13 | 14 | which implemented the algorithm described in: 15 | S. S. Ahmed, A. Schiessl, F. Gumbmann, M. Tiebout, S. Methfessel and L. -P. Schmidt, 16 | "Advanced Microwave Imaging," in IEEE Microwave Magazine, vol. 13, no. 6, pp. 26-43, Sept.-Oct. 2012, doi: 10.1109/MMM.2012.2205772. 17 | """ 18 | 19 | class SignalData: 20 | """ 21 | This class stores meta data of the if-signal (or beat signal) 22 | and also the if signal values for all antenna combinations itself 23 | """ 24 | 25 | def __init__(self): 26 | self.tx_positions = None 27 | self.rx_positions = None 28 | 29 | self.time_vector = None 30 | self.carrier_frequency = None 31 | self.bandwidth = None 32 | self.chirp_duration = None 33 | self.cycle_duration = None 34 | self.delta_frequency = None 35 | self.signal_type = "FMCW" 36 | # numpy array with shape (n_tx, n_rx, n_chirp, n_time) 37 | self.signals = None 38 | 39 | # add some get properties for convenience 40 | @property 41 | def number_tx(self): 42 | return self.signals.shape[0] 43 | 44 | @property 45 | def number_rx(self): 46 | return self.signals.shape[1] 47 | 48 | @property 49 | def number_chirps(self): 50 | return self.signals.shape[2] 51 | 52 | 53 | def holo_reconstruction_range_angle_cuda( 54 | range_positions, angle_positions, signal_data, z_pos=0.0, 55 | chirp_index=0, zero_padding_factor=4, transformation_matrix=np.identity(4)): 56 | """ 57 | Applies a holographic/matched filter reconstruction in range, sin(angle) space, 58 | whereby the angle_position is in sine space -> therefore it takes not angles but sin(angles) 59 | 60 | The z_pos indicates in which slice in cartesian coordinates the reconstruction in polar coordinates should be performed. 61 | A transformation_matrix can be given, which is applied to all (cartesian converted) positions before reconstruction. 62 | 63 | The signal is padded by a zero_padding_factor before an FFT is applied. 64 | Any window function has to be applied by the user in beforehand. 65 | """ 66 | 67 | mod = SourceModule(""" 68 | # define _USE_MATH_DEFINES 69 | #include 70 | #include 71 | #include 72 | #include 73 | __device__ __forceinline__ cuComplex comp_exp (float phase) 74 | { 75 | cuComplex result = make_cuComplex(cos(phase), sin(phase)); 76 | return result; 77 | } 78 | 79 | __device__ float cargf(const cuComplex& z) 80 | { 81 | return atan2(cuCimagf(z), cuCrealf(z)); 82 | } 83 | 84 | __device__ float cabsf(const cuComplex& z) 85 | { 86 | return sqrtf(cuCimagf(z)*cuCimagf(z) + cuCrealf(z)*cuCrealf(z)); 87 | } 88 | 89 | __device__ float3 operator-(const float3 &a, const float3 &b) 90 | { 91 | return make_float3(a.x-b.x, a.y-b.y, a.z-b.z); 92 | } 93 | 94 | __device__ __forceinline__ float vec_length(const float3 &a) 95 | { 96 | return sqrt(a.x*a.x + a.y*a.y + a.z*a.z); 97 | } 98 | 99 | __device__ cuComplex operator*(const cuComplex &a, const cuComplex &b) 100 | { 101 | return make_cuComplex(a.x*b.x - a.y*b.y, a.x*b.y + a.y*b.x); 102 | } 103 | 104 | __device__ cuComplex interpolate_complex(float x0, float x1, cuComplex f0, cuComplex f1, float x) 105 | { 106 | float phase0 = cargf(f0); 107 | float phase1 = cargf(f1); 108 | 109 | if (phase1 < phase0) 110 | phase1 += 2*M_PI; 111 | 112 | float mag_interp = cabsf(f0) + (cabsf(f1) - cabsf(f0))/(x1 - x0) * (x-x0); 113 | float phase_interp = phase0 + ((phase1 -phase0)/(x1 - x0)) * (x-x0); 114 | return make_cuComplex(mag_interp*cos(phase_interp), mag_interp*sin(phase_interp)); 115 | } 116 | 117 | __global__ void holo_reco 118 | ( 119 | float3* tx_antennas, 120 | int num_tx_antennas, 121 | float3* rx_antennas, 122 | int num_rx_antennas, 123 | float* trans_matrix, 124 | float* range_positions, 125 | int num_range_positions, 126 | float* angle_positions, 127 | int num_angle_positions, 128 | cuComplex* reco_image, 129 | cuComplex* frequency_signal, 130 | int frequency_signal_length, 131 | float frequency_slope, 132 | float delta_frequency, 133 | float carrier_frequency, 134 | float z_pos 135 | ) 136 | { 137 | int index_range = threadIdx.x + blockIdx.x*blockDim.x; 138 | int index_angle = threadIdx.y + blockIdx.y*blockDim.y; 139 | 140 | if (index_range >= num_range_positions || index_angle >= num_angle_positions) 141 | return; 142 | 143 | const float c = 3e8; 144 | float range_pos = range_positions[index_range]; 145 | float angle_pos = angle_positions[index_angle]; 146 | 147 | // convert to cartesian and transform 148 | float3 reco_pos_orig; 149 | reco_pos_orig.y = range_pos*cosf(asinf(angle_pos)); 150 | reco_pos_orig.x = range_pos*angle_pos; 151 | 152 | // apply matrix 153 | float3 reco_pos; 154 | reco_pos.x = trans_matrix[0]*reco_pos_orig.x + trans_matrix[1]*reco_pos_orig.y + trans_matrix[3]; 155 | reco_pos.y = trans_matrix[4]*reco_pos_orig.x + trans_matrix[5]*reco_pos_orig.y + trans_matrix[7]; 156 | reco_pos.z = z_pos; 157 | 158 | cuComplex result = make_cuComplex(1e-3,1e-3); 159 | for(int tx_index = 0; tx_index < num_tx_antennas; tx_index++) 160 | { 161 | float3 tx_pos = tx_antennas[tx_index]; 162 | for(int rx_index = 0; rx_index < num_rx_antennas; rx_index++) 163 | { 164 | float3 rx_pos = rx_antennas[rx_index]; 165 | 166 | float delay = (vec_length(reco_pos-tx_pos) + vec_length(reco_pos-rx_pos))/c; 167 | float exp_phase = -2*M_PI*carrier_frequency*delay; 168 | cuComplex weight = comp_exp(exp_phase); 169 | 170 | float x = (delay*frequency_slope); 171 | int signal_index = (int)(delay*frequency_slope / delta_frequency); 172 | 173 | int x0_array_idx = tx_index*num_rx_antennas*frequency_signal_length + rx_index*frequency_signal_length + signal_index; 174 | int x1_array_idx = tx_index*num_rx_antennas*frequency_signal_length + rx_index*frequency_signal_length + (signal_index+1); 175 | 176 | cuComplex f0 = frequency_signal[x0_array_idx]; 177 | cuComplex f1 = frequency_signal[x1_array_idx]; 178 | 179 | float x0 = signal_index*delta_frequency; 180 | float x1 = (signal_index+1)*delta_frequency; 181 | cuComplex signal_value = interpolate_complex(x0, x1, f0, f1, x); 182 | 183 | cuComplex part_result =signal_value*weight; 184 | result.x = result.x + part_result.x; 185 | result.y = result.y + part_result.y; 186 | } 187 | } 188 | int result_index = index_angle*num_range_positions + index_range; 189 | reco_image[result_index] = result; 190 | } 191 | """) 192 | 193 | bandwidth = signal_data.bandwidth 194 | chirp_duration = signal_data.chirp_duration 195 | 196 | frequency_slope = np.float32(bandwidth/chirp_duration) 197 | delta_frequency = np.float32(1.0/(chirp_duration*zero_padding_factor)) 198 | 199 | tx_antennas = np.asarray(signal_data.tx_positions) 200 | rx_antennas = np.asarray(signal_data.rx_positions) 201 | 202 | reco_trans_matrix = transformation_matrix.astype(np.float32) 203 | reco_positions_range = range_positions.astype(np.float32) 204 | reco_positions_angle = angle_positions.astype(np.float32) 205 | 206 | num_tx_antennas = np.int32(len(tx_antennas)) 207 | num_rx_antennas = np.int32(len(rx_antennas)) 208 | 209 | tx_antennas = np.ravel(tx_antennas.astype(np.float32)) 210 | rx_antennas = np.ravel(rx_antennas.astype(np.float32)) 211 | 212 | reco_image = 1e-9*np.ones((reco_positions_range.shape[0], reco_positions_angle.shape[0]), dtype=np.complex64) 213 | reco_image = np.ravel(reco_image) 214 | raw_signal = signal_data.signals 215 | 216 | # create frequency signal 217 | frequency_signal = np.empty((raw_signal.shape[0], raw_signal.shape[1], raw_signal.shape[3]*zero_padding_factor), dtype=np.complex64) 218 | for tx_index in range(len(signal_data.tx_positions)): 219 | for rx_index in range(len(signal_data.rx_positions)): 220 | zero_padded_signal = np.zeros(frequency_signal.shape[2], dtype=np.complex128) 221 | zero_padded_signal[:raw_signal.shape[3]] = raw_signal[tx_index, rx_index, chirp_index] 222 | frequency_signal[tx_index, rx_index] = np.fft.fft(zero_padded_signal) 223 | 224 | frequency_signal_length = np.int32(frequency_signal.shape[2]) 225 | 226 | # copy data to gpu 227 | reco_trans_matrix_gpu = cuda.mem_alloc(reco_trans_matrix.nbytes) 228 | reco_positions_range_gpu = cuda.mem_alloc(reco_positions_range.nbytes) 229 | reco_positions_angle_gpu = cuda.mem_alloc(reco_positions_angle.nbytes) 230 | 231 | tx_antennas_gpu = cuda.mem_alloc(tx_antennas.nbytes) 232 | rx_antennas_gpu = cuda.mem_alloc(rx_antennas.nbytes) 233 | reco_image_gpu = cuda.mem_alloc(reco_image.nbytes) 234 | frequency_signal = np.ravel(frequency_signal) 235 | frequency_signal_gpu = cuda.mem_alloc(frequency_signal.nbytes) 236 | 237 | cuda.memcpy_htod(reco_trans_matrix_gpu, reco_trans_matrix) 238 | cuda.memcpy_htod(tx_antennas_gpu, tx_antennas) 239 | cuda.memcpy_htod(rx_antennas_gpu, rx_antennas) 240 | cuda.memcpy_htod(reco_image_gpu, reco_image) 241 | cuda.memcpy_htod(frequency_signal_gpu, frequency_signal) 242 | cuda.memcpy_htod(reco_positions_range_gpu, reco_positions_range) 243 | cuda.memcpy_htod(reco_positions_angle_gpu, reco_positions_angle) 244 | 245 | num_reco_positions_range = np.int32(len(reco_positions_range)) 246 | num_reco_positions_angle = np.int32(len(reco_positions_angle)) 247 | carrier_frequency = np.float32(signal_data.carrier_frequency) 248 | 249 | func = mod.get_function("holo_reco") 250 | func( 251 | tx_antennas_gpu,\ 252 | num_tx_antennas, \ 253 | rx_antennas_gpu,\ 254 | num_rx_antennas,\ 255 | reco_trans_matrix_gpu,\ 256 | reco_positions_range_gpu,\ 257 | num_reco_positions_range,\ 258 | reco_positions_angle_gpu,\ 259 | num_reco_positions_angle,\ 260 | reco_image_gpu,\ 261 | frequency_signal_gpu, 262 | frequency_signal_length,\ 263 | frequency_slope,\ 264 | delta_frequency,\ 265 | carrier_frequency,\ 266 | np.float32(z_pos),\ 267 | block=(8, 8, 1), grid=(128, 128, 1)) 268 | cuda.memcpy_dtoh(reco_image, reco_image_gpu) 269 | reco_image = reco_image.reshape((reco_positions_angle.shape[0], reco_positions_range.shape[0])) 270 | reco_image = reco_image / (num_rx_antennas*num_tx_antennas) 271 | 272 | return reco_image 273 | -------------------------------------------------------------------------------- /sigproc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def create_steering_vector(tx_positions : np.ndarray, rx_positions : np.ndarray, sample_vector = None) -> np.ndarray: 5 | """ 6 | Create a steering vector, which samples in the sine domain from -0.85 to 0.85 if not sample_vector is given 7 | """ 8 | 9 | if sample_vector is None: 10 | angular_size=450 11 | fc = 77e9 12 | c=3e8 13 | wavelength = c/fc 14 | sample_vector = np.linspace(-0.85, 0.85, angular_size, endpoint=False) 15 | sample_vector = np.flip(sample_vector)[:, np.newaxis] 16 | 17 | if sample_vector.ndim == 1: 18 | sample_vector = sample_vector[:, np.newaxis] 19 | 20 | input_virtual_positions = np.empty((1, len(tx_positions)*len(rx_positions))) 21 | for pos_idx, rx_pos in enumerate(rx_positions): 22 | virt_pos = rx_pos + tx_positions 23 | input_virtual_positions[0, pos_idx*len(tx_positions):(pos_idx+1)*len(tx_positions)] = virt_pos 24 | 25 | steering_vector = np.exp(2.0j*np.pi/wavelength*(sample_vector@input_virtual_positions)) 26 | return steering_vector 27 | 28 | 29 | def create_normalized_cov_mat(channel_signal : np.ndarray): 30 | """ 31 | expects a 1D-signal and creates a normalized sample covariance matrix 32 | """ 33 | 34 | channel_signal_ext = channel_signal[:, np.newaxis] 35 | cov_mat = channel_signal_ext@np.conj(channel_signal_ext.T) 36 | cov_mat /= np.linalg.norm(cov_mat) 37 | 38 | return cov_mat 39 | 40 | def normalize_data(input : np.ndarray): 41 | """ 42 | normalize data to the range 0->1 for numpy arrays and pytorch tensors 43 | """ 44 | 45 | if type(input) == np.ndarray: 46 | min_val = np.min(input) 47 | max_val = np.max(input) 48 | input_out = (input - min_val) / (max_val - min_val) 49 | return input_out 50 | elif type(input) is torch.Tensor: 51 | min_val = torch.min(input) 52 | max_val = torch.max(input) 53 | input_out = (input - min_val) / (max_val - min_val) 54 | return input_out 55 | else: 56 | raise Exception("Data type not supported. Supported is np.ndarray and torch.tensor") 57 | 58 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from typing import Dict 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | import matplotlib.pyplot as plt 8 | from .FullRadarCubeDataset import FullRadarCubeDataset, FullRadarCubeDatasetConfig 9 | from .RadarUnet import RadarUNet 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch import nn 12 | from sigproc import * 13 | from torch.optim.lr_scheduler import MultiStepLR 14 | 15 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 16 | print('Using {} device'.format(device)) 17 | 18 | training_artifcats_dir = "/" 19 | 20 | tx_positions = np.asarray([np.array([-5.0e-3, 0.0, 0]), 21 | np.array([-3.0e-3, 0.0, 0]), 22 | np.array([-1.0e-3, 0.0, 0])]) 23 | 24 | rx_positions = np.asarray([np.array([0.0e-3, 0.0, 0.0]), 25 | np.array([6.0e-3, 0.0, 0.0]), 26 | np.array([12.0e-3, 0.0, 0.0]), 27 | np.array([18.0e-3, 0.0, 0.0]), 28 | np.array([24.0e-3, 0.0, 0.0]), 29 | np.array([30.0e-3, 0.0, 0.0]), 30 | np.array([36.0e-3, 0.0, 0.0]), 31 | np.array([42.0e-3, 0.0, 0.0]), 32 | np.array([48.0e-3, 0.0, 0.0]), 33 | np.array([54.0e-3, 0.0, 0.0]), 34 | np.array([60.0e-3, 0.0, 0.0]), 35 | np.array([66.0e-3, 0.0, 0.0]), 36 | np.array([72.0e-3, 0.0, 0.0]), 37 | np.array([78.0e-3, 0.0, 0.0]), 38 | np.array([84.0e-3, 0.0, 0.0]), 39 | np.array([90.0e-3, 0.0, 0.0])]) 40 | 41 | 42 | rx_dict = {"sparse16":[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15], 43 | "sparse4":[0,1,10,15], 44 | "sparse3":[0,5,10], 45 | "sparse6":[0, 1, 6, 10, 13, 15]} 46 | 47 | virt_indices_dict = {} 48 | steering_vectors_dict = {} 49 | tx_indices = [0,1,2] 50 | for key in rx_dict.keys(): 51 | # construct steering vector 52 | steering_vector = create_steering_vector(tx_positions[:,0], rx_positions[rx_dict[key]][:,0]) 53 | steering_vectors_dict[key] = steering_vector 54 | 55 | # construct indices for virtual array (array occupation) 56 | rx_indices = np.array(rx_dict[key])*3 57 | virt_indices = [] 58 | for rx_index in rx_indices: 59 | for tx_index in tx_indices: 60 | virt_indices.append(rx_index+tx_index) 61 | virt_indices_dict[key] = virt_indices 62 | 63 | selected_antenna_config = "sparse16" 64 | 65 | def input_transform_channel(file, idx)-> torch.tensor: 66 | """ 67 | This method transforms the input data of the neural network, by 68 | requiring a h5-file handle and the current data index 69 | """ 70 | 71 | # range-channel-data (range-Doppler-fft + Doppler selection applied, but channel data unchanged) 72 | rc_image_np = np.array(file.get(f'rc_data_{idx:06d}'), dtype=np.complex64) 73 | 74 | rc_image_np = rc_image_np[virt_indices_dict[selected_antenna_config], :] 75 | rc_image = rc_image_np.T[:, np.newaxis, :] 76 | rc_image = torch.tensor(rc_image, device=device) 77 | 78 | steering_vector = torch.tensor(steering_vectors_dict[selected_antenna_config], device=device, dtype=torch.complex64) 79 | steering_vector = torch.unsqueeze(steering_vector.T, 0) 80 | ra_image = rc_image@torch.conj(steering_vector) 81 | ra_image = torch.squeeze(ra_image, 1).T 82 | ra_image_abs = torch.log10(torch.abs(ra_image)) 83 | ra_image_abs = normalize_data(ra_image_abs) 84 | 85 | rc_image_T = torch.swapaxes(rc_image, 1, 2) 86 | cov_mat_image_torch = rc_image_T@torch.conj(rc_image) 87 | cov_mat_image_torch /= torch.linalg.norm(cov_mat_image_torch) 88 | 89 | max_cov_size = min(29, rc_image_np.shape[0]) 90 | cov_mat_image_torch = cov_mat_image_torch[:, :max_cov_size, :max_cov_size] 91 | idx = torch.triu_indices(*cov_mat_image_torch.shape[1:]) 92 | cov_mat_line = cov_mat_image_torch[:, idx[0], idx[1]].T 93 | 94 | cov_mat_image = torch.zeros(ra_image.shape, dtype=torch.complex64, device=device) 95 | cov_mat_image[:cov_mat_line.shape[0], :] = cov_mat_line 96 | 97 | input_tensor = torch.stack((ra_image_abs, torch.angle(ra_image), torch.real(cov_mat_image), torch.imag(cov_mat_image), torch.angle(cov_mat_image))) 98 | return input_tensor 99 | 100 | 101 | def target_transform_simple(file, idx) -> torch.tensor: 102 | """ 103 | This method transforms the target/output data of the neural network, by 104 | requiring a h5-file handle and the current data index 105 | """ 106 | 107 | # expect image, which was created by a matched filter in range-sin(angle) coordinates 108 | ra_image = np.array(file.get(f'ra_data_matched_{idx:06d}'), dtype=np.float32) 109 | 110 | ra_image -= np.min(ra_image) 111 | ra_image /= (np.max(ra_image)+1e-2) 112 | scale_factor = 10 113 | ra_image = ra_image*scale_factor 114 | 115 | return torch.tensor(ra_image) 116 | 117 | def train(experiment_name, model, num_epochs, learning_rate, out_directory): 118 | 119 | if not os.path.exists(out_directory): 120 | os.makedirs(out_directory) 121 | 122 | # Tensorboard writer 123 | writer = SummaryWriter(f'runs/{experiment_name}') 124 | model_param_name = f"{experiment_name}.params" 125 | loss_fn = nn.MSELoss(reduction="sum") 126 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 127 | 128 | avg_test_loss_before = 10 129 | batch_step = 0 130 | epoch_step = 0 131 | avg_train_loss = 10 132 | 133 | scheduler = MultiStepLR(optimizer, milestones=[5], gamma=0.1) 134 | for epoch in range(num_epochs): 135 | print(f"[Epoch {epoch} / {num_epochs}] -> last-test-loss: {avg_test_loss_before}, last-train-loss: {avg_train_loss}, learning_rate: {scheduler.get_last_lr()}") 136 | 137 | train_data_set.shuffle_data(epoch) 138 | model.train() 139 | train_loss = 0 140 | for batch_idx, (X, y) in enumerate(train_data_loader): 141 | 142 | X = X.to(device) 143 | y = y.to(device) 144 | # Forward pass 145 | pred = model(X) 146 | if batch_idx == 0: 147 | fig, axes = plt.subplots(max(2,train_data_loader.batch_size), 3, figsize=(15, 4*train_data_loader.batch_size)) 148 | for image_idx in range(train_data_loader.batch_size): 149 | 150 | plt1 = axes[image_idx,0].imshow(X[image_idx,0].cpu().detach().numpy()) 151 | plt.colorbar(plt1 ,ax=axes[image_idx,0]) 152 | 153 | for image_idx in range(train_data_loader.batch_size): 154 | plt1 = axes[image_idx,1].imshow(pred[image_idx,0].cpu().detach().numpy()) 155 | plt2 = axes[image_idx,2].imshow(y[image_idx].cpu().detach().numpy()) 156 | plt.colorbar(plt1 ,ax=axes[image_idx,1]) 157 | plt.colorbar(plt2 ,ax=axes[image_idx,2]) 158 | writer.add_figure("Predicted-Train", fig, global_step=epoch_step) 159 | plt.close() 160 | 161 | N = 4 # increase batch size even with insufficient video memory by accumulating gradients 162 | loss = loss_fn(torch.squeeze(pred, dim=1), y) / N 163 | 164 | loss.backward() 165 | if (batch_idx+1) % N == 0: 166 | train_loss += loss.item() 167 | optimizer.step() 168 | writer.add_scalar("Training loss - Batch", loss, global_step=batch_step) 169 | batch_step += 1 170 | 171 | max_value = 0 172 | total_norm = 0 173 | for p in model.parameters(): 174 | if p.grad is None: 175 | continue 176 | param_norm = p.grad.data.norm(2) 177 | if param_norm > max_value: 178 | max_value = p.data.norm(2) 179 | total_norm += param_norm.item() ** 2 180 | total_norm = total_norm ** (1. / 2) 181 | 182 | if batch_idx % 100 == 0: 183 | print(f"batch-idx: {batch_idx:06d} -> loss: {loss.item():.4f} -> grad_sum: {total_norm:.4f} -> grad_max: {max_value:.4f}") 184 | 185 | optimizer.zero_grad() 186 | 187 | if batch_idx % 1000 == 0: 188 | model_filename = os.path.join(out_directory, f"epoch_{epoch}_{batch_idx}_{model_param_name}") 189 | torch.save(model.state_dict(), model_filename) 190 | print(f"saved model {model_filename}") 191 | 192 | # add mean avg loss for current epoch 193 | avg_train_loss = train_loss / len(train_data_loader) 194 | 195 | # Evaluate the model at the beginning of each epoch 196 | model.eval() 197 | test_loss = 0 198 | with torch.no_grad(): 199 | for batch_idx, (X, y) in enumerate(test_data_loader): 200 | 201 | X = X.to(device) 202 | y = y.to(device) 203 | pred = model(X) 204 | 205 | if batch_idx == 0: 206 | fig, axes = plt.subplots(1, 3, figsize=(15, 5)) 207 | 208 | plt0 = axes[0].imshow(X[0,0].cpu().detach().numpy()) 209 | plt.colorbar(plt0 ,ax=axes[0]) 210 | 211 | plt1 = axes[1].imshow(pred[0,0].cpu().detach().numpy()) 212 | plt2 = axes[2].imshow(y[0].cpu().detach().numpy()) 213 | plt.colorbar(plt1 ,ax=axes[1]) 214 | plt.colorbar(plt2 ,ax=axes[2]) 215 | writer.add_figure("Predicted-Test", fig, global_step=epoch_step) 216 | plt.close() 217 | test_loss += loss_fn(torch.squeeze(pred, dim=1), y).item() 218 | 219 | 220 | avg_test_loss = test_loss / len(test_data_loader) 221 | torch.save(model.state_dict(), os.path.join(out_directory, f"epoch_{epoch}_{model_param_name}")) 222 | avg_test_loss_before = avg_test_loss 223 | 224 | # plot both in tensorboard 225 | loss_dict = {"avg_train_loss":avg_train_loss, "avg_test_loss":avg_test_loss} 226 | writer.add_scalars("Loss-Epoch", loss_dict, global_step=epoch_step) 227 | 228 | epoch_step += 1 229 | scheduler.step() 230 | 231 | antenna_configs = ["sparse16"] 232 | for antenna_config_idx, antenna_config in enumerate(antenna_configs): 233 | 234 | selected_antenna_config = antenna_config 235 | experiment_name = f"U-Net_{antenna_config}_pub" 236 | 237 | print(f"run expriment: {experiment_name}") 238 | selected_rx = rx_dict[antenna_config] 239 | train_data_config = FullRadarCubeDatasetConfig() 240 | 241 | train_data_config.data_set_size = 18500 242 | train_data_config.number_valid_samples = 500 243 | train_data_config.number_test_samples = 500 244 | train_data_config.number_train_samples = 17500 245 | train_data_config.input_filename = f"input_data.h5" 246 | train_data_config.target_filename = f"target_data.h5" 247 | 248 | train_data_config.mode = "train" 249 | train_data_config.input_load_callback = input_transform_channel 250 | train_data_config.target_load_callback = target_transform_simple 251 | 252 | train_data_set = FullRadarCubeDataset(train_data_config) 253 | eval_data_config = FullRadarCubeDatasetConfig() 254 | eval_data_config.data_set_size = 18500 255 | eval_data_config.number_valid_samples = 500 256 | eval_data_config.number_test_samples = 500 257 | eval_data_config.number_train_samples = 17500 258 | eval_data_config.input_filename = f"input_data.h5" 259 | eval_data_config.target_filename = f"target_data.h5" 260 | eval_data_config.mode = "valid" 261 | 262 | eval_data_config.input_load_callback = input_transform_channel 263 | eval_data_config.target_load_callback = target_transform_simple 264 | 265 | eval_data_set = FullRadarCubeDataset(eval_data_config) 266 | x, y = train_data_set[0] 267 | 268 | # initialize model and training parameters 269 | train_data_loader = DataLoader(train_data_set, batch_size=2) 270 | test_data_loader = DataLoader(eval_data_set, batch_size=1) 271 | 272 | torch.random.manual_seed(99) 273 | model = RadarUNet(y.shape, x.shape[0]).to(device) 274 | 275 | learning_rate =1e-4 276 | num_epochs = 6 277 | 278 | train(experiment_name, model, num_epochs, learning_rate) 279 | 280 | 281 | 282 | 283 | --------------------------------------------------------------------------------