├── images ├── cnn.png ├── logo.png ├── model.png ├── logo128.png ├── mannequin.png ├── rdm2RDspects.gif ├── spectrograms.png └── real_synth_skeleton.gif ├── ifxaion ├── img │ └── aion.png └── README.md ├── EUSIPCO2022 ├── images │ ├── soli.gif │ └── spectrograms.gif └── README.md ├── _config.yml ├── configurations └── radar_configs.csv ├── assets └── css │ └── style.scss ├── .gitignore ├── utils ├── radar.py ├── skeletons.py ├── provider.py ├── synthesize.py ├── visualization.py ├── slicer.py └── preprocess.py ├── CITATION.cff ├── networks ├── img_transf.py └── perceptual.py ├── README.md ├── visualize └── notebook.ipynb └── main.py /images/cnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/cnn.png -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/logo.png -------------------------------------------------------------------------------- /images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/model.png -------------------------------------------------------------------------------- /images/logo128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/logo128.png -------------------------------------------------------------------------------- /ifxaion/img/aion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/ifxaion/img/aion.png -------------------------------------------------------------------------------- /images/mannequin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/mannequin.png -------------------------------------------------------------------------------- /images/rdm2RDspects.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/rdm2RDspects.gif -------------------------------------------------------------------------------- /images/spectrograms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/spectrograms.png -------------------------------------------------------------------------------- /EUSIPCO2022/images/soli.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/EUSIPCO2022/images/soli.gif -------------------------------------------------------------------------------- /images/real_synth_skeleton.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/images/real_synth_skeleton.gif -------------------------------------------------------------------------------- /EUSIPCO2022/images/spectrograms.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fraunhoferhhi/racpit/HEAD/EUSIPCO2022/images/spectrograms.gif -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | remote_theme: pages-themes/cayman@v0.2.0 2 | plugins: 3 | - jekyll-remote-theme 4 | title: Machine Learning for Radar 5 | description: Additional material 6 | -------------------------------------------------------------------------------- /configurations/radar_configs.csv: -------------------------------------------------------------------------------- 1 | name,I,II,III,IV 2 | ChirpsPerFrame,64,64,64,128 3 | SamplesPerChirp,256,256,256,256 4 | AdcSamplerate,2000,1500,1500,1500 5 | FramePeriod,50000,32000,50000,50000 6 | LowerFrequency,59000,58000,58000,58000 7 | UpperFrequency,61000,62000,62000,62000 8 | -------------------------------------------------------------------------------- /ifxaion/README.md: -------------------------------------------------------------------------------- 1 | # Infineon Data Acquisition ![Aion](img/aion.png) 2 | 3 | Infineon's Data Acquisition (**ifxaion**) radar module is proprietary and cannot be publicly disclosed. 4 | For inquiries hereby, please contact 5 | [Lorenzo Servadei](https://linkedin.com/in/lorenzo-servadei-32140937). 6 | -------------------------------------------------------------------------------- /assets/css/style.scss: -------------------------------------------------------------------------------- 1 | --- 2 | --- 3 | 4 | // Breakpoints 5 | $large-breakpoint: 64em !default; 6 | $medium-breakpoint: 42em !default; 7 | 8 | // Headers 9 | $header-heading-color: #fff !default; 10 | $header-bg-color: #aec16e; 11 | $header-bg-color-secondary: #88b7a7; 12 | 13 | // Text 14 | $section-headings-color: #db0036; 15 | $body-text-color: #606c71 !default; 16 | $body-link-color: #1e6bb8 !default; 17 | $blockquote-text-color: #819198 !default; 18 | 19 | // Code 20 | $code-bg-color: #f3f6fa !default; 21 | $code-text-color: #567482 !default; 22 | 23 | // Borders 24 | $border-color: #dce6f0 !default; 25 | $table-border-color: #e9ebec !default; 26 | $hr-border-color: #eff0f1 !default; 27 | 28 | @import "{{ site.theme }}"; 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | visualization/ 105 | log/ 106 | -------------------------------------------------------------------------------- /utils/radar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jul 3 16:59:19 2018 4 | 5 | @author: Rodrigo Hernangomez 6 | """ 7 | 8 | import numpy as np 9 | 10 | constants = {'c': 3e8, 'f0': 60e9} 11 | 12 | eps = 10.0 ** (- 100) 13 | 14 | 15 | def range_axis(bw, n_samples, c=constants['c']): 16 | """ 17 | Calculate range resolution and scope. 18 | :param bw: FMCW bandwidth in Hz. 19 | :param n_samples: Number of samples per chirp. 20 | :param c: Speed of light in m/s. 21 | :return: Range resolution dr and scope rmax 22 | """ 23 | dr = c / (2 * bw) 24 | rmax = dr * n_samples / 2 25 | return dr, rmax 26 | 27 | 28 | def doppler_axis(prt, n_chirps, lambda0=constants['c']/constants['f0']): 29 | """ 30 | Calculate velocity resolution and scope. 31 | :param prt: Pulse Repetition Time in seconds. 32 | :param n_chirps: Number of chirps per frame. 33 | :param lambda0: Wavelength of the radar in meters. For 60 Ghz, this equals around 5 mm. 34 | :return: Velocity resolution dv and scope vmax 35 | """ 36 | vmax = lambda0 / (4 * prt) 37 | dv = 2 * vmax / n_chirps 38 | return dv, vmax 39 | 40 | 41 | def rx_mask(rx1, rx2, rx3): 42 | mask = 0b0 43 | for i, rx in enumerate([rx1, rx2, rx3]): 44 | mask += rx << i 45 | return mask 46 | 47 | 48 | def normalize_db(db_data, axis=None): 49 | return db_data - np.max(db_data, axis=axis, keepdims=True) 50 | 51 | 52 | def absolute(rdm, normalize=False): 53 | rdm_abs = np.abs(rdm) 54 | if normalize: 55 | rdm_abs /= rdm_abs.max() 56 | return rdm_abs 57 | 58 | 59 | def mag2db(rdm, normalize=True): 60 | rdm_ = np.abs(rdm) 61 | rdm_db = 20 * np.log10(rdm_ + eps) 62 | if normalize: 63 | rdm_db = normalize_db(rdm_db) 64 | return rdm_db 65 | 66 | 67 | def db2mag(db): 68 | return np.power(10, db / 20) 69 | 70 | 71 | def complex2vector(array_cx): 72 | return np.stack([array_cx.real, array_cx.imag], axis=-1) 73 | 74 | 75 | def vector2complex(array_vec): 76 | cx_vec = array_vec.astype(np.complex) 77 | return cx_vec[..., 0] + cx_vec[..., 1] * 1j 78 | 79 | 80 | def affine_transform(x, a, b): 81 | return a * x + b 82 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use RACPIT's code or you take the publication as a reference for your research, please cite it using the preferred-citation information." 3 | authors: 4 | - family-names: Hernangómez 5 | given-names: Rodrigo 6 | orcid: https://orcid.org/0000-0002-1284-4951 7 | affiliation: Fraunhofer Heinrich Hertz Institute 8 | - family-names: Visentin 9 | given-names: Tristan 10 | affiliation: Fraunhofer Heinrich Hertz Institute 11 | - family-names: Servadei 12 | given-names: Lorenzo 13 | orcid: https://orcid.org/0000-0003-4322-834X 14 | affiliation: Infineon Technologies AG 15 | - family-names: Khodabakhshandeh 16 | given-names: Hamid 17 | affiliation: Fraunhofer Heinrich Hertz Institute 18 | - family-names: Stanczak 19 | given-names: Slawomir 20 | orcid: https://orcid.org/0000-0003-3829-4668 21 | affiliation: Fraunhofer Heinrich Hertz Institute 22 | title: "Code and additional information to 'RACPIT: Improving Radar Human Activity Classification Using Synthetic Data with Image Transformation'" 23 | version: 1.0.0 24 | doi: 10.3390/s22041519 25 | url: "https://www.mdpi.com/1424-8220/22/4/1519" 26 | preferred-citation: 27 | type: article 28 | title: 'Improving Radar Human Activity Classification Using Synthetic Data with Image Transformation' 29 | abstract: 30 | authors: 31 | - family-names: Hernangómez 32 | given-names: Rodrigo 33 | orcid: https://orcid.org/0000-0002-1284-4951 34 | affiliation: Fraunhofer Heinrich Hertz Institute 35 | - family-names: Visentin 36 | given-names: Tristan 37 | affiliation: Fraunhofer Heinrich Hertz Institute 38 | - family-names: Servadei 39 | given-names: Lorenzo 40 | orcid: https://orcid.org/0000-0003-4322-834X 41 | affiliation: Infineon Technologies AG 42 | - family-names: Khodabakhshandeh 43 | given-names: Hamid 44 | affiliation: Fraunhofer Heinrich Hertz Institute 45 | - family-names: Stanczak 46 | given-names: Slawomir 47 | orcid: https://orcid.org/0000-0003-3829-4668 48 | affiliation: Fraunhofer Heinrich Hertz Institute 49 | year: 2022 50 | month: 1 51 | volume: 22 52 | issue: 4 53 | start: 1519 54 | issn: '1424-8220' 55 | journal: Sensors 56 | doi: "10.3390/s22041519" 57 | 58 | -------------------------------------------------------------------------------- /networks/img_transf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Conv Layer 5 | class ConvLayer(nn.Module): 6 | def __init__(self, in_channels, out_channels, kernel_size, stride): 7 | super(ConvLayer, self).__init__() 8 | padding = kernel_size // 2 9 | self.reflection_pad = nn.ReflectionPad2d(padding) 10 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) #, padding) 11 | 12 | def forward(self, x): 13 | out = self.reflection_pad(x) 14 | out = self.conv2d(out) 15 | return out 16 | 17 | # Upsample Conv Layer 18 | class UpsampleConvLayer(nn.Module): 19 | def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None): 20 | super(UpsampleConvLayer, self).__init__() 21 | self.upsample = upsample 22 | if upsample: 23 | self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest') 24 | reflection_padding = kernel_size // 2 25 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 26 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) 27 | 28 | def forward(self, x): 29 | if self.upsample: 30 | x = self.upsample(x) 31 | out = self.reflection_pad(x) 32 | out = self.conv2d(out) 33 | return out 34 | 35 | # Residual Block 36 | # adapted from pytorch tutorial 37 | # https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02- 38 | # intermediate/deep_residual_network/main.py 39 | class ResidualBlock(nn.Module): 40 | def __init__(self, channels): 41 | super(ResidualBlock, self).__init__() 42 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1) 43 | self.in1 = nn.InstanceNorm2d(channels, affine=True) 44 | self.relu = nn.ReLU() 45 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1) 46 | self.in2 = nn.InstanceNorm2d(channels, affine=True) 47 | 48 | def forward(self, x): 49 | residual = x 50 | out = self.relu(self.in1(self.conv1(x))) 51 | out = self.in2(self.conv2(out)) 52 | out = out + residual 53 | out = self.relu(out) 54 | return out 55 | 56 | # Image Transform Network 57 | class ImageTransformNet(nn.Module): 58 | def __init__(self, num_channels=3): 59 | super(ImageTransformNet, self).__init__() 60 | 61 | # nonlineraity 62 | self.relu = nn.ReLU() 63 | self.tanh = nn.Tanh() 64 | 65 | # encoding layers 66 | self.conv1 = ConvLayer(num_channels, 32, kernel_size=9, stride=1) 67 | self.in1_e = nn.InstanceNorm2d(32, affine=True) 68 | 69 | self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2) 70 | self.in2_e = nn.InstanceNorm2d(64, affine=True) 71 | 72 | self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2) 73 | self.in3_e = nn.InstanceNorm2d(128, affine=True) 74 | 75 | # residual layers 76 | self.res1 = ResidualBlock(128) 77 | self.res2 = ResidualBlock(128) 78 | self.res3 = ResidualBlock(128) 79 | self.res4 = ResidualBlock(128) 80 | self.res5 = ResidualBlock(128) 81 | 82 | # decoding layers 83 | self.deconv3 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2 ) 84 | self.in3_d = nn.InstanceNorm2d(64, affine=True) 85 | 86 | self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2 ) 87 | self.in2_d = nn.InstanceNorm2d(32, affine=True) 88 | 89 | self.deconv1 = UpsampleConvLayer(32, num_channels, kernel_size=9, stride=1) 90 | self.in1_d = nn.InstanceNorm2d(num_channels, affine=True) 91 | 92 | def forward(self, x): 93 | # encode 94 | y = self.relu(self.in1_e(self.conv1(x))) 95 | y = self.relu(self.in2_e(self.conv2(y))) 96 | y = self.relu(self.in3_e(self.conv3(y))) 97 | 98 | # residual layers 99 | y = self.res1(y) 100 | y = self.res2(y) 101 | y = self.res3(y) 102 | y = self.res4(y) 103 | y = self.res5(y) 104 | 105 | # decode 106 | y = self.relu(self.in3_d(self.deconv3(y))) 107 | y = self.relu(self.in2_d(self.deconv2(y))) 108 | #y = self.tanh(self.in1_d(self.deconv1(y))) 109 | y = self.deconv1(y) 110 | 111 | return y 112 | 113 | 114 | class MultiTransformNet(nn.Module): 115 | def __init__(self, num_inputs=2, num_channels=1): 116 | super(MultiTransformNet, self).__init__() 117 | self.transformers = nn.ModuleList([ImageTransformNet(num_channels) for _ in range(num_inputs)]) 118 | 119 | def forward(self, inputs): 120 | outputs = [trans(x) for x, trans in zip(inputs, self.transformers)] 121 | return outputs 122 | -------------------------------------------------------------------------------- /utils/skeletons.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ifxaion.daq import Daq 3 | import scipy.interpolate 4 | import xarray as xr 5 | 6 | 7 | def rcsellipsoid(a, b, c, phi,theta): 8 | rcs = (np.pi * a**2 * b**2 * c**2) / (a**2 * (np.sin(theta))**2 * (np.cos(phi))**2 + 9 | b**2 * (np.sin(theta))**2 * (np.sin(phi))**2 + c**2 * (np.cos(theta))**2)**2 10 | return rcs 11 | 12 | 13 | class Segment: 14 | """ Defines a body segment and allows to calculate RCS of segment """ 15 | 16 | def __init__(self, segmentPositions, ellipsoidParams, aspect): 17 | self.ellipsoidParams = ellipsoidParams 18 | self.segmentPositions = segmentPositions 19 | self.aspect = aspect 20 | 21 | 22 | def calculateRange(self, radarloc): 23 | """ 24 | Calculates the distance to the radar for all segment positions 25 | :return: vector of the distances [d1, d2, d3, ...] 26 | """ 27 | self.r_dist = np.abs(self.segmentPositions - radarloc) 28 | self.r_total = np.sqrt(self.r_dist[:, 0] ** 2 + self.r_dist[:, 1] ** 2 + self.r_dist[:, 2] ** 2) 29 | return self.r_total 30 | 31 | 32 | def calculateAngles(self, radarloc): 33 | """ 34 | Calculates the angles to the radar for all segment positions 35 | :return: vector of phi [phi1, phi2, phi3,...] and vector of theta [theta1,theta2,theta3,...] 36 | """ 37 | A = np.column_stack(( 38 | radarloc[0] - self.segmentPositions[:, 0], 39 | radarloc[1] - self.segmentPositions[:, 1], 40 | radarloc[2] - self.segmentPositions[:, 2] 41 | )) 42 | 43 | B = np.column_stack(( 44 | self.aspect[:, 0], 45 | self.aspect[:, 1], 46 | self.aspect[:, 2] 47 | )) 48 | 49 | a_dot_b = np.sum(A * B, axis=1) 50 | a_sum_sqrt = np.sqrt(np.sum(A * A, axis=1)) 51 | b_sum_sqrt = np.sqrt(np.sum(B * B, axis=1)) 52 | 53 | theta = np.arccos(a_dot_b / (a_sum_sqrt * b_sum_sqrt)) 54 | phi = np.arcsin((radarloc[1] - self.segmentPositions[:, 1]) / 55 | np.sqrt(self.r_dist[:, 0] ** 2 + self.r_dist[:, 1] ** 2)) 56 | 57 | return phi, theta 58 | 59 | def calculateRCS(self, phiAngle, thetaAngle): 60 | """ 61 | Calculates the RCS 62 | :return: vector of RCS [rcs1, rcs2, rcs3,...] 63 | """ 64 | a = self.ellipsoidParams[0] 65 | b = self.ellipsoidParams[1] 66 | c = self.ellipsoidParams[2] 67 | rcs = rcsellipsoid(a, b, c, phiAngle, thetaAngle) 68 | return rcs 69 | 70 | 71 | def load(path, verbose=False): 72 | """ 73 | Read the skeleton data with help of the Daq module. 74 | """ 75 | 76 | daq = Daq(rec_dir=path) 77 | 78 | skeleton_df = daq.skeletons.data 79 | 80 | skeletonTimestamps = skeleton_df.index 81 | 82 | duration = (skeletonTimestamps[-1]-skeletonTimestamps[0]).total_seconds() 83 | 84 | if verbose: 85 | print("Reading skeleton data: {} skeletons, in {} second recording".format(len(skeleton_df), duration)) 86 | 87 | return skeleton_df 88 | 89 | 90 | def interpolate(skeletons, timestamp_seconds): 91 | sk_data = np.array(skeletons.Data) 92 | sk_data = np.stack(sk_data) 93 | sk_timestamps = skeletons.index 94 | 95 | sk_interpolation = scipy.interpolate.interp1d(sk_timestamps.total_seconds(), sk_data, 96 | axis=0, fill_value="extrapolate", kind='linear') 97 | sk_new = sk_interpolation(timestamp_seconds) 98 | return sk_new 99 | 100 | 101 | def to_xarray(data, timestamps, name="Skeletons", attrs=None): 102 | skeleton_da = xr.DataArray(data, dims=("time", "space", "keypoints"), 103 | name=name, attrs=attrs, 104 | coords={"time": timestamps.to_numpy(), 105 | "space": ["x", "y", "z"], 106 | "keypoints": list(keypoints)}).assign_attrs(units="m") 107 | return skeleton_da 108 | 109 | 110 | def get_edges(skeleton): 111 | if isinstance(skeleton, xr.DataArray): 112 | sk_data = skeleton.values 113 | else: 114 | sk_data = skeleton 115 | return np.stack([[sk_frame[:, segment] for segment in coco_edges] for sk_frame in sk_data]) 116 | 117 | 118 | keypoints = ("nose", "left_eye", "right_eye", "left_ear", "right_ear", 119 | "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", 120 | "left_wrist", "right_wrist", "left_hip", "right_hip", 121 | "left_knee", "right_knee", "left_ankle", "right_ankle") 122 | 123 | # Edges between keypoints used for the COCO API to visualize skeletons 124 | # See https://github.com/facebookresearch/Detectron/issues/640 125 | coco_edges = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], 126 | [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]] 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RACPIT 2 | [![NumPy](https://img.shields.io/badge/numpy-%23013243.svg?style=for-the-badge&logo=numpy&logoColor=white)][numpy] 3 | [![PyTorch](https://img.shields.io/badge/PyTorch-%23EE4C2C.svg?style=for-the-badge&logo=PyTorch&logoColor=white)][pytorch] 4 | [![Pandas](https://img.shields.io/badge/pandas-%23150458.svg?style=for-the-badge&logo=pandas&logoColor=white)][pandas] 5 | 6 | This repository contains supplementary material for our article 7 | ["Improving Radar Human Activity Classification 8 | Using Synthetic Data with Image Transformation"](https://www.mdpi.com/1424-8220/22/4/1519), 9 | published in 10 | [MDPI Sensors](https://www.mdpi.com/journal/sensors) 11 | as part of the 12 | [Special Issue "Advances in Radar Sensors"](https://www.mdpi.com/journal/sensors/special_issues/radar_application). 13 | There we introduce **RACPIT**: 14 | Radar Activity Classification with Perceptual Image Transformation, 15 | a deep-learning approach to human activity classification using 16 | [FMCW radar](https://community.infineon.com/t5/Knowledge-Base-Articles/Understanding-FMCW-Radars-Features-and-operational-principles/ta-p/767198) 17 | and enhanced with synthetic data. 18 | 19 | ## Background 20 | 21 | ### Radar data 22 | 23 | We use **Range Doppler Maps (RDMs)** 24 | as a basis for our input data. These can be either real data acquired 25 | with Infineon's 26 | [Radar sensors for IoT](https://www.infineon.com/cms/en/product/sensor/radar-sensors/radar-sensors-for-iot/) 27 | or synthetic using kinematic data with the following model: 28 | 29 | $\Large s\left(t\right)=\sum_{k}{\sqrt{\frac{A_{k,t}}{L_{k,t}}}\sin{\left(2\pi f_{k,t}t+\phi_{k,t}\right)}}$ 30 | 31 |
32 | Human reflection model 33 |
34 | 35 | $A_{k,t}$, 36 | $L_{k,t}$, 37 | $f_{k,t}$ and 38 | $\phi_{k,t}$ 39 | represent the radar cross section, free-space path loss, 40 | instant frequency and instant phase, respectively, 41 | of the returned and mixed-down signal for every modelled human limb 42 | $k$ 43 | and instant 44 | $t$. 45 | The latter three parameters depend 46 | on the instantaneous distance of the limb to the radar sensor, 47 | $d_{k,t}$, 48 | and are calculated using the customary 49 | [radar](https://www.radartutorial.eu/01.basics/The%20Radar%20Range%20Equation.en.html) and 50 | [FMCW](https://www.radartutorial.eu/02.basics/Frequency%20Modulated%20Continuous%20Wave%20Radar.en.html) 51 | equations. 52 | 53 | ![Simulation animation](images/real_synth_skeleton.gif) 54 | 55 | We further preprocess the RDMs by stacking them and summing over Doppler and range axis 56 | to obtain range and Doppler spectrograms, respectively: 57 | 58 | ![Radar spectrogram extraction](images/rdm2RDspects.gif) 59 | 60 | ### Deep learning 61 | 62 | We train our image transformation networks with an adapted version of 63 | [Perceptual Losses for Real-Time Style Transfer and Super-Resolution][perceptual]. 64 | 65 | [perceptual]: https://arxiv.org/abs/1603.08155 66 | 67 | ![RACPIT model](images/model.png) 68 | 69 | Since we are working with radar data, we substitute VGG16 as the perceptual network 70 | with our two-branch convolutional neural network from 71 | [Domain Adaptation Across Configurations of FMCW Radar for Deep Learning Based Human Activity Classification](https://doi.org/10.23919/IRS51887.2021.9466179). 72 | 73 | 74 | 75 | If we train the image transformation networks with real data as our input and synthetic data as our ground truth, 76 | we obtain a denoising behavior for the image transformation networks. 77 | 78 | 79 | 80 | ## Implementation 81 | 82 | The code has been written for 83 | PyTorch based on 84 | [Daniel Yang's implementation](https://github.com/dxyang/StyleTransfer) 85 | of [Perceptual loss][perceptual]. 86 | 87 | Data preprocessing is heavily based on 88 | *x*array. You can take a closer look at it 89 | in our 90 | [example](./visualize). 91 | 92 | ### Prerequisites 93 | - [Python 3.8](https://www.python.org/) 94 | - [PyTorch 1.7.0][pytorch] 95 | - [*x*array](https://xarray.pydata.org) 96 | - [NumPy][numpy] 97 | - [Pandas][pandas] 98 | - [Matplotlib](https://matplotlib.org/) 99 | - [Cuda 11.0](https://developer.nvidia.com/cuda-11.0-download-archive) 100 | (For GPU training) 101 | 102 | [numpy]: http://www.numpy.org/ 103 | [pytorch]: http://pytorch.org/ 104 | [pandas]: https://pandas.pydata.org/ 105 | 106 | ### Usage 107 | 108 | Radar data can be batch-preprocessed and stored 109 | for faster training: 110 | 111 | ```bash 112 | $ python utils/preprocess.py --raw "/path/to/data/raw" --output "/path/to/data/real" --value "db" --marginalize "incoherent" 113 | $ python utils/preprocess.py --raw "/path/to/data/raw" --output "/path/to/data/synthetic" --synthetic --value "db" --marginalize "incoherent" 114 | ``` 115 | 116 | After this, you can train your CNN, that will serve as a perceptual network: 117 | 118 | ```bash 119 | $ python main.py --log "cnn" train-classify --range --config "I" --gpu 0 --no-split --dataset "/path/to/data/synthetic" 120 | ``` 121 | 122 | Then you can train the image transformation networks: 123 | 124 | ```bash 125 | $ python main.py --log "trans" train-transfer --range --config "I" --gpu 0 --visualize 5 --input "/path/to/data/real" --output "/path/to/data/synthetic" --recordings first --model "models/cnn.model" 126 | ``` 127 | 128 | And finally test the whole pipeline: 129 | 130 | ```bash 131 | $ python main.py test --range --config "I" --gpu 0 --visualize 10 --dataset "/path/to/data/real" --recordings last --transformer "models/trans.model" --model "models/cnn.model" 132 | ``` 133 | 134 | ## Citation 135 | 136 | If you use RACPIT's code or you take the publication as a reference for your research, 137 | please cite our work in the following way: 138 | 139 | ```bibtex 140 | @Article{s22041519, 141 | AUTHOR = {Hernang{\'o}mez, Rodrigo and Visentin, Tristan and Servadei, Lorenzo and Khodabakhshandeh, Hamid and Sta{\'n}czak, S{\l}awomir}, 142 | TITLE = {Improving Radar Human Activity Classification Using Synthetic Data with Image Transformation}, 143 | JOURNAL = {Sensors}, 144 | VOLUME = {22}, 145 | YEAR = {2022}, 146 | NUMBER = {4}, 147 | ARTICLE-NUMBER = {1519}, 148 | URL = {https://www.mdpi.com/1424-8220/22/4/1519}, 149 | ISSN = {1424-8220}, 150 | DOI = {10.3390/s22041519} 151 | } 152 | ``` 153 | -------------------------------------------------------------------------------- /networks/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | from networks.img_transf import ImageTransformNet, MultiTransformNet 5 | 6 | 7 | class Vgg16(nn.Module): 8 | def __init__(self): 9 | super(Vgg16, self).__init__() 10 | # Convert 1-channel image to 3-channel 11 | model = models.vgg16(pretrained=True) 12 | first_conv_layer = [nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)] 13 | first_conv_layer.extend(list(model.features)) 14 | model.features = nn.Sequential(*first_conv_layer) 15 | features = model.features 16 | 17 | self.to_relu_1_2 = nn.Sequential() 18 | self.to_relu_2_2 = nn.Sequential() 19 | self.to_relu_3_3 = nn.Sequential() 20 | self.to_relu_4_3 = nn.Sequential() 21 | 22 | for x in range(4): 23 | self.to_relu_1_2.add_module(str(x), features[x]) 24 | for x in range(4, 9): 25 | self.to_relu_2_2.add_module(str(x), features[x]) 26 | for x in range(9, 16): 27 | self.to_relu_3_3.add_module(str(x), features[x]) 28 | for x in range(16, 23): 29 | self.to_relu_4_3.add_module(str(x), features[x]) 30 | 31 | # don't need the gradients, just want the features 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | def forward(self, x): 36 | h = self.to_relu_1_2(x) 37 | h_relu_1_2 = h 38 | h = self.to_relu_2_2(h) 39 | h_relu_2_2 = h 40 | h = self.to_relu_3_3(h) 41 | h_relu_3_3 = h 42 | h = self.to_relu_4_3(h) 43 | h_relu_4_3 = h 44 | out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3) 45 | return out 46 | 47 | 48 | class RDConv(nn.Module): 49 | """ 50 | Convolutional layer(s) for radar spectrograms 51 | """ 52 | def __init__(self, num_branches): 53 | super(RDConv, self).__init__() 54 | self.branches = nn.ModuleList([self._branch() for _ in range(num_branches)]) 55 | 56 | @staticmethod 57 | def _branch(): 58 | conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(7, 7), padding=(3, 3)) 59 | relu1 = nn.ReLU() 60 | max1 = nn.MaxPool2d(4, 3) 61 | conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5, 5), padding=(2, 2)) 62 | relu2 = nn.ReLU() 63 | max2 = nn.MaxPool2d(3, 2) 64 | conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), padding=(1, 1)) 65 | relu3 = nn.ReLU() 66 | max3 = nn.MaxPool2d(3, 2) 67 | 68 | return nn.ModuleList([conv1, relu1, max1, conv2, relu2, max2, 69 | conv3, relu3, max3]) 70 | 71 | def _branch_forward(self, x, branch=0, flatten=True): 72 | for layer in self.branches[branch]: 73 | x = layer(x) 74 | if flatten: 75 | x = torch.flatten(x, start_dim=1) 76 | return x 77 | 78 | def forward(self, inputs): 79 | if len(self.branches) > 1 or any(isinstance(inputs, _type) for _type in (list, tuple)): 80 | features = [] 81 | for i, x in enumerate(inputs): 82 | features.append(self._branch_forward(x, i)) 83 | return torch.cat(features, dim=1) 84 | else: 85 | return self._branch_forward(inputs) 86 | 87 | def output_num(self, input_shapes): 88 | inputs = [torch.zeros(1, *s[1:]) for s in input_shapes] 89 | if len(inputs) == 1: 90 | inputs = inputs[0] 91 | y = self.forward(inputs) 92 | return y.size()[-1] 93 | 94 | 95 | class RDNet(nn.Module): 96 | def __init__(self, input_shapes=None, class_num=5): 97 | super(RDNet, self).__init__() 98 | # set base network 99 | 100 | branches = len(input_shapes) 101 | self.conv_layers = RDConv(branches) 102 | features_num = self.conv_layers.output_num(input_shapes) 103 | self.classifier_layer_list = [nn.Linear(features_num, 16), nn.ReLU(), nn.Dropout(0.1), 104 | nn.Linear(16, 16), nn.ReLU(), nn.Dropout(0.1), 105 | nn.Linear(16, 8), nn.ReLU(), nn.Dropout(0.1), 106 | nn.Linear(8, class_num)] 107 | self.classifier_layer = nn.Sequential(*self.classifier_layer_list) 108 | self.softmax = nn.Softmax(dim=1) 109 | 110 | # initialization 111 | for dep in range(4): 112 | self.classifier_layer[dep * 3].weight.data.normal_(0, 0.01) 113 | self.classifier_layer[dep * 3].bias.data.fill_(0.0) 114 | 115 | def forward(self, inputs): 116 | features = self.conv_layers(inputs) 117 | outputs = self.classifier_layer(features) 118 | return outputs 119 | 120 | def predict(self, inputs): 121 | outputs = self.forward(inputs) 122 | return self.softmax(outputs) 123 | 124 | 125 | class RDPerceptual(RDNet): 126 | """ 127 | Use a trained RD classifier for perceptual loss 128 | """ 129 | 130 | def __init__(self, model_path, input_shapes=None, class_num=5): 131 | super(RDPerceptual, self).__init__(input_shapes=input_shapes, class_num=class_num) 132 | self.load_state_dict(torch.load(model_path)) 133 | 134 | for param in self.parameters(): 135 | param.requires_grad = False 136 | 137 | def forward(self, inputs): 138 | features = self.conv_layers(inputs) 139 | outputs = self.classifier_layer(features) 140 | return outputs, features 141 | 142 | 143 | class RACPIT(RDNet): 144 | """ 145 | RD classifier including image transformer 146 | """ 147 | def __init__(self, trans_path, model_path, input_shapes=None, class_num=5, train_classifier=False): 148 | super(RACPIT, self).__init__(input_shapes=input_shapes, class_num=class_num) 149 | self.load_state_dict(torch.load(model_path)) 150 | 151 | if len(input_shapes) > 1: 152 | self.transformer = MultiTransformNet(num_inputs=len(input_shapes), num_channels=1) 153 | else: 154 | self.transformer = ImageTransformNet(num_channels=1) 155 | self.transformer.load_state_dict(torch.load(trans_path)) 156 | 157 | frozen_params = self.transformer.parameters() if train_classifier else self.parameters() 158 | for param in frozen_params: 159 | param.requires_grad = False 160 | 161 | def forward(self, x): 162 | if len(x) == 1: 163 | x = x[0] 164 | x_trans = self.transformer(x) 165 | return super(RACPIT, self).forward(x_trans) 166 | -------------------------------------------------------------------------------- /EUSIPCO2022/README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Domain Adaptation across FMCW Radar Configurations Using Margin Disparity Discrepancy 2 | 3 | The content of this page serves as supplementary material to the presentation of our work at [EUSIPCO2022](https://2022.eusipco.org/) in Belgrade. [A preprint is also available][preprint]. 4 | 5 | [preprint]: https://arxiv.org/abs/2203.04588 6 | 7 | ## Radar ML 8 | 9 | Radar sensing has gained research interest over the past few years due to several 10 | factors. On the one hand, semiconductor companies have managed to produce highly 11 | integrated radar chipsets thanks to the 12 | [Frequency-modulated continuous 13 | wave (FMCW)](https://www.infineon.com/dgdl/Infineon-Radar%20FAQ-PI-v02_00-EN.pdf?fileId=5546d46266f85d6301671c76d2a00614) technology. An example of this is the 14 | [60 GHz sensor frontend](https://www.infineon.com/cms/en/product/promopages/60GHz/) 15 | developed by [Infineon](https://www.infineon.com/) 16 | that lies at the core of Google's [Soli Project](https://atap.google.com/soli/). 17 | 18 | ![Google's Soli Project](images/soli.gif) 19 | 20 | On the other hand, radar sensors enable interesting 21 | [IoT applications](https://www.infineon.com/cms/en/product/sensor/radar-sensors/radar-sensors-for-iot/), such as human activity and surveillance or hand gesture 22 | recognition, and it presents some advantages to cameras when 23 | it comes to privacy concerns and ill-posed optical scenarios, including 24 | through-wall and bad-lighting situations. 25 | 26 | All the data used in this work has been recorded with Infineon's 27 | [BGT60TR13C](https://www.infineon.com/cms/en/product/sensor/radar-sensors/radar-sensors-for-iot/60ghz-radar/bgt60tr13c/) 28 | radar sensors. 29 | 30 | ### Signal preprocessing 31 | 32 | Range and speed (also known as Doppler shift) of FMCW radar targets can 33 | be explored through so-called Range-Doppler maps, which convey information 34 | from the radar targets over both dimensions. 35 | 36 | Since a series of Range Doppler maps is 3-dimensional and thus requires large neural networks to work with, 37 | we extract so-called range and Doppler spectrograms by summing over 38 | the corresponding axis. This allows us to train a light neural network that takes both range and 39 | Doppler dynamic signatures and extracts the most relevant features from 40 | them through two separate branches convolutional layers. 41 | 42 | ![Extraction of range and Doppler spectrograms from Range-Doppler maps](images/spectrograms.gif) 43 | 44 | ![CNN architecture details](images/cnn.svg) 45 | 46 | ## Domain adaptation 47 | 48 | The complexity of the classification of radar signatures has driven radar approaches to 49 | resort to machine learning and deep learning techniques that require 50 | high amounts of data. This poses a challenge to radar, especially 51 | when aspects such as the sensor's configuration, the environment or user 52 | settings have a large diversity. Each of these aspects may indeed lead to changes in the data domain in the sense of **domain adaptation theory**. 53 | 54 | Domain adaptation techniques consider the case where the underlying probability distribution of data differs 55 | between a **source domain**, from which we can draw sufficient training 56 | data, and a **target domain** that we ultimately aim for deployment. 57 | 58 | Such a framework allows an easier deployment from pre-trained models, 59 | where only few labeled data from the target domain is required for fine-tuning. 60 | This is known as **supervised domain adaptation** and we 61 | already investigated for radar in [[3]](#ref3). For this paper, we have 62 | moved on to the more demanding situation where the target data is 63 | provided without labels from the label space, also called **unsupervised 64 | domain adaptation**. 65 | 66 | ### Margin Disparity Discrepancy 67 | 68 | The unsupervised domain adaptation 69 | method we have used is called Margin Disparity Discrepancy (MDD) and it has been recently 70 | developed by Zhang et al. [[1]](#ref1) based on theoretical guarantees and tested on 71 | computer vision datasets. The authors have proved that the true error in the 72 | target domain can be bounded by the empirical error in the source domain 73 | plus a residual ideal loss and the so-called MDD term. 74 | 75 | As a small **theoretical contribution to MDD**, we have managed to upper 76 | bound the empirical source error from the original, which originally 77 | uses the ramp loss, by its cross-entropy counterpart. 78 | This is important since the ramp loss is non-convex und thus unpractical for ML training. 79 | In fact, the authors of [[1]](#ref1) perform their experiments with the cross-entropy loss. 80 | 81 | The full proof for the cross-entropy bound is in [our paper][preprint]. As a visual intuition, 82 | though, one can just regard the cross-entropy as a smooth 83 | version of hinge loss, which in turn is nothing more 84 | than a convex relaxation of the ramp loss. 85 | 86 | ![Different loss functions](images/losses.svg) 87 | 88 | ## References 89 | 90 | 1. Zhang, Y., Liu, T., Long, M. and Jordan, M. (2019) 'Bridging Theory 91 | and Algorithm for Domain Adaptation', in *International Conference 92 | on Machine Learning*. *International Conference on Machine 93 | Learning*, PMLR, pp. 7404--7413. Available at: . 94 | 2. Liang, X., Wang, X., Lei, Z., Liao, S. and Li, S.Z. (2017) 95 | 'Soft-Margin Softmax for Deep Classification', in D. Liu, S. Xie, Y. 96 | Li, D. Zhao, and E.-S.M. El-Alfy (eds) *Neural Information 97 | Processing*. Cham: Springer International Publishing (Lecture Notes 98 | in Computer Science), pp. 413--421. Available at: 99 | . 100 | 3. Khodabakhshandeh, H., Visentin, T., Hernangómez, R. and 101 | Pütz, M. (2021) 'Domain Adaptation Across Configurations of FMCW 102 | Radar for Deep Learning Based Human Activity Classification', in 103 | *2021 21st International Radar Symposium (IRS)*. *2021 21st 104 | International Radar Symposium (IRS)*, Berlin, Germany, pp. 1--10. Available at: . 105 | 4. Hernangómez, R., Bjelakovic, I., Servadei, L. and 106 | Stańczak, S. (2022) 'Unsupervised Domain Adaptation across FMCW 107 | Radar Configurations Using Margin Disparity Discrepancy', in *2022 108 | 30th European Signal Processing Conference (EUSIPCO)*. Belgrade, 109 | Serbia. Available at: . 110 | 5. Hernangómez, R., Santra, A. and Stańczak, S. (2019) 'Human Activity 111 | Classification with Frequency Modulated Continuous Wave Radar Using 112 | Deep Convolutional Neural Networks', in *2019 International Radar 113 | Conference (RADAR)*. *2019 International Radar Conference (RADAR)*, 114 | Toulon, France: IEEE, pp. 1--6. Available at: 115 | . 116 | 6. Hernangómez, R., Santra, A. and Stańczak, S. (2021) 'Study on 117 | feature processing schemes for deep-learning-based human activity 118 | classification using frequency-modulated continuous-wave radar', 119 | *IET Radar, Sonar & Navigation*, 15(8), pp. 932--944. Available at: 120 | . 121 | 7. Lien, J., Gillian, N., Karagozler, M.E., Amihood, P., Schwesig, C., 122 | Olson, E., Raja, H. and Poupyrev, I. (2016) 'Soli: ubiquitous 123 | gesture sensing with millimeter wave radar', *ACM Transactions on 124 | Graphics*, 35(4), p. 142:1-142:19. Available at: 125 | . 126 | 8. Motiian, S., Jones, Q., Iranmanesh, S. and Doretto, G. (2017) 127 | 'Few-Shot Adversarial Domain Adaptation', *Advances in Neural 128 | Information Processing Systems*, 30. Available at: 129 | . 130 | 9. Santra, A. and Hazra, S. (2020) *Deep learning applications of 131 | short-range radars*. Artech House. 132 | -------------------------------------------------------------------------------- /utils/provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from collections.abc import Sequence 4 | 5 | from utils import preprocess 6 | from utils.slicer import train_test_slice 7 | from utils.radar import affine_transform, normalize_db as norm_db 8 | from utils.visualization import spec_plot 9 | 10 | import torch.utils.data as util_data 11 | 12 | 13 | class _TimeInfo(list): 14 | def __init__(self, dataset, time_labels=("date", "offset"), dim="batch_index", drop_attrs=True): 15 | self._ds = dataset[list(time_labels)] 16 | if drop_attrs: 17 | self._ds.attrs = {} 18 | self.iter_dim = dim 19 | super(_TimeInfo, self).__init__() 20 | 21 | def __getitem__(self, index): 22 | return self._ds[{self.iter_dim: index}] 23 | 24 | def __setitem__(self, key, value): 25 | raise TypeError("Assignment operation not permitted for TimeInfo") 26 | 27 | def __len__(self): 28 | return self._ds.sizes[self.iter_dim] 29 | 30 | def __str__(self): 31 | return self._ds.__str__() 32 | 33 | 34 | class RadarDataset(util_data.Dataset): 35 | """ 36 | Base class to use as a dataset. 37 | Sequence superclass is straightforwardly exchangeable with the appropriate Tensorflow/Pytorch superclass 38 | """ 39 | def __init__(self, recordings, slices=None, normalize_db=True, clip=(-40, 0), add_channel_dim=-3, 40 | dim="batch_index", class_attr="activities", label_var="label", ignore_dims=None, 41 | norm_values=None, feature_dtype=np.single, label_dtype=np.longlong): 42 | 43 | self.label = label_var 44 | self.iter_dim = dim 45 | self.class_attr = class_attr 46 | 47 | self.ftype = feature_dtype 48 | self.ltype = label_dtype 49 | 50 | process_funcs = [] 51 | if normalize_db: 52 | preprocess.proc_register(process_funcs, norm_db, axis=(-1, -2)) 53 | if clip is not None: 54 | preprocess.proc_register(process_funcs, np.clip, clip[0], clip[1]) 55 | if norm_values is not None: # Change the range of the data from `clip` to `norm_values` 56 | a = (norm_values[1] - norm_values[0])/(clip[1] - clip[0]) 57 | b = norm_values[0] - a * clip[0] 58 | preprocess.proc_register(process_funcs, affine_transform, a, b) 59 | self._process_funcs = process_funcs 60 | 61 | if ignore_dims is True: 62 | self.ignore_dims = tuple(c for c in recordings[0].coords) 63 | else: 64 | self.ignore_dims = ignore_dims 65 | self._slices = slices 66 | self.channel_dim = add_channel_dim 67 | self._recordings = recordings 68 | self.attrs = dict_intersection([r.attrs for r in recordings]) 69 | self.total_bytes = preprocess.get_size(recordings, loaded_only=False) 70 | 71 | self.dataset = self.Iterable(self) 72 | 73 | super(RadarDataset, self).__init__() 74 | 75 | def _access_item(self, index): 76 | if isinstance(index, int): 77 | item = preprocess.slice_recording(self._recordings, self._slices[index]) 78 | elif isinstance(index, slice): 79 | item = preprocess.recs2dataset(self._recordings, self._slices[index], ignore_dims=self.ignore_dims) 80 | elif isinstance(index, np.ndarray): 81 | item = preprocess.recs2dataset(self._recordings, [self._slices[i] for i in index], 82 | ignore_dims=self.ignore_dims) 83 | else: 84 | raise IndexError(f"Only ints, np.int arrays and slices are accepted, got {type(index)}") 85 | return item 86 | 87 | def __getitem__(self, index): 88 | item = self._access_item(index) 89 | labels = self._get_labels(item).astype(self.ltype).values 90 | features_ds = preprocess.apply_processing(self._get_features(item), self._process_funcs) 91 | if self.channel_dim is not None: 92 | features_ds = features_ds.expand_dims(dim="channel", axis=self.channel_dim) 93 | 94 | features = [f.astype(self.ftype).values for f in features_ds.values()] 95 | return features, labels 96 | 97 | def plot(self, index, axes=None, **kwargs): 98 | features = self.dataset[index] 99 | spec_plot(features, axes=axes, **kwargs) 100 | 101 | def __len__(self): 102 | return len(self._slices) 103 | 104 | def _get_features(self, dataset): 105 | return dataset.drop(self.label) 106 | 107 | def _get_labels(self, dataset): 108 | return dataset[self.label] 109 | 110 | @property 111 | def feature_shapes(self): 112 | features, _ = self.__getitem__(0) 113 | length = (self.__len__(),) 114 | shapes = [length + f.shape for f in features] 115 | return shapes 116 | 117 | def scope(self, unit): 118 | item = self._access_item(0) 119 | return preprocess.get_scope(item, unit) 120 | 121 | @property 122 | def loaded_bytes(self): 123 | return preprocess.get_size(self._recordings, loaded_only=True) 124 | 125 | @property 126 | def class_num(self): 127 | return len(self.attrs[self.class_attr]) 128 | 129 | @property 130 | def branches(self): 131 | return len(self.feature_shapes) 132 | 133 | class Iterable(Sequence): 134 | def __init__(self, radar_dataset): 135 | self._rds = radar_dataset 136 | super(RadarDataset.Iterable, self).__init__() 137 | 138 | def __len__(self): 139 | assert isinstance(self._rds, RadarDataset) 140 | return len(self._rds._slices) 141 | 142 | def __getitem__(self, index): 143 | assert isinstance(self._rds, RadarDataset) 144 | item = self._rds._access_item(index) 145 | ds = preprocess.apply_processing(self._rds._get_features(item), self._rds._process_funcs) 146 | return ds 147 | 148 | 149 | def load_rd(config, preprocessed_path, spec_length, stride, 150 | range_length=None, doppler_length=None, train_load=0.8, gpu=False, split=None): 151 | """ 152 | Split, slice and load preprocessed data as range/Doppler spectrograms to be used with Pytorch/Tensorflow 153 | :param config: Radar configuration to load, e.g. "E" 154 | :param preprocessed_path: Path to the preprocessed data 155 | :param spec_length: Length of the spectrograms in bins 156 | :param stride: Spectrogram stride or hop length for the slicing 157 | :param range_length: If provided, data will be cropped on the range axis from 0 to range_length 158 | :param doppler_length: If provided, data will be cropped on the doppler axis from 0 to doppler_length 159 | :param train_load: Target train load to perform train-test split 160 | :param gpu: If True, load the xarray.Datasets. This is normally desirable if training in the GPU cluster. 161 | :param split: Type of split. 162 | It can be None (no split), "deterministic", single", "double" (see train_test_slice) or a dictionary 163 | containing already split segments (useful for reproducibility) 164 | :return: RadarDataset objects with the data. Segments are also returned for reproducibility 165 | """ 166 | recordings = preprocess.open_recordings(config, preprocessed_path, load=gpu, 167 | range_length=range_length, doppler_length=doppler_length) 168 | 169 | if split is None: 170 | slices = train_test_slice(recordings, spec_length, stride, train_load, split=split) 171 | rds = RadarDataset(recordings, slices=slices) 172 | return rds 173 | elif isinstance(split, dict): 174 | slices = train_test_slice(recordings, spec_length, stride, train_load, 175 | split=split, return_segments=False) 176 | rd_datasets = [RadarDataset(recordings, slices=s) for s in slices] 177 | return rd_datasets 178 | else: 179 | slices = train_test_slice(recordings, spec_length, stride, train_load, 180 | split=split, return_segments=True) 181 | segments = slices.pop(-1) 182 | rd_datasets = [RadarDataset(recordings, slices=s) for s in slices] 183 | return rd_datasets, segments 184 | 185 | 186 | def dict_intersection(dictionaries): 187 | intersection = set() | dictionaries[0].items() 188 | for d in dictionaries: 189 | intersection &= d.items() 190 | return dict(intersection) 191 | -------------------------------------------------------------------------------- /utils/synthesize.py: -------------------------------------------------------------------------------- 1 | from utils import radar, skeletons as sk 2 | import numpy as np 3 | # import scipy.io 4 | 5 | 6 | def synthetic_radar(skeletons, config, frame_times=None): 7 | """ 8 | :param skeletons: skeleton data as provided by the Daq module 9 | :param config: the configuration of the radar 10 | :param frame_times: the start times of all frames in the format [f1,f2,...] 11 | :return: Synthetic data in the format (nFrames, 1, nChirps, nSamples) 12 | """ 13 | 14 | radarLoc = [p[0] for p in config["position"]] 15 | timestamps = skeletons.index 16 | startTime = timestamps[0].total_seconds() 17 | endTime = timestamps[-1].total_seconds() 18 | 19 | name = config["cfg"] 20 | f_low = config["LowerFrequency"] * 1e6 21 | f_up = config["UpperFrequency"] * 1e6 22 | numChirps = config["ChirpsPerFrame"] 23 | numSamples = config["SamplesPerChirp"] 24 | framePeriod = config["FramePeriod"] * 1e-6 25 | chirpPeriod = config["ChirpToChirpTime"] * 1e-9 26 | samplePeriod = 1e-3 / config["AdcSamplerate"] 27 | 28 | print("Create synthetic data for configuration {}".format(name)) 29 | if frame_times is None: 30 | frame_times = np.arange(startTime, endTime, framePeriod) 31 | print("Creating {} frames of synthetic data with framePeriod {}".format(len(frame_times), framePeriod)) 32 | else: 33 | print("Creating {} frames of synthetic data from provided frameTimes".format(len(frame_times))) 34 | 35 | C = 299792458 # m / s 36 | 37 | numFrames = len(frame_times) 38 | 39 | # array containing the start time of all chirps in the format [frames, chirps] 40 | chirpTimes = np.linspace(frame_times, frame_times + (numChirps - 1) * chirpPeriod, numChirps).T 41 | 42 | # Interpolation to increase temporal resolution of the skeleton data 43 | # returns segment position at each chirp start time 44 | sk_data = sk.interpolate(skeletons, chirpTimes.flatten()) 45 | 46 | # if saveMatlab: 47 | # print("Save skeletons for matlab to location: {}".format(saveMatlab.absolute())) 48 | # scipy.io.savemat(saveMatlab, {'data': data, 49 | # 'timestamp': timestamps.to_numpy()}) 50 | 51 | # Height of the person 52 | heightNose = np.mean([sk_keypoints[2, 0] for sk_keypoints in skeletons.Data]) 53 | Height = heightNose + 0.16 54 | # body segments length (meter) 55 | headlen = 0.130 * Height 56 | shoulderlen = (0.259 / 2) * Height 57 | torsolen = 0.288 * Height 58 | hiplen = (0.191 / 2) * Height 59 | upperleglen = 0.245 * Height 60 | # lowerleglen = 0.246 * Height 61 | # footlen = 0.143 * Height 62 | upperarmlen = 0.188 * Height 63 | lowerarmlen = 0.152 * Height 64 | # Ht = upperleglen + lowerleglen 65 | 66 | # Get coordinates of the person 67 | head = (sk_data[:, :, 3] + sk_data[:, :, 4]) / 2 68 | 69 | neck = head.copy() 70 | neck[:, 2] -= 0.17 71 | 72 | base = (sk_data[:, :, 11] + sk_data[:, :, 12]) / 2 73 | base[:, 2] += 0.1 74 | 75 | lshoulder = sk_data[:, :, 5] 76 | lelbow = sk_data[:, :, 7] 77 | lhand = sk_data[:, :, 9] 78 | 79 | lhip = sk_data[:, :, 11] 80 | lknee = sk_data[:, :, 13] 81 | lankle = sk_data[:, :, 15] 82 | # ltoe = lankle.copy() 83 | 84 | rshoulder = sk_data[:, :, 6] 85 | relbow = sk_data[:, :, 8] 86 | rhand = sk_data[:, :, 10] 87 | 88 | rhip = sk_data[:, :, 12] 89 | rknee = sk_data[:, :, 14] 90 | rankle = sk_data[:, :, 16] 91 | # rtoe = rankle.copy() 92 | 93 | torso = (neck + base) / 2 94 | lupperarm = (lshoulder + lelbow) / 2 95 | rupperarm = (rshoulder + relbow) / 2 96 | lupperleg = (lhip + lknee) / 2 97 | rupperleg = (rhip + rknee) / 2 98 | llowerleg = (lankle + lknee) / 2 99 | rlowerleg = (rankle + rknee) / 2 100 | 101 | # Based on "A global human walking model with real-time kinematic personification", 102 | # by R. Boulic, N.M. Thalmann, and D. Thalmann % The Visual Computer, vol .6, pp .344 - 358, 1990 103 | segments = [ 104 | sk.Segment( 105 | segmentPositions=head, 106 | ellipsoidParams=[0.1, 0.1, headlen / 2], 107 | aspect=head - neck 108 | ), 109 | sk.Segment( 110 | segmentPositions=torso, 111 | ellipsoidParams=[0.15, 0.15, torsolen / 2], 112 | aspect=neck - base 113 | ), 114 | sk.Segment( 115 | segmentPositions=lshoulder, 116 | ellipsoidParams=[0.06, 0.06, shoulderlen / 2], 117 | aspect=lshoulder - neck 118 | ), 119 | sk.Segment( 120 | segmentPositions=rshoulder, 121 | ellipsoidParams=[0.06, 0.06, shoulderlen / 2], 122 | aspect=rshoulder - neck 123 | ), 124 | sk.Segment( 125 | segmentPositions=lupperarm, 126 | ellipsoidParams=[0.06, 0.06, upperarmlen / 2], 127 | aspect=lshoulder - lelbow 128 | ), 129 | sk.Segment( 130 | segmentPositions=rupperarm, 131 | ellipsoidParams=[0.06, 0.06, upperarmlen / 2], 132 | aspect=rshoulder - relbow 133 | ), 134 | sk.Segment( 135 | segmentPositions=lhand, 136 | ellipsoidParams=[0.05, 0.05, lowerarmlen / 2], 137 | aspect=lelbow - lhand 138 | ), 139 | sk.Segment( 140 | segmentPositions=rhand, 141 | ellipsoidParams=[0.05, 0.05, lowerarmlen / 2], 142 | aspect=relbow - rhand 143 | ), 144 | sk.Segment( 145 | segmentPositions=lhip, 146 | ellipsoidParams=[0.07, 0.07, hiplen / 2], 147 | aspect=lhip - base 148 | ), 149 | sk.Segment( 150 | segmentPositions=rhip, 151 | ellipsoidParams=[0.07, 0.07, hiplen / 2], 152 | aspect=rhip - base 153 | ), 154 | sk.Segment( 155 | segmentPositions=lupperleg, 156 | ellipsoidParams=[0.07, 0.07, upperleglen / 2], 157 | aspect=lknee - lhip 158 | ), 159 | sk.Segment( 160 | segmentPositions=rupperleg, 161 | ellipsoidParams=[0.07, 0.07, upperleglen / 2], 162 | aspect=rknee - rhip 163 | ), 164 | sk.Segment( 165 | segmentPositions=llowerleg, 166 | ellipsoidParams=[0.06, 0.06, upperleglen / 2], 167 | aspect=lankle - lknee 168 | ), 169 | sk.Segment( 170 | segmentPositions=rlowerleg, 171 | ellipsoidParams=[0.06, 0.06, upperleglen / 2], 172 | aspect=rankle - rknee 173 | ) 174 | ] 175 | 176 | # define timestamps of a single chirp 177 | t = np.linspace(0, (numSamples - 1) * samplePeriod, numSamples) 178 | 179 | bw = f_up - f_low 180 | alpha = bw / t[-1] 181 | f_c = f_low + (f_up - f_low) / 2 182 | s = np.zeros((numSamples, np.size(chirpTimes))) 183 | 184 | dr, r_max = radar.range_axis(bw, numSamples) 185 | 186 | for segment in segments: 187 | r_total = segment.calculateRange(radarLoc) 188 | phi, theta = segment.calculateAngles(radarLoc) 189 | rcs = segment.calculateRCS(phi, theta) 190 | amp = np.sqrt(rcs) 191 | 192 | # Formally it should not be squared, but squaring resembles the effect of the low pass filter 193 | fspl = (4 * np.pi * r_total * f_c / C) ** 2 194 | 195 | # calculates the signal for all chirps c and samples s in the format [[chirp1s1, chirp2s1, ..., chirpns1] 196 | # [chirp1s2, chirp2s2, ..., chripns2 ] 197 | # [ ... ] 198 | # [chrip1sm, chrip2sm, ..., chripnsm]] 199 | s_segment = (amp / fspl) * np.cos(2 * np.pi * (2 * f_low * r_total / C + 200 | 2 * alpha * np.outer(t, r_total) / C - 201 | 2 * alpha * r_total.reshape(1, -1) ** 2 / C ** 2)) 202 | 203 | # print("amp shape: {}".format(amp.shape)) 204 | # print("np.outer shape: {}".format(np.outer(t, r_total).shape)) 205 | # print("r_total.reshape: {}".format(r_total.reshape(1,-1).shape)) 206 | 207 | # set whole chirp to zero if segment is out of sight 208 | s_segment[:, r_total > r_max] = 0 209 | s += s_segment 210 | 211 | s = s.T.reshape(numFrames, numChirps, numSamples) 212 | s = np.expand_dims(s, axis=1) 213 | 214 | print("Synthetic data with shape {} successfully created".format(s.shape)) 215 | return s 216 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import pandas as pd 4 | 5 | import matplotlib.pyplot as plt 6 | plt.rcParams['svg.fonttype'] = 'none' 7 | import matplotlib.ticker as mticker 8 | import matplotlib.animation as manimation 9 | 10 | from mpl_toolkits import mplot3d # noqa: F401 unused import 11 | 12 | from utils import skeletons as sk 13 | 14 | # Put all visualizations here 15 | 16 | 17 | def spec_plot(ds, axes=None, ax_xlabel=None, cbar_global=None, **kwargs): 18 | """ 19 | Plot radar spectrograms using xarray functionality 20 | :param ds: xarray.Dataset with spectrograms to plot 21 | :param axes: Matplotlib Axes to plot onto. If not given, a new Matplotlib Figure will be created 22 | :param axes: Matplotlib Axes where x label should be present. Default to the last Axes 23 | :param cbar_global: Title of a global colorbar 24 | :param kwargs: Keyword arguments for xarray.DataArray.plot 25 | :return: None 26 | """ 27 | ds_time = ds["time"] 28 | ds['time'] = tdelta2secs(ds_time) 29 | ds.time.attrs = {"long_name": "time offset", "units": "min:sec"} 30 | 31 | title = create_title(ds) 32 | 33 | num_features = len(ds) 34 | 35 | if axes is None: 36 | fig, axes = plt.subplots(num_features, 1) 37 | else: 38 | fig = axes[-1].figure 39 | 40 | if cbar_global: 41 | kwargs["add_colorbar"] = False 42 | 43 | if ax_xlabel is None: 44 | ax_xlabel = [axes[-1]] 45 | for ax, feat in zip(axes, ds.values()): 46 | plt.sca(ax) 47 | cbar_kwargs = {"label": f"Spectrogram [{feat.units}]"} 48 | try: 49 | if not kwargs["add_colorbar"]: 50 | cbar_kwargs = {} 51 | except KeyError: 52 | pass 53 | feat.plot.imshow(x="time", cbar_kwargs=cbar_kwargs, **kwargs) 54 | ax.xaxis.set_major_formatter(mticker.FuncFormatter(format_seconds)) 55 | if ax not in ax_xlabel: 56 | ax.set_xlabel(None) 57 | 58 | fig.suptitle(title, wrap=True) 59 | 60 | if cbar_global: 61 | im = axes[-1].get_images()[0] 62 | fig.subplots_adjust(right=0.8) 63 | cbar_ax = fig.add_axes([0.85, 0.15, 0.03, 0.7]) 64 | cbar = fig.colorbar(im, cax=cbar_ax) 65 | cbar.set_label(cbar_global) 66 | 67 | ds['time'] = ds_time 68 | 69 | 70 | def format_seconds(x, _pos): 71 | timestamp = "{:02d}:{:02d}".format(int(x // 60), int(x % 60)) 72 | rest = x % 1.0 73 | if rest != 0.0: 74 | return '' 75 | else: 76 | return timestamp 77 | 78 | 79 | def animation_real_synthetic(real_rd, synthetic_rd, skeletons=None, 80 | sensor_loc=None, notebook=False, save_path=None, max_z=2.5, **kwargs): 81 | """ 82 | Compare real and synthetic data with an animated plot 83 | :param real_rd: Real RDM sequence (in dB) as xarray.DataArray 84 | :param synthetic_rd: Synthetic RDM sequence (in dB) as xarray.DataArray. 85 | The time dimension must be consistent across all data arrays 86 | :param skeletons: Skeleton data as xarray.DataArray (optional) 87 | :param sensor_loc: Radar Sensor position as given by daq.env 88 | :param notebook: if True, return `anim_obj` to be used in a Jupyter notebook 89 | :param save_path: if given, save the video under this path 90 | :param max_z: Maximum height in meters 91 | :param kwargs: keyword arguments for RDM plotting. 92 | :return: `anim_obj` if `notebook=True` 93 | """ 94 | fig = plt.figure(figsize=(10, 5)) 95 | 96 | sk_lines = False 97 | if skeletons is None: 98 | subplot_real = 121 99 | subplot_synth = 122 100 | else: 101 | subplot_real = 222 102 | subplot_synth = 224 103 | 104 | ax_sk = fig.add_subplot(121, projection='3d') 105 | room_size = skeletons.room_size + [max_z] 106 | sk_edges = sk.get_edges(skeletons) 107 | sk_lines = plt_skeleton(sk_edges[0, ], room_size=room_size, axis=ax_sk, sensor_loc=sensor_loc) 108 | 109 | if "vmin" not in kwargs: 110 | kwargs["vmin"] = min(real_rd.min(), synthetic_rd.min()) 111 | if "vmax" not in kwargs: 112 | kwargs["vmax"] = max(real_rd.max(), synthetic_rd.max()) 113 | ax_real = fig.add_subplot(subplot_real) 114 | im_real = real_rd.isel(time=0).plot.imshow(add_colorbar=False, **kwargs) 115 | if skeletons is not None: 116 | ax_real.set_xticklabels([]) 117 | ax_real.set_xlabel(None) 118 | plt.title("Real Data") 119 | ax_synth = fig.add_subplot(subplot_synth) 120 | im_synth = synthetic_rd.isel(time=0).plot.imshow(add_colorbar=False, **kwargs) 121 | plt.title("Synthetic Data") 122 | 123 | ts_format = "%M:%S.%f" 124 | timestamps = pd.to_timedelta(real_rd.time.values) + pd.Timestamp(0) 125 | frame_period_ms = real_rd.FramePeriod // 1000 126 | num_frames = len(timestamps) 127 | 128 | title = create_title(real_rd) 129 | 130 | def suptitle(timestamp): 131 | return f"{title}+{timestamp.strftime(ts_format)[:-3]}" 132 | 133 | fig.suptitle(suptitle(timestamps[0]), wrap=True) 134 | 135 | cbar_orient = 'horizontal' if skeletons is None else 'vertical' 136 | cbar = fig.colorbar(im_synth, ax=(ax_real, ax_synth), orientation=cbar_orient) 137 | cbar.set_label('Amplitude [dB]') 138 | 139 | # animation function. This is called sequentially 140 | def animate(i): 141 | im_real.set_array(real_rd.isel(time=i)) 142 | im_synth.set_array(synthetic_rd.isel(time=i)) 143 | if sk_lines: 144 | update_skeleton(sk_lines, sk_edges[i, ]) 145 | fig.suptitle(suptitle(timestamps[i])) 146 | return [im_real, im_synth] 147 | 148 | anim_obj = manimation.FuncAnimation( 149 | fig, 150 | # The function that does the updating of the Figure 151 | animate, 152 | frames=num_frames, 153 | # Frame-time in ms 154 | interval=frame_period_ms, 155 | blit=True 156 | ) 157 | if save_path is not None: 158 | anim_obj.save(save_path) 159 | elif notebook: 160 | return anim_obj 161 | else: 162 | plt.show() 163 | 164 | 165 | def animation(real_rd, skeletons=None, 166 | sensor_loc=None, notebook=False, save_path=None, max_z=2.5, **kwargs): 167 | """ 168 | Plot range doppler maps 169 | :param real_rd: Real RDM sequence (in dB) as xarray.DataArray 170 | :param skeletons: Skeleton data as xarray.DataArray (optional) 171 | :param sensor_loc: Radar Sensor position as given by daq.env 172 | :param notebook: if True, return `anim_obj` to be used in a Jupyter notebook 173 | :param save_path: if given, save the video under this path 174 | :param max_z: Maximum height in meters 175 | :param kwargs: keyword arguments for RDM plotting. 176 | :return: `anim_obj` if `notebook=True` 177 | """ 178 | fig = plt.figure(figsize=(11, 5)) 179 | 180 | sk_lines = False 181 | if skeletons is None: 182 | subplot_real = 111 183 | else: 184 | subplot_real = 122 185 | 186 | ax_sk = fig.add_subplot(121, projection='3d') 187 | room_size = skeletons.room_size + [max_z] 188 | sk_edges = sk.get_edges(skeletons) 189 | sk_lines = plt_skeleton(sk_edges[0, ], room_size=room_size, axis=ax_sk, sensor_loc=sensor_loc) 190 | 191 | if "vmin" not in kwargs: 192 | kwargs["vmin"] = real_rd.min() 193 | if "vmax" not in kwargs: 194 | kwargs["vmax"] = real_rd.max() 195 | ax_real = fig.add_subplot(subplot_real) 196 | im_real = real_rd.isel(time=0).plot.imshow(add_colorbar=False, **kwargs) 197 | plt.title("Range Doppler Map") 198 | 199 | ts_format = "%M:%S.%f" 200 | timestamps = pd.to_timedelta(real_rd.time.values) + pd.Timestamp(0) 201 | frame_period_ms = real_rd.FramePeriod // 1000 202 | num_frames = len(timestamps) 203 | 204 | title = create_title(real_rd) 205 | 206 | def suptitle(timestamp): 207 | return f"{title}+{timestamp.strftime(ts_format)[:-3]}" 208 | 209 | fig.suptitle(suptitle(timestamps[0]), wrap=True) 210 | 211 | cbar = fig.colorbar(im_real, ax=ax_real, orientation='vertical') 212 | cbar.set_label('Amplitude [dB]') 213 | 214 | # animation function. This is called sequentially 215 | def animate(i): 216 | im_real.set_array(real_rd.isel(time=i)) 217 | if sk_lines: 218 | update_skeleton(sk_lines, sk_edges[i, ]) 219 | fig.suptitle(suptitle(timestamps[i])) 220 | return [im_real] 221 | 222 | anim_obj = manimation.FuncAnimation( 223 | fig, 224 | # The function that does the updating of the Figure 225 | animate, 226 | frames=num_frames, 227 | # Frame-time in ms 228 | interval=frame_period_ms, 229 | blit=True 230 | ) 231 | if save_path is not None: 232 | anim_obj.save(save_path) 233 | elif notebook: 234 | return anim_obj 235 | else: 236 | plt.show() 237 | 238 | 239 | def animation_spec(rdm, spectrograms, gap=False, marker_color='white', marker_style="d", 240 | notebook=False, save_path=None, **kwargs): 241 | """ 242 | Animation to show how spectrograms are extracted from the RDM sequence 243 | :param rdm: Range Doppler Map sequence as xarray.DataArray 244 | :param spectrograms: Range spectrograms as xarray.DataSet, extracted from `rdm` 245 | :param gap: If True, leave a gap between RDM and spectrograms 246 | :param marker_color: Color of the time marker on the spectrograms 247 | :param marker_style: Style of the marker tips on the spectrograms 248 | :param notebook: if True, return `anim_obj` to be used in a Jupyter notebook 249 | :param save_path: if given, save the video under this path 250 | :param kwargs: keyword arguments for RDM plotting. 251 | :return: `anim_obj` if `notebook=True` 252 | """ 253 | 254 | if gap: 255 | fig = plt.figure(figsize=(16, 5)) 256 | subplot_rdm = 131 257 | subplot_range = 233 258 | subplot_doppler = 236 259 | else: 260 | fig = plt.figure(figsize=(10, 5)) 261 | subplot_rdm = 121 262 | subplot_range = 222 263 | subplot_doppler = 224 264 | 265 | ts_format = "%M:%S.%f" 266 | timestamps = pd.to_timedelta(rdm.time.values) + pd.Timestamp(0) 267 | frame_period_ms = rdm.FramePeriod // 1000 268 | num_frames = len(timestamps) 269 | 270 | time_spect = tdelta2secs(spectrograms.time) 271 | rng = spectrograms.range.values 272 | dopp = spectrograms.doppler.values 273 | 274 | def set_rdm_title(timestamp): 275 | return f"{rdm_title}at time offset {timestamp.strftime(ts_format)[:-3]}" 276 | 277 | if "vmin" not in kwargs: 278 | kwargs["vmin"] = min(rdm.min(), min(s.min() for s in spectrograms.values())) 279 | if "vmax" not in kwargs: 280 | kwargs["vmax"] = max(rdm.max(), max(s.max() for s in spectrograms.values())) 281 | ax_rdm = fig.add_subplot(subplot_rdm) 282 | im_rdm = rdm.isel(time=0).plot.imshow(add_colorbar=False, **kwargs) 283 | 284 | rdm_title = "Range Doppler Map\n" if gap else "RDM " 285 | plt.title(set_rdm_title(timestamps[0])) 286 | 287 | ax_range = fig.add_subplot(subplot_range) 288 | ax_doppler = fig.add_subplot(subplot_doppler) 289 | 290 | spec_plot(spectrograms, axes=[ax_range, ax_doppler], add_colorbar=False, **kwargs) 291 | ax_range.set_title("Radar Spectrograms") 292 | fig.align_ylabels([ax_range, ax_doppler]) 293 | plt.sca(ax_range) 294 | r_marker, = plt.plot([time_spect[0], time_spect[0]], [rng[0], rng[-1]], c=marker_color, marker=marker_style) 295 | d_marker, = ax_doppler.plot([time_spect[0], time_spect[0]], 296 | [dopp[0], dopp[-1]], c=marker_color, marker=marker_style) 297 | 298 | fig.suptitle(create_title(rdm), wrap=True) 299 | 300 | if gap: 301 | cbar_rdm = fig.colorbar(im_rdm, ax=ax_rdm, orientation='vertical') 302 | cbar_rdm.set_label('Amplitude [dB]') 303 | cbar_spect = fig.colorbar(ax_range.get_images()[0], ax=(ax_range, ax_doppler), orientation='vertical') 304 | cbar_spect.set_label('Amplitude [dB]') 305 | 306 | # animation function. This is called sequentially 307 | def animate(i): 308 | im_rdm.set_array(rdm.isel(time=i)) 309 | ax_rdm.set_title(set_rdm_title(timestamps[i])) 310 | t_marker = [time_spect[i]] * 2 311 | r_marker.set_xdata(t_marker) 312 | d_marker.set_xdata(t_marker) 313 | return [im_rdm, r_marker, d_marker] 314 | 315 | anim_obj = manimation.FuncAnimation( 316 | fig, 317 | # The function that does the updating of the Figure 318 | animate, 319 | frames=num_frames, 320 | # Frame-time in ms 321 | interval=frame_period_ms, 322 | blit=True 323 | ) 324 | if save_path is not None: 325 | anim_obj.save(save_path) 326 | elif notebook: 327 | return anim_obj 328 | else: 329 | plt.show() 330 | 331 | 332 | def plt_skeleton(sk_edges, room_size=None, axis=None, sensor_loc=None): 333 | if axis is None: 334 | fig = plt.figure() 335 | axis = fig.add_subplot(projection='3d') 336 | sk_lines = [] 337 | for edge in sk_edges: 338 | line, = axis.plot(*edge, color='black', marker='o') 339 | sk_lines.append(line) 340 | if sensor_loc is not None: 341 | sensor_x, sensor_y, sensor_z = (sl[0] for sl in sensor_loc) 342 | axis.scatter(sensor_x, sensor_y, sensor_z, color="red", marker="*") 343 | axis.text(sensor_x, sensor_y, sensor_z, "Radar") 344 | 345 | axis.set_xlabel('x') 346 | axis.set_ylabel('y') 347 | axis.set_zlabel('z') 348 | if room_size is not None: 349 | axis.set_xlim(0, room_size[0]) 350 | axis.set_ylim(0, room_size[1]) 351 | axis.set_zlim(0, room_size[2]) 352 | 353 | plt.title("Room reconstruction") 354 | 355 | return sk_lines 356 | 357 | 358 | def update_skeleton(sk_lines, sk_edges): 359 | for line, edge in zip(sk_lines, sk_edges): 360 | line.set_xdata(edge[0]) 361 | line.set_ydata(edge[1]) 362 | line.set_3d_properties(edge[2]) 363 | 364 | 365 | def create_title(ds): 366 | act = ds.attrs['activity'] 367 | cfg = ds.attrs['cfg'] 368 | date = pd.to_datetime(ds.date).strftime("on %d-%m-%Y at %H:%M") 369 | title = f"{act} from confg. {cfg} {date}" 370 | return title 371 | 372 | 373 | def tdelta2secs(time_delta): 374 | return time_delta / np.timedelta64(1, 's') 375 | -------------------------------------------------------------------------------- /utils/slicer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.random import default_rng 3 | 4 | from utils.preprocess import recs2dataset 5 | 6 | rng = default_rng() 7 | 8 | 9 | def verbose_split(train_idx, test_idx): 10 | i = 0 11 | for tr, ts in zip(train_idx, test_idx): 12 | i += 1 13 | 14 | partial_load = sum(len(t) for t in tr) 15 | rec_len = partial_load + sum(len(t) for t in ts) 16 | print(f"Recording {i}:") 17 | print(f"\tPartial load:\t{partial_load / rec_len:02},\t{partial_load}/{rec_len} frames") 18 | 19 | total_load = sum(sum(len(t) for t in tr) for tr in train_idx) 20 | total_length = total_load + sum(sum(len(t) for t in ts) for ts in test_idx) 21 | print(f"Obtained load:\t{total_load / total_length:02},\t{total_load}/{total_length} frames") 22 | 23 | 24 | def train_split_no_cut(recording_lengths, train_load, labels, verbose=True): 25 | """ 26 | Perform test-train split of a list of recordings without cutting recordings, only assigning recordings 27 | to different sets in a label-balanced way 28 | Args: 29 | recording_lengths: List of int, each one of them represents the length of a recording file 30 | train_load: float, Desired train load in [0.5, 1) 31 | labels: List of int representing the label assigned to each recording to take into account for splitting 32 | verbose: Verbose output to verify the split 33 | 34 | Returns: 35 | train_idx, test_idx 36 | The indices for the train and test indices as two nested lists of range objects 37 | 38 | """ 39 | 40 | if not (0.5 <= train_load < 1): 41 | raise ValueError(f"Invalid value for train_load={train_load}.\n" 42 | f"It must lie within [0.5 , 1).") 43 | 44 | label_lengths = {} 45 | for rec_len, lbl in zip(recording_lengths, labels): 46 | try: 47 | label_lengths[lbl].append(rec_len) 48 | except KeyError: 49 | label_lengths[lbl] = [rec_len] 50 | total_loads = {k: round(sum(v) * train_load) for k, v in label_lengths.items()} 51 | 52 | train_idx = [] 53 | test_idx = [] 54 | acc_lengths = {k: 0 for k in total_loads} 55 | for r_len, lbl in zip(recording_lengths, labels): 56 | if acc_lengths[lbl] < total_loads[lbl]: 57 | train_idx.append([range(0, r_len)]) 58 | test_idx.append([range(0, 0)]) 59 | else: 60 | train_idx.append([range(0, 0)]) 61 | test_idx.append([range(0, r_len)]) 62 | acc_lengths[lbl] += r_len 63 | 64 | if verbose: 65 | print(f"Required load:\t{train_load}") 66 | verbose_split(train_idx, test_idx) 67 | 68 | return train_idx, test_idx 69 | 70 | 71 | def train_split_singlecut(recording_lengths, train_load, beta=10.0, verbose=True): 72 | """ 73 | Perform test-train split of a list of recordings through a single random cut on every recording 74 | Args: 75 | recording_lengths: List of int, each one of them represents the length of a recording file 76 | train_load: float, Desired train load in [0.5, 1) 77 | beta: Parameter of the beta distribution. The bigger, the less standard deviation over the random cuts. 78 | If None, it divides every measurement deterministically according to train_load 79 | verbose: Verbose output to verify the split 80 | 81 | Returns: 82 | train_idx, test_idx 83 | The indices for the train and test indices as two nested lists of range objects 84 | 85 | """ 86 | 87 | if not (0.5 <= train_load < 1): 88 | raise ValueError(f"Invalid value for train_load={train_load}.\n" 89 | f"It must lie within [0.5 , 1).") 90 | 91 | if beta is None: 92 | deterministic = True 93 | pass 94 | elif beta < 1: 95 | raise ValueError(f"Invalid value for beta={beta}.\n" 96 | f"It must be equal or greater than one to enforce unimodality.") 97 | else: 98 | deterministic = False 99 | 100 | rec_lengths = np.array(recording_lengths) 101 | n_recs = rec_lengths.size 102 | total_length = rec_lengths.sum() 103 | 104 | # Use the longest recording to even out the obtained load 105 | i_max = rec_lengths.argmax() 106 | max_length = rec_lengths[i_max] 107 | max_out_idx = np.arange(n_recs) != i_max 108 | 109 | total_load = np.round(total_length * train_load) 110 | 111 | max_out_loads = None 112 | max_load = -1 113 | if deterministic: 114 | max_out_loads = np.round(train_load * rec_lengths[max_out_idx]) 115 | max_load = total_load - max_out_loads.sum() 116 | else: # Use a beta distribution to generate random loads with rejection 117 | while not 0 < max_load < max_length: 118 | alpha = beta * train_load / (1 - train_load) 119 | beta_loads = rng.beta(alpha, beta, size=n_recs - 1) 120 | max_out_loads = np.round(beta_loads * rec_lengths[max_out_idx]) 121 | max_load = total_load - max_out_loads.sum() 122 | 123 | partial_loads = np.zeros(n_recs, dtype=np.int32) 124 | partial_loads[max_out_idx] = max_out_loads 125 | partial_loads[i_max] = max_load 126 | 127 | head_tail = rng.binomial(1, 0.5, size=n_recs) # Choose randomly the head or the tail of the chunk 128 | 129 | train_idx = [] 130 | test_idx = [] 131 | 132 | for head_train, partial_load, rec_len in zip(head_tail, partial_loads, rec_lengths): 133 | if head_train: 134 | train_idx.append([range(0, partial_load)]) 135 | test_idx.append([range(partial_load, rec_len)]) 136 | else: 137 | train_idx.append([range(rec_len - partial_load, rec_len)]) 138 | test_idx.append([range(0, rec_len - partial_load)]) 139 | 140 | if verbose: 141 | print(f"Longest recording is {i_max + 1}, {rec_lengths[i_max]} frames") 142 | print(f"Required load:\t{train_load}") 143 | verbose_split(train_idx, test_idx) 144 | 145 | return train_idx, test_idx 146 | 147 | 148 | def train_split_doublecut(recording_lengths, train_load, min_len=1, stride=1, verbose=True): 149 | """ 150 | Perform test-train split of a list of recordings through a double cut of fixed length 151 | at a random position of every recording 152 | Args: 153 | recording_lengths: List of int, each one of them represents the length of a recording file 154 | train_load: float, Desired train load in [0.5, 1) 155 | min_len: int, Minimum length of a chunk, important if the data will be further sliced 156 | stride: int, Stride to consider when choosing the random position of the cuts 157 | verbose: Verbose output to verify the split 158 | 159 | Returns: 160 | train_idx, test_idx 161 | The indices for the train and test indices as two nested lists of range objects 162 | 163 | """ 164 | 165 | if not (0.5 <= train_load < 1): 166 | raise ValueError(f"Invalid value for train_load={train_load}.\n" 167 | f"It must lie within [0.5 , 1).") 168 | 169 | rec_lengths = np.array(recording_lengths) 170 | n_recs = rec_lengths.size 171 | total_length = rec_lengths.sum() 172 | 173 | # Use the longest recording to even out the obtained load 174 | i_max = rec_lengths.argmax() 175 | max_out_idx = np.arange(n_recs) != i_max 176 | 177 | total_load = np.round(total_length * train_load) 178 | 179 | max_out_loads = np.round(train_load * rec_lengths[max_out_idx]) 180 | max_load = total_load - max_out_loads.sum() 181 | 182 | partial_loads = np.zeros(n_recs, dtype=np.int32) 183 | partial_loads[max_out_idx] = max_out_loads 184 | partial_loads[i_max] = max_load 185 | 186 | train_idx = [] 187 | test_idx = [] 188 | 189 | for partial_load, rec_len in zip(partial_loads, rec_lengths): 190 | 191 | trimmed_length = rec_len - partial_load 192 | 193 | offset_choices = [0, trimmed_length] # head and tail offsets 194 | offset_choices.extend(range(min_len, trimmed_length - min_len, stride)) # centered train chunk 195 | offset_choices.extend(range(trimmed_length + min_len, rec_len - min_len, stride)) # centered test chunk 196 | 197 | offset = rng.choice(offset_choices) 198 | second_cut = (offset + partial_load) % rec_len 199 | 200 | if offset == 0: # head train chunk 201 | train_idx.append([range(offset, partial_load)]) 202 | test_idx.append([range(partial_load, rec_len)]) 203 | elif offset == trimmed_length: # tail train chunk 204 | train_idx.append([range(offset, rec_len)]) 205 | test_idx.append([range(0, offset)]) 206 | elif offset < second_cut: # centered train chunk 207 | train_idx.append([range(offset, second_cut)]) 208 | test_idx.append([range(0, offset), 209 | range(second_cut, rec_len)]) 210 | else: # centered test chunk 211 | train_idx.append([range(0, second_cut), 212 | range(offset, rec_len)]) 213 | test_idx.append([range(second_cut, offset)]) 214 | 215 | if verbose: 216 | print(f"Longest recording is {i_max + 1}, {rec_lengths[i_max]} frames") 217 | print(f"Required load:\t{train_load}") 218 | verbose_split(train_idx, test_idx) 219 | 220 | return train_idx, test_idx 221 | 222 | 223 | def serialize_ranges(ranges): 224 | if isinstance(ranges, range): 225 | return {"start": ranges.start, "stop": ranges.stop} 226 | elif isinstance(ranges, (tuple, list)): 227 | return [serialize_ranges(r) for r in ranges] 228 | else: 229 | raise TypeError(f"Unexpected class {type(ranges)}") 230 | 231 | 232 | def deserialize_ranges(ranges): 233 | if isinstance(ranges, dict): 234 | return range(ranges["start"], ranges["stop"]) 235 | elif isinstance(ranges, (tuple, list)): 236 | return [deserialize_ranges(r) for r in ranges] 237 | else: 238 | raise TypeError(f"Unexpected class {type(ranges)}") 239 | 240 | 241 | def index_ranges(ranges, recordings): 242 | lut_date = {r.date: i for i, r in enumerate(recordings)} 243 | new_ranges = [] 244 | for date, r in ranges.items(): 245 | new_ranges.append((lut_date[date], deserialize_ranges(r))) 246 | return new_ranges 247 | 248 | 249 | def train_test_slice(recordings, time_length, time_hop, train_load, split="single", beta=20.0, 250 | verbose=True, merge=False, return_segments=False): 251 | """Slice recordings into train and test datasets 252 | 253 | Args: 254 | recordings: Sequence of Dataset objects holding the preprocessed recordings 255 | time_length: Length of the data examples along the time axis 256 | time_hop: Hop size between the slices across the time axis 257 | train_load: float, Desired train load in [0.5, 1) 258 | return_segments: Return serialized segments 259 | verbose: Verbose output to verify the split 260 | merge: If True, merge slices into a new lazy dataset, otherwise return a flat indexed slice list 261 | split: Split variant, either "deterministic", single" or "double" or None for no train-test split 262 | beta: float, beta value for the random single split 263 | 264 | Returns: train_idx, test_idx 265 | Train and test Dataset objects or slice lists, depending on 'merge' 266 | 267 | """ 268 | 269 | rec_lens = [r.sizes["time"] for r in recordings] 270 | rec_labels = [r.label.item() for r in recordings] 271 | 272 | def slice_indices(indices): 273 | sliced_indices = [] 274 | for i, rec_segments in enumerate(indices): 275 | for range_seg in rec_segments: 276 | for offset in range(0, len(range_seg) - time_length + 1, time_hop): 277 | sliced_indices.append((i, range_seg[offset:offset+time_length])) 278 | return sliced_indices 279 | 280 | train_splits = {"deterministic": lambda rcl, tld: train_split_singlecut(rcl, tld, beta=None, verbose=verbose), 281 | "no-cut": lambda rcl, tld: train_split_no_cut(rcl, tld, labels=rec_labels, verbose=verbose), 282 | "single": lambda rcl, tld: train_split_singlecut(rcl, tld, beta=beta, verbose=verbose), 283 | "double": lambda rcl, tld: train_split_doublecut(rcl, tld, min_len=time_length, 284 | stride=time_hop, verbose=verbose)} 285 | 286 | if split is None: 287 | full_idx = [[range(rl)] for rl in rec_lens] 288 | full_slices = slice_indices(full_idx) 289 | if merge: 290 | return recs2dataset(recordings, full_slices) 291 | else: 292 | return full_slices 293 | else: 294 | if isinstance(split, dict): 295 | return_segments = False 296 | rec_dates = [r.date for r in recordings] 297 | train_split = split['train'] 298 | test_split = split['test'] 299 | 300 | train_idx = [deserialize_ranges(train_split[d]) for d in rec_dates] 301 | test_idx = [deserialize_ranges(test_split[d]) for d in rec_dates] 302 | else: 303 | try: 304 | train_idx, test_idx = train_splits[split](rec_lens, train_load) 305 | except KeyError as ke: 306 | ts_keys = tuple(train_splits.keys()) 307 | raise ValueError(f"Unrecognized split argument '{split}'." 308 | f" Argument must belong to {ts_keys}") from ke 309 | 310 | train_sliced = slice_indices(train_idx) 311 | test_sliced = slice_indices(test_idx) 312 | 313 | if merge: 314 | train_ds = recs2dataset(recordings, train_sliced) 315 | test_ds = recs2dataset(recordings, test_sliced) 316 | ret = [train_ds, test_ds] 317 | else: 318 | ret = [train_sliced, test_sliced] 319 | 320 | if return_segments: 321 | serial_segments = {"train": {r.date: s for r, s in zip(recordings, serialize_ranges(train_idx))}, 322 | "test": {r.date: s for r, s in zip(recordings, serialize_ranges(test_idx))}} 323 | ret.append(serial_segments) 324 | 325 | return ret 326 | -------------------------------------------------------------------------------- /visualize/notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "581bcae7-0e04-478c-8d92-ea4f253b86bc", 6 | "metadata": { 7 | "pycharm": { 8 | "name": "#%% md\n" 9 | }, 10 | "tags": [] 11 | }, 12 | "source": [ 13 | "# RACPIT visualization" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "id": "b2f5d92a-8f06-458e-9f98-f362d750e100", 19 | "metadata": { 20 | "pycharm": { 21 | "name": "#%% md\n" 22 | }, 23 | "tags": [] 24 | }, 25 | "source": [ 26 | "![logo](../images/logo128.png) This notebook demonstrates our simulation, preprocessing and visualization pipeline for radar data.\n", 27 | "See it rendered\n", 28 | "[here](https://fraunhoferhhi.github.io/racpit/visualize)." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "fe5aa707", 35 | "metadata": { 36 | "collapsed": false, 37 | "jupyter": { 38 | "outputs_hidden": false 39 | }, 40 | "pycharm": { 41 | "name": "#%%\n" 42 | } 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "import os\n", 47 | "os.chdir(\"..\")" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "cca2b70b", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "from pathlib import Path\n", 58 | "import pandas as pd\n", 59 | "\n", 60 | "import numpy as np\n", 61 | "import xarray as xr\n", 62 | "\n", 63 | "from utils.synthesize import synthetic_radar\n", 64 | "from utils.preprocess import open_recordings, identify_config, raw2rdm\n", 65 | "\n", 66 | "from utils import radar\n", 67 | "from utils import skeletons as sk\n", 68 | "\n", 69 | "from utils.visualization import spec_plot, animation_real_synthetic, animation_spec\n", 70 | "\n", 71 | "from ifxaion.daq import Daq\n", 72 | "\n", 73 | "import matplotlib.pyplot as plt\n", 74 | "from IPython.display import HTML\n", 75 | "\n", 76 | "from networks.img_transf import MultiTransformNet\n", 77 | "\n", 78 | "import torch\n", 79 | "from torch.autograd import Variable" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "5c85fb6b-358a-4850-9968-6c20eab5e09b", 86 | "metadata": { 87 | "tags": [] 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "data_path = Path(\"/mnt/infineon-radar\")\n", 92 | "\n", 93 | "raw_dir = data_path / \"daq_x-har\"\n", 94 | "activity = \"5_Walking_Boxing\"\n", 95 | "path = raw_dir / f\"{activity}_converted/recording-2020-01-28_12-37-12\"\n", 96 | "\n", 97 | "real_path = data_path / \"preprocessed/fixed_size/real\"\n", 98 | "synth_path = data_path / \"preprocessed/fixed_size/synthetic\"\n", 99 | "\n", 100 | "itn_config = \"D\"\n", 101 | "itn_recording = \"2020-02-05T15:16:08\"\n", 102 | "itn_path = \"models/publication_I.model\"" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "e9862806-9955-4297-b062-e3717f5cc0c7", 108 | "metadata": { 109 | "pycharm": { 110 | "name": "#%% md\n" 111 | }, 112 | "tags": [] 113 | }, 114 | "source": [ 115 | "## Simulation" 116 | ] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "id": "161437b9-1232-442c-b1ff-653b573d8262", 121 | "metadata": { 122 | "pycharm": { 123 | "name": "#%% md\n" 124 | }, 125 | "tags": [] 126 | }, 127 | "source": [ 128 | "### Radar data loading" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "05eec437", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "daq = Daq(rec_dir=path)\n", 139 | "env = daq.env\n", 140 | "recording = daq.radar[2]\n", 141 | "rec_config = daq.radar[2].cfg\n", 142 | "timestamps = recording.data.index\n", 143 | "ts_seconds = timestamps.total_seconds()" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "id": "6195148a", 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "config_name = identify_config(rec_config)\n", 154 | "\n", 155 | "rec_config['RadarName'] = rec_config.pop(\"Name\")\n", 156 | "rec_config['cfg'] = config_name\n", 157 | "rec_config[\"activity\"] = activity\n", 158 | "\n", 159 | "n_samples = rec_config['SamplesPerChirp']\n", 160 | "m_chirps = rec_config['ChirpsPerFrame']\n", 161 | "\n", 162 | "print(f\"Synthetizing data for configuration {config_name}\")" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "id": "6c39f9e6", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "frame_interval_ms = np.mean((timestamps[1:] - timestamps[:-1]).total_seconds()) * 1e3\n", 173 | "duration_sec = (timestamps[-1] - timestamps[0]).total_seconds()\n", 174 | "print(f'Mean frame interval:\\t{frame_interval_ms} ms')\n", 175 | "print(f'Total duration:\\t{duration_sec} seconds')" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "id": "42f66451-ea13-4591-8f4f-b28ae83fcad4", 181 | "metadata": { 182 | "pycharm": { 183 | "name": "#%% md\n" 184 | }, 185 | "tags": [] 186 | }, 187 | "source": [ 188 | "### Data synthesis" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "id": "e6717ffc-ea38-4325-90d1-8ea074be7e58", 194 | "metadata": { 195 | "pycharm": { 196 | "name": "#%% md\n" 197 | }, 198 | "tags": [] 199 | }, 200 | "source": [ 201 | "Load skeleton data" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "id": "611fca25", 208 | "metadata": { 209 | "tags": [] 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "skeletons = sk.load(path, verbose=True)\n", 214 | "sk_interp = sk.interpolate(skeletons, timestamps.total_seconds())\n", 215 | "sk_da = sk.to_xarray(sk_interp, timestamps, attrs=env)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "id": "9310c5a7", 221 | "metadata": { 222 | "pycharm": { 223 | "name": "#%% md\n" 224 | } 225 | }, 226 | "source": [ 227 | "Synthesize raw data from skeleton points using a radar configuration" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "id": "8d1f79ce", 234 | "metadata": { 235 | "pycharm": { 236 | "name": "#%%\n" 237 | } 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "syntheticData = synthetic_radar(skeletons, rec_config, ts_seconds)\n", 242 | "\n", 243 | "assert syntheticData.shape[-2] == m_chirps, \"Number of chirps of synthetic data not correct\"\n", 244 | "assert syntheticData.shape[-1] == n_samples, \"Number of samples per chirp of synthetic data not correct\"\n", 245 | "\n", 246 | "smin = syntheticData.min()\n", 247 | "smax = syntheticData.max()\n", 248 | "snorm = (syntheticData - smin) / (smax - smin)\n", 249 | "raw_synth = pd.DataFrame({\"Timestamps\": timestamps, \"NormData\": [sn for sn in snorm]}).set_index(\"Timestamps\")" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "id": "fd06c539", 255 | "metadata": { 256 | "pycharm": { 257 | "name": "#%% md\n" 258 | } 259 | }, 260 | "source": [ 261 | "The result is a DataFrame" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "id": "d5ec385d", 268 | "metadata": { 269 | "pycharm": { 270 | "name": "#%%\n" 271 | } 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "raw_synth" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "id": "5df5924d-3e2a-47c6-90b8-a3ca70d26d16", 281 | "metadata": { 282 | "pycharm": { 283 | "name": "#%% md\n" 284 | }, 285 | "tags": [] 286 | }, 287 | "source": [ 288 | "### Preprocessing" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "id": "c46026bd-5dd5-4008-be07-64f9f4c48c2d", 294 | "metadata": { 295 | "pycharm": { 296 | "name": "#%% md\n" 297 | } 298 | }, 299 | "source": [ 300 | "The raw data is processed and converted to an [*x*array](http://xarray.pydata.org/en/stable/) DataArray" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "id": "c777f60f", 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "rdm_synth = raw2rdm(raw_synth, rec_config, env, name=f\"{activity}-{config_name}\")" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "id": "97315310", 316 | "metadata": { 317 | "pycharm": { 318 | "name": "#%% md\n" 319 | } 320 | }, 321 | "source": [ 322 | "The Range Doppler Maps can be converted to dB" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "id": "06c3813f", 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "rdm_db = xr.apply_ufunc(radar.mag2db, rdm_synth, keep_attrs=True, kwargs={\"normalize\": True})\n", 333 | "rdm_db.assign_attrs(units=\"dB\")" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "id": "01230d8e", 339 | "metadata": {}, 340 | "source": [ 341 | "Range & Doppler spectrograms in dB can also be calculated" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "id": "153dd8c5", 348 | "metadata": { 349 | "collapsed": false, 350 | "jupyter": { 351 | "outputs_hidden": false 352 | }, 353 | "pycharm": { 354 | "name": "#%%\n" 355 | } 356 | }, 357 | "outputs": [], 358 | "source": [ 359 | "rdm_abs = np.abs(rdm_synth)\n", 360 | "rspect = rdm_abs.sum(dim=\"doppler\").assign_attrs({\"long_name\": \"Range spectrogram\", \"units\": \"dB\"})\n", 361 | "dspect = rdm_abs.sum(dim=\"range\").assign_attrs({\"long_name\": \"Doppler spectrogram\", \"units\": \"dB\"})\n", 362 | "synth_spects = xr.Dataset({\"range_spect\": rspect, \"doppler_spect\": dspect}, attrs=rdm_synth.attrs)\n", 363 | "synth_spects = xr.apply_ufunc(radar.mag2db, synth_spects, keep_attrs=True, kwargs={\"normalize\": True})\n", 364 | "\n", 365 | "synth_spects" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "id": "9d5011b5-7fc7-4ba5-94f9-03c79190c9b2", 371 | "metadata": { 372 | "tags": [] 373 | }, 374 | "source": [ 375 | "### Data animations" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "id": "325edabd-469d-44ff-b603-114fe1c41d77", 381 | "metadata": { 382 | "tags": [] 383 | }, 384 | "source": [ 385 | "Process range & Doppler information from the real recording and extract a short time slice from it" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": null, 391 | "id": "22205cdc", 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "time_slice = slice(\"00:00:52\", \"00:01:07\")\n", 396 | "\n", 397 | "skeleton_slice = sk_da.sel(time=time_slice)\n", 398 | "\n", 399 | "rdm_real = raw2rdm(recording.data, rec_config, env, name=f\"{activity}-{config_name}\")\n", 400 | "rdm_rabs = np.abs(rdm_real.sel(time=time_slice))\n", 401 | "\n", 402 | "rng_spect = rdm_rabs.sum(dim=\"doppler\").assign_attrs({\"long_name\": \"Range spectrogram\", \"units\": \"dB\"})\n", 403 | "dopp_spect = rdm_rabs.sum(dim=\"range\").assign_attrs({\"long_name\": \"Doppler spectrogram\", \"units\": \"dB\"})\n", 404 | "real_spects = xr.Dataset({\"range_spect\": rng_spect, \"doppler_spect\": dopp_spect}, attrs=rdm_real.attrs)\n", 405 | "real_spects = xr.apply_ufunc(radar.mag2db, real_spects, keep_attrs=True, kwargs={\"normalize\": True})\n", 406 | "\n", 407 | "rdm_real_db = xr.apply_ufunc(radar.mag2db, rdm_rabs, keep_attrs=True, kwargs={\"normalize\": True})\n", 408 | "rdm_synth_db = xr.apply_ufunc(radar.normalize_db, rdm_db.sel(time=time_slice), keep_attrs=True)\n" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "id": "540ab81f", 415 | "metadata": { 416 | "tags": [] 417 | }, 418 | "outputs": [], 419 | "source": [ 420 | "%%capture\n", 421 | "anim_rs = animation_real_synthetic(rdm_real_db, rdm_synth_db, skeleton_slice,\n", 422 | " sensor_loc=rec_config[\"position\"], vmin=-40, notebook=True)\n", 423 | "anim_spects = animation_spec(rdm_real_db, real_spects, vmin=-40, notebook=True)" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "id": "031dfc22", 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [ 433 | "HTML(anim_rs.to_html5_video())" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": null, 439 | "id": "f6623b79", 440 | "metadata": {}, 441 | "outputs": [], 442 | "source": [ 443 | "HTML(anim_spects.to_html5_video())" 444 | ] 445 | }, 446 | { 447 | "cell_type": "markdown", 448 | "id": "9398ee61-47c8-4398-9ee7-9660e497e54a", 449 | "metadata": { 450 | "pycharm": { 451 | "name": "#%% md\n" 452 | }, 453 | "tags": [] 454 | }, 455 | "source": [ 456 | "## Image transformation" 457 | ] 458 | }, 459 | { 460 | "cell_type": "markdown", 461 | "id": "69c57ed0-5214-4e4d-915d-6693f2bfa976", 462 | "metadata": { 463 | "pycharm": { 464 | "name": "#%% md\n" 465 | } 466 | }, 467 | "source": [ 468 | "![RACPIT Architecture](../images/model.png)" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "id": "f9e98f4f-e139-47f9-81cd-4b9996a4f35e", 474 | "metadata": { 475 | "pycharm": { 476 | "name": "#%% md\n" 477 | }, 478 | "tags": [] 479 | }, 480 | "source": [ 481 | "### Open recordings as a list of `Xarray.Datasets`" 482 | ] 483 | }, 484 | { 485 | "cell_type": "markdown", 486 | "id": "a918b594-c353-48bb-824c-d24010a84a78", 487 | "metadata": { 488 | "pycharm": { 489 | "name": "#%% md\n" 490 | } 491 | }, 492 | "source": [ 493 | "The example uses lazy load, but in the GPU `load=True` boosts performance" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": null, 499 | "id": "1b4be7b0-291d-4e12-b0ef-818b80a2ad73", 500 | "metadata": {}, 501 | "outputs": [], 502 | "source": [ 503 | "real_recs = open_recordings(itn_config, real_path, load=False)\n", 504 | "synth_recs = open_recordings(itn_config, synth_path, load=False)" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": null, 510 | "id": "20c2487e", 511 | "metadata": { 512 | "tags": [] 513 | }, 514 | "outputs": [], 515 | "source": [ 516 | "print(f\"{len(real_recs)} recordings have been lazy loaded\")\n", 517 | "i_rec = [i for i, rec in enumerate(real_recs) if rec.date == itn_recording][0]\n", 518 | "print(f\"Recording {itn_recording} found at index {i_rec}\")" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": null, 524 | "id": "13139269", 525 | "metadata": {}, 526 | "outputs": [], 527 | "source": [ 528 | "real_rec = real_recs[i_rec]\n", 529 | "synth_rec = synth_recs[i_rec]" 530 | ] 531 | }, 532 | { 533 | "cell_type": "markdown", 534 | "id": "679b6f57-36f4-44a9-a224-0a08048663c3", 535 | "metadata": {}, 536 | "source": [ 537 | "Extract short spectrograms from the recordings" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": null, 543 | "id": "15ebc9b2-8dd9-4df8-9a93-fae0425775ca", 544 | "metadata": {}, 545 | "outputs": [], 546 | "source": [ 547 | "tslice = slice(\"00:01:39\", None)\n", 548 | "time_length = 64\n", 549 | "\n", 550 | "real_spects = real_rec.drop_vars('label').sel(time=tslice).isel(time=slice(0,time_length))\n", 551 | "synth_spects = synth_rec.drop_vars('label').sel(time=tslice).isel(time=slice(0,time_length))\n", 552 | "\n", 553 | "real_spects = xr.apply_ufunc(radar.normalize_db, real_spects.load(), keep_attrs=True)\n", 554 | "synth_spects = xr.apply_ufunc(radar.normalize_db, synth_spects.load(), keep_attrs=True)" 555 | ] 556 | }, 557 | { 558 | "cell_type": "markdown", 559 | "id": "89442fe7-ba23-4d46-90ac-75b3431056e4", 560 | "metadata": { 561 | "tags": [] 562 | }, 563 | "source": [ 564 | "### Transform images" 565 | ] 566 | }, 567 | { 568 | "cell_type": "markdown", 569 | "id": "732e14cf-44c5-46b3-ac74-54a327eecbd6", 570 | "metadata": { 571 | "tags": [] 572 | }, 573 | "source": [ 574 | "Load Image Transformation Network" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": null, 580 | "id": "2961332a-afef-410b-a0dd-df8d72415632", 581 | "metadata": { 582 | "tags": [] 583 | }, 584 | "outputs": [], 585 | "source": [ 586 | "dtype = torch.FloatTensor\n", 587 | "\n", 588 | "transformer = MultiTransformNet(num_inputs=2, num_channels=1)\n", 589 | "transformer.load_state_dict(torch.load(itn_path))\n", 590 | "_ = transformer.eval()" 591 | ] 592 | }, 593 | { 594 | "cell_type": "markdown", 595 | "id": "2d575fb9-3dc6-4c73-a6cf-7747e28cf742", 596 | "metadata": {}, 597 | "source": [ 598 | "Transform real data with the ITN" 599 | ] 600 | }, 601 | { 602 | "cell_type": "code", 603 | "execution_count": null, 604 | "id": "7fe74785-25d5-4de3-a45f-b7a108812d42", 605 | "metadata": {}, 606 | "outputs": [], 607 | "source": [ 608 | "range_inp = torch.from_numpy(real_spects.range_spect.values[None, None, :, :])\n", 609 | "dopp_inp = torch.from_numpy(real_spects.doppler_spect.values[None, None, :, :])\n", 610 | "spec_input = [Variable(range_inp, requires_grad=False).type(dtype), Variable(dopp_inp, requires_grad=False).type(dtype)]\n", 611 | "\n", 612 | "range_hat, doppler_hat = transformer(spec_input)\n", 613 | "range_trans = range_hat.detach().numpy()\n", 614 | "doppler_trans = doppler_hat.detach().numpy()" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": null, 620 | "id": "e148ead0-3356-4d10-bd54-22dfde6e114f", 621 | "metadata": {}, 622 | "outputs": [], 623 | "source": [ 624 | "spectrograms = xr.merge([real_spects.rename_vars(range_spect=\"range_real\", doppler_spect=\"doppler_real\"),\n", 625 | " synth_spects.rename_vars(range_spect=\"range_synth\", doppler_spect=\"doppler_synth\")],\n", 626 | " combine_attrs=\"drop_conflicts\")\n", 627 | "spectrograms[\"range_trans\"] = (['time', 'range'], np.squeeze(range_trans), {\"units\": \"dB\"})\n", 628 | "spectrograms[\"doppler_trans\"] = (['time', 'doppler'], np.squeeze(doppler_trans), {\"units\": \"dB\"})" 629 | ] 630 | }, 631 | { 632 | "cell_type": "markdown", 633 | "id": "ba9a2ac9-334d-43aa-b7f5-446b989a43ff", 634 | "metadata": { 635 | "jupyter": { 636 | "outputs_hidden": false 637 | }, 638 | "pycharm": { 639 | "name": "#%%\n" 640 | }, 641 | "tags": [] 642 | }, 643 | "source": [ 644 | "### Plotting" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": null, 650 | "id": "a4e3609c-fcb3-4bcd-b5de-8460c966aff8", 651 | "metadata": { 652 | "collapsed": false, 653 | "jupyter": { 654 | "outputs_hidden": false 655 | }, 656 | "pycharm": { 657 | "name": "#%%\n" 658 | } 659 | }, 660 | "outputs": [], 661 | "source": [ 662 | "fig, axes = plt.subplots(2, 3, figsize=(12, 8))\n", 663 | "\n", 664 | "spec_plot(spectrograms[[\"range_real\", \"range_trans\", \"range_synth\", \"doppler_real\", \"doppler_trans\", \"doppler_synth\"]],\n", 665 | " axes=axes.flatten(), ax_xlabel=axes[-1], vmin=-40, vmax=0, add_colorbar=False)\n", 666 | "\n", 667 | "im = axes[-1][-1].get_images()[0]\n", 668 | "fig.align_ylabels(axes[:,0])\n", 669 | "_ = [ax.set_ylabel(None) for ax in axes[:,1:].flatten()]\n", 670 | "\n", 671 | "cbar = plt.colorbar(im, ax=axes, orientation=\"horizontal\")\n", 672 | "cbar.set_label(\"Amplitude [dB]\")\n", 673 | "\n", 674 | "for title, ax in zip((\"Real data $x$\", \"Transformed data $\\widehat{y}$\", \"Synthetic data $y$\"), axes[0]):\n", 675 | " ax.set_title(title)" 676 | ] 677 | } 678 | ], 679 | "metadata": { 680 | "kernelspec": { 681 | "display_name": "Python 3 (ipykernel)", 682 | "language": "python", 683 | "name": "python3" 684 | }, 685 | "language_info": { 686 | "codemirror_mode": { 687 | "name": "ipython", 688 | "version": 3 689 | }, 690 | "file_extension": ".py", 691 | "mimetype": "text/x-python", 692 | "name": "python", 693 | "nbconvert_exporter": "python", 694 | "pygments_lexer": "ipython3", 695 | "version": "3.8.12" 696 | }, 697 | "pycharm": { 698 | "stem_cell": { 699 | "cell_type": "raw", 700 | "source": [], 701 | "metadata": { 702 | "collapsed": false 703 | } 704 | } 705 | }, 706 | "toc-autonumbering": false, 707 | "toc-showcode": false, 708 | "toc-showmarkdowntxt": true 709 | }, 710 | "nbformat": 4, 711 | "nbformat_minor": 5 712 | } -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | from ifxaion.daq import Daq 2 | from ifxaion.radar.utils import processing_functions as pf 3 | 4 | from utils import radar 5 | 6 | import json 7 | import numpy as np 8 | import pandas as pd 9 | import xarray as xr 10 | from dask import is_dask_collection 11 | 12 | from pathlib import Path 13 | from datetime import datetime 14 | 15 | from utils.synthesize import synthetic_radar 16 | from utils.skeletons import load as skload 17 | 18 | _configs_path = Path.cwd() / 'configurations' / 'radar_configs.csv' 19 | _interference_path = None 20 | 21 | _base_path = Path("/mnt/infineon-radar") 22 | _raw_path = _base_path / "daq_x-har" 23 | _preprocessed_path = _base_path / "preprocessed" / "daq_x-har" 24 | 25 | _margin_opts = ("none", "coherent", "incoherent") 26 | _margin_default = _margin_opts[-1] 27 | 28 | _radar_transformations = { 29 | 'complex': { 30 | "func": lambda d: xr.apply_ufunc(radar.complex2vector, d, 31 | output_core_dims=[["complex"]], keep_attrs=True), 32 | "units": "Complex amplitude"}, 33 | 'magnitude': { 34 | "func": lambda d: xr.apply_ufunc(radar.absolute, d, 35 | keep_attrs=True, kwargs={"normalize": True}), 36 | "units": "Magnitude"}, 37 | 'db': { 38 | "func": lambda d: xr.apply_ufunc(radar.mag2db, d, 39 | keep_attrs=True, kwargs={"normalize": True}), 40 | "units": "dB"}} 41 | _radar_val_keys = tuple(_radar_transformations.keys()) 42 | _radar_val_default = _radar_val_keys[-1] 43 | 44 | configs = pd.read_csv(_configs_path, index_col=0, sep=",").T 45 | config_names = configs.index 46 | 47 | 48 | def dict_included(subset_dict, other_dict): 49 | try: 50 | return all(other_dict[k] == v for k, v in subset_dict.items()) 51 | except KeyError: 52 | return False 53 | 54 | 55 | def identify_config(radar_config): 56 | found_config = configs[configs.apply(dict_included, axis=1, args=[radar_config])] 57 | try: 58 | config_name = found_config.index[0] 59 | except IndexError: 60 | config_name = "UNKNOWN" 61 | return config_name 62 | 63 | 64 | class StrEncoder(json.JSONEncoder): 65 | _str_cls = (datetime, Path) 66 | 67 | def default(self, obj): 68 | if any(isinstance(obj, c) for c in self._str_cls): 69 | return str(obj) # Let the base class default method raise the TypeError 70 | return json.JSONEncoder.default(self, obj) 71 | 72 | 73 | def preprocess(raw_path=None, preprocessed_path=None, synthetic=False, 74 | range_length=None, doppler_length=None, 75 | range_scope=None, doppler_scope=None, resample_ms=None, 76 | value=_radar_val_default, marginalize=_margin_default, 77 | interference_path=_interference_path): 78 | """Preprocess and save radar data for later training 79 | 80 | :param raw_path: path to dataset 81 | :param preprocessed_path: path to output dataset 82 | :param synthetic: if True, generate radar signals from skeleton data and use it instead of real data 83 | :param range_length: Length of range FFT 84 | :param doppler_length: Length of Doppler FFT 85 | :param range_scope: If given (in meters), range_length will be referred only until this point 86 | :param doppler_scope: If given (in m/s), doppler_length will be referred only until this point 87 | :param resample_ms: New frame period to resample in milliseconds 88 | :param marginalize: If a non-empty string, save range and doppler spectrograms (in)coherently, 89 | otherwise range doppler maps 90 | :param value: Save data as either complex amplitude, magnitude or decibels 91 | :param interference_path: Path to a JSON file indicating interference chunks to discard 92 | :return: metadata dictionary 93 | """ 94 | 95 | created_at = datetime.now() 96 | 97 | if raw_path is None: 98 | raw_path = _raw_path 99 | else: 100 | raw_path = Path(raw_path) 101 | 102 | metadata = locals() 103 | del metadata["preprocessed_path"] 104 | 105 | if marginalize not in _margin_opts: 106 | raise ValueError(f"Unrecognized marginalize option '{marginalize}', choose from {_margin_opts}") 107 | 108 | try: 109 | rdm_transform = _radar_transformations[value] 110 | except KeyError as ke: 111 | raise ValueError(f"Unrecognized value '{value}'. Value must belong to {_radar_val_keys}") from ke 112 | 113 | if preprocessed_path is None: 114 | preprocessed_path = _preprocessed_path 115 | else: 116 | preprocessed_path = Path(preprocessed_path) 117 | 118 | metadata_path = preprocessed_path / "metadata.json" 119 | 120 | try: 121 | with open(metadata_path, "w") as wf: 122 | json.dump(metadata, wf, cls=StrEncoder, indent=4) 123 | except FileNotFoundError: 124 | preprocessed_path.mkdir(parents=True) 125 | with open(metadata_path, "w") as wf: 126 | json.dump(metadata, wf, cls=StrEncoder, indent=4) 127 | 128 | if interference_path is None: 129 | interferences = None 130 | else: 131 | with open(interference_path, 'r') as rf: 132 | interferences = json.load(rf) 133 | 134 | label_suffix = 10 135 | unknown_config = "UNKNOWN" 136 | 137 | for activity_path in raw_path.iterdir(): 138 | activity_dir = str(activity_path.relative_to(raw_path)) 139 | activity = activity_dir[:-label_suffix] 140 | print(f"Converting {activity}") 141 | for rec_path in activity_path.iterdir(): 142 | rec_name = str(rec_path.relative_to(activity_path).with_suffix('')) 143 | 144 | if synthetic: 145 | skeletons = skload(rec_path, verbose=True) 146 | else: 147 | skeletons = None 148 | 149 | daq = Daq(rec_dir=rec_path) 150 | env = daq.env 151 | radars = daq.radar 152 | n = len(radars) 153 | print(f"{n} radar files were found in {rec_path}") 154 | print("Environment:", env, sep="\n") 155 | del env["software"] 156 | del env["room_size"] 157 | env["synthetic"] = synthetic 158 | 159 | rec_signals = [] 160 | 161 | for recording in radars: 162 | rec_config = recording.cfg 163 | config_name = identify_config(rec_config) 164 | 165 | rdm_dir = preprocessed_path / config_name / activity 166 | rdm_path = rdm_dir / rec_name 167 | 168 | rec_config['RadarName'] = rec_config.pop("Name") 169 | rec_config['cfg'] = config_name 170 | 171 | if not rdm_dir.is_dir(): 172 | try: 173 | rdm_dir.mkdir(parents=False) 174 | except FileNotFoundError: 175 | rdm_dir.mkdir(parents=True) 176 | if config_name != unknown_config: 177 | with open(rdm_dir.parent / "config.json", 'w') as fp: 178 | json.dump(rec_config, fp, indent=4) 179 | 180 | rec_config["activity"] = activity 181 | 182 | print(f"Reading data with configuration {config_name}") 183 | 184 | interf_slices = [] 185 | if interferences is not None: 186 | interf_slices = interferences[activity_dir][rec_name][config_name] 187 | 188 | named_slices = [] 189 | if len(interf_slices) == 0: 190 | named_slices.append((rdm_path, slice(None, None))) 191 | else: 192 | print("Removing interferences") 193 | data_slices = complementary_slices(interf_slices) 194 | path_base = str(rdm_path) 195 | for i, data_slice in enumerate(data_slices): 196 | rdm_file = Path(f"{path_base}_{i+1}") 197 | named_slices.append((rdm_file, data_slice)) 198 | 199 | for rdm_file, time_slice in named_slices: 200 | rec_data = recording.data[time_slice] 201 | 202 | timestamps = rec_data.index 203 | frame_interval_ms = np.mean((timestamps[1:] - timestamps[:-1]).total_seconds()) * 1e3 204 | duration_sec = (timestamps[-1] - timestamps[0]).total_seconds() 205 | print(f'Mean frame interval:\t{frame_interval_ms} ms') 206 | print(f'Total duration:\t{duration_sec} seconds') 207 | 208 | if synthetic: 209 | n_samples = rec_config['SamplesPerChirp'] 210 | m_chirps = rec_config['ChirpsPerFrame'] 211 | 212 | sk_slice = skeletons[time_slice] 213 | syntheticData = synthetic_radar(sk_slice, rec_config, timestamps.total_seconds()) 214 | 215 | assert syntheticData.shape[-2] == m_chirps, "Incorrect #chirps of synthetic data" 216 | assert syntheticData.shape[-1] == n_samples, "Incorrect #samples per chirp of synthetic data" 217 | 218 | sMin = syntheticData.min() 219 | sMax = syntheticData.max() 220 | dynamic_range = sMax - sMin 221 | if dynamic_range > 0: 222 | sNorm = (syntheticData - sMin) / dynamic_range 223 | else: 224 | print(f"\n*** WARNING ***\nSynthetic {rdm_file}-{config_name} " 225 | f"has null dynamic range\n***************\n") 226 | sNorm = syntheticData 227 | rec_data = pd.DataFrame({"Timestamps": timestamps, 228 | "NormData": [sn for sn in sNorm]}).set_index("Timestamps") 229 | rec_signals.append((rdm_file, rec_config, rec_data)) 230 | 231 | for rdm_file, rec_config, rData in rec_signals: 232 | 233 | n_samples = rec_config['SamplesPerChirp'] 234 | m_chirps = rec_config['ChirpsPerFrame'] 235 | frame_period_ms = rec_config['FramePeriod'] // 1000 236 | bw_MHz = rec_config['UpperFrequency'] - rec_config['LowerFrequency'] 237 | prt_ns = rec_config['ChirpToChirpTime'] 238 | 239 | config_name = rec_config['cfg'] 240 | rdm_name = f"{rdm_file}-config_{config_name}" 241 | 242 | dr, r_max = radar.range_axis(bw_MHz * 10 ** 6, n_samples) 243 | dv, v_max = radar.doppler_axis(prt_ns * 10 ** (-9), m_chirps) 244 | 245 | if range_scope is None: 246 | r_length = range_length 247 | else: 248 | if r_max < range_scope: 249 | raise ValueError(f"Configuration {config_name} has a max. range of {r_max} m, " 250 | f"below the desired scope of {range_scope} m") 251 | r_length = round(range_length * r_max / range_scope) 252 | if doppler_scope is None: 253 | d_length = doppler_length 254 | else: 255 | if v_max < doppler_scope: 256 | raise ValueError(f"Configuration {config_name} has a max. velocity of {v_max} m/s, " 257 | f"below the desired scope of {doppler_scope} m/s") 258 | d_length = round(doppler_length * v_max / doppler_scope) 259 | 260 | if range_scope or doppler_scope: 261 | print(f"FFT parameters:\n\tRange scope: {range_scope} m,\tDoppler scope: {doppler_scope} m/s") 262 | print(f"\tRange length: {range_length} under scope, {r_length} in total") 263 | print(f"\tDoppler length: {doppler_length} under scope, {d_length} in total") 264 | 265 | rdm_da = raw2rdm(rData, rec_config, env, r_length=r_length, d_length=d_length, name=rdm_name) 266 | 267 | if marginalize == _margin_opts[0]: # none 268 | ds = xr.Dataset({"rdm": rdm_da}, attrs=rdm_da.attrs) 269 | ds.rdm.attrs["long_name"] = "Range Doppler map" 270 | else: 271 | if marginalize == _margin_opts[2]: # incoherent 272 | print("Incoherent marginalization") 273 | rdm_da = np.abs(rdm_da) 274 | else: # coherent 275 | print("Coherent marginalization") 276 | rspect = rdm_da.sum(dim="doppler") 277 | dspect = rdm_da.sum(dim="range") 278 | ds = xr.Dataset({"range_spect": rspect, "doppler_spect": dspect}, attrs=rdm_da.attrs) 279 | ds.range_spect.attrs["long_name"] = "Range spectrogram" 280 | ds.doppler_spect.attrs["long_name"] = "Doppler spectrogram" 281 | 282 | if resample_ms is not None and resample_ms != frame_period_ms: 283 | old_frames = ds.sizes["time"] 284 | print(f"Interpolating frame period from {frame_period_ms}ms to {resample_ms}ms") 285 | ds = ds.resample(time=f"{resample_ms}ms").interpolate("cubic") 286 | new_frames = ds.sizes["time"] 287 | print(f"{old_frames} frames resampled into {new_frames} frames") 288 | ds = ds.assign_attrs(frame_period_ms=resample_ms) 289 | else: 290 | ds = ds.assign_attrs(frame_period_ms=frame_period_ms) 291 | 292 | del ds.attrs["units"] 293 | for k, da in ds.data_vars.items(): 294 | ds[k] = rdm_transform["func"](da) 295 | ds[k].attrs["units"] = rdm_transform["units"] 296 | 297 | if "complex" in ds.coords: 298 | ds["complex"] = ["real", "imag"] 299 | 300 | rdm_path = rdm_file.with_suffix(".nc") 301 | ds.to_netcdf(rdm_path) 302 | print(f"Data saved under {rdm_path}\n") 303 | return metadata 304 | 305 | 306 | def raw2rdm(raw_data, rec_config, env, antennas=0, r_length=None, d_length=None, name="rdm"): 307 | """ 308 | Convert raw_data to a sequence of Range Doppler Maps (RDM) and return it as an xarray.DataArray 309 | :param raw_data: Radar raw data as returned by daq 310 | :param rec_config: RadarConfig of the recording as returned by daq 311 | :param env: Environment data as returned by daq 312 | :param antennas: Antenna indices to be preprocessed (Only a single index is supported) 313 | :param r_length: Effective length of the FFT over the range axis 314 | :param d_length: Effective length of the FFT over the doppler axis 315 | :param name: Name of the resulting xarray.DataArray 316 | :return: xarray.DataArray with a sequence of RDMs and embedded metadata. 317 | The dimensions of each RDM are (doppler_length, range_length) if used, otherwise (ChirpsPerFrame, SamplesPerChirp/2) 318 | """ 319 | rdm_data = pf.preprocess_radar_data(raw_data, 320 | range_length=r_length, 321 | doppler_length=d_length, 322 | antennas=antennas).squeeze() 323 | doppler_bins, range_bins = rdm_data.shape[-2:] 324 | 325 | print(f"Data shape: {rdm_data.shape}") 326 | 327 | timestamps = raw_data.index 328 | 329 | n_samples = rec_config['SamplesPerChirp'] 330 | m_chirps = rec_config['ChirpsPerFrame'] 331 | 332 | bw_MHz = rec_config['UpperFrequency'] - rec_config['LowerFrequency'] 333 | prt_ns = rec_config['ChirpToChirpTime'] 334 | 335 | dr, r_max = radar.range_axis(bw_MHz * 10 ** 6, n_samples) 336 | dv, v_max = radar.doppler_axis(prt_ns * 10 ** (-9), m_chirps) 337 | 338 | range_coords = np.linspace(0, r_max, range_bins) 339 | doppler_coords = np.linspace(-v_max, v_max, doppler_bins) 340 | 341 | rdm_da = xr.DataArray(rdm_data, dims=("time", "doppler", "range"), 342 | name=name, attrs={"units": "Complex amplitude"}, 343 | coords={"time": timestamps.to_numpy(), 344 | "doppler": doppler_coords, 345 | "range": range_coords}) 346 | 347 | rdm_da.range.attrs["units"] = 'm' 348 | rdm_da.doppler.attrs["units"] = "m/s" 349 | 350 | rdm_da = rdm_da.assign_attrs(rec_config) 351 | rdm_da = rdm_da.assign_attrs(env) 352 | 353 | # Delete problematic attributes 354 | del rdm_da.attrs["position"] 355 | del rdm_da.attrs["orientation"] 356 | del rdm_da.attrs["transformation"] 357 | 358 | return rdm_da 359 | 360 | 361 | def open_recordings(radar_configs, preprocessed_path=None, range_length=None, doppler_length=None, 362 | categorical=True, load=False): 363 | """ 364 | Open preprocessed recordings for all activities of a certain (or several) radar configuration using xarray 365 | :param radar_configs: One or more radar configurations, e.g. "E" or ["A", "B", "C"] 366 | :param preprocessed_path: Path to the preprocessed data 367 | :param range_length: If provided, data will be cropped on the range axis from 0 to range_length 368 | :param doppler_length: If provided, data will be cropped on the doppler axis from 0 to doppler_length 369 | :param categorical: If True, the activities are categorized as int labels 370 | :param load: If True, Datasets are loaded, otherwise use xarray's lazy load 371 | :return: List of xarray.datasets with the recordings 372 | """ 373 | ds_kwargs = {} if load else dict(chunks="auto", cache=False) 374 | 375 | if preprocessed_path is None: 376 | preprocessed_path = _preprocessed_path 377 | else: 378 | preprocessed_path = Path(preprocessed_path) 379 | 380 | def concat_recs(cfg): 381 | cfg_path = preprocessed_path / cfg 382 | activities = [a for a in cfg_path.iterdir() if a.is_dir()] 383 | cfg_recs = [] 384 | for act_path in activities: 385 | for rec_file in act_path.iterdir(): 386 | with xr.open_dataset(rec_file, **ds_kwargs) as rec: 387 | if range_length is not None: 388 | rec = rec.isel(range=slice(range_length)) 389 | if doppler_length is not None: 390 | rec = rec.isel(doppler=slice(doppler_length)) 391 | if load: 392 | rec = rec.persist() 393 | cfg_recs.append(rec) 394 | if categorical: 395 | cfg_recs = categorize(cfg_recs, "activity", map_name="activities") 396 | return cfg_recs 397 | 398 | if type(radar_configs) == str: 399 | recordings = concat_recs(radar_configs) 400 | else: 401 | recordings = {c: concat_recs(c) for c in radar_configs} 402 | 403 | return recordings 404 | 405 | 406 | def categorize(recordings, attr, map_name="categories", var_name="label"): 407 | """ Add a global variable from a list of datasets as a categoric variable 408 | 409 | :param recordings: List of Datasets holding processed recordings 410 | :param attr: str, global attribute to categorize as variable 411 | :param map_name: str, name of the global attribute holding the category map 412 | :param var_name: str, name of the index variable 413 | :return: List of the same Datasets enriched with the categorized attribute 414 | """ 415 | indices, category_map = pd.factorize([r.attrs[attr] for r in recordings], sort=True) 416 | return [r.assign({var_name: i}).assign_attrs({map_name: tuple(category_map)}) 417 | for r, i in zip(recordings, indices)] 418 | 419 | 420 | def recs2dataset(recordings, sliced_indices, ignore_dims=None): 421 | """ Concatenate recording slices into a Dataset 422 | 423 | :param recordings: List of Datasets holding processed recordings 424 | :param sliced_indices: Double nested list of Range objects to slice recordings 425 | :param ignore_dims: list with the names of dimensions to be reset to plain indices 426 | :return: Concatenated Dataset with all sliced recordings 427 | """ 428 | if ignore_dims is not None: 429 | ignore_dims = {dim: f"{dim}_" for dim in ignore_dims} 430 | rec_slices = [] 431 | for sl in sliced_indices: 432 | rec = slice_recording(recordings, sl) 433 | rec = reset_time(rec) 434 | if ignore_dims is not None: 435 | rec = rec.rename_vars(ignore_dims).reset_coords(ignore_dims.values(), drop=True) 436 | rec_slices.append(rec) 437 | ds = concat_slices(rec_slices, combine_attrs="drop_conflicts") 438 | ds.label.attrs["long_name"] = "Activity label" 439 | return ds 440 | 441 | 442 | def slice_recording(recordings, indexed_slice): 443 | i, r_slice = indexed_slice 444 | try: 445 | sliced_rec = recordings[i].isel(time=r_slice) 446 | except IndexError as ie: 447 | t_len = recordings[i].sizes["time"] 448 | shift = t_len - r_slice.stop 449 | print(f"Caught index error: {ie}, shifting slice by {shift}") 450 | sliced_rec = recordings[i].isel(time=range(r_slice.start + shift, r_slice.stop + shift)) 451 | return sliced_rec 452 | 453 | 454 | def concat_slices(slices, dim="batch_index", combine_attrs="drop"): 455 | return xr.concat(slices, dim=dim, combine_attrs=combine_attrs) 456 | 457 | 458 | def reset_time(dataset, add_offset=True): 459 | if add_offset: 460 | dataset = dataset.assign_attrs(date=pd.to_datetime(dataset.date) + dataset.time[0].values) 461 | return dataset.assign_coords(time=lambda d: np.arange(d.sizes["time"]) * np.timedelta64(d.frame_period_ms, 'ms')) 462 | 463 | 464 | def apply_processing(dataset, funcs, dask="allowed"): 465 | """Apply a sequence of processing functions to an xarray in a lazy fashion 466 | 467 | Args: 468 | dataset: xarray Object 469 | funcs: sequence of ufuncs 470 | dask: Dask option to pass to apply_ufunc 471 | 472 | Returns: The processed xarray object 473 | 474 | """ 475 | for fn, args, kwargs in funcs: 476 | dataset = xr.apply_ufunc(fn, dataset, *args, keep_attrs=True, kwargs=kwargs, dask=dask) 477 | return dataset 478 | 479 | 480 | def proc_register(processing_functions, func, *args, **kwargs): 481 | processing_functions.append((func, args, kwargs)) 482 | 483 | 484 | def get_size(dataset, loaded_only=True, human=False): 485 | def get_bytes(ds, human_b=False): 486 | byte_size = 0 487 | for var in ds.variables.values(): 488 | if not loaded_only or not is_dask_collection(var): 489 | byte_size += var.nbytes 490 | if human_b: 491 | byte_size = human_bytes(byte_size) 492 | return byte_size 493 | 494 | if any(isinstance(dataset, c) for c in (list, tuple)): 495 | summed_bytes = sum(get_bytes(d) for d in dataset) 496 | if human: 497 | summed_bytes = human_bytes(summed_bytes) 498 | return summed_bytes 499 | else: 500 | return get_bytes(dataset, human_b=human) 501 | 502 | 503 | def human_bytes(byte_size, digits=2, units=('bytes', 'kB', 'MB', 'GB', 'TB')): 504 | if byte_size <= 0: 505 | return False 506 | exponent = int(np.floor(np.log2(byte_size) / 10)) 507 | mantissa = byte_size / (1 << (exponent * 10)) 508 | hbytes = f"{mantissa:.{digits}f}{units[exponent]}" 509 | 510 | return hbytes 511 | 512 | 513 | def get_scope(ds, unit): 514 | with xr.set_options(keep_attrs=True): 515 | return ds.coords[unit][-1] - ds.coords[unit][0] 516 | 517 | 518 | def complementary_slices(slices): 519 | comp_slices = [] 520 | next_start = None 521 | for sl in slices: 522 | if isinstance(sl, slice): 523 | sl = {'start': sl.start, 'stop': sl.stop} 524 | if next_start != sl["start"]: 525 | comp_slices.append(slice(next_start, sl["start"])) 526 | next_start = sl["stop"] 527 | if next_start is not None: 528 | comp_slices.append(slice(next_start, None)) 529 | return comp_slices 530 | 531 | 532 | if __name__ == '__main__': 533 | 534 | import argparse 535 | 536 | parser = argparse.ArgumentParser(description='FMCW radar preprocessing') 537 | parser.add_argument("--raw", type=str, default=str(_raw_path), help="Path to the raw radar data") 538 | parser.add_argument("--output", type=str, default=str(_preprocessed_path), 539 | help="Path to the output the preprocessed data") 540 | parser.add_argument("--interference", type=str, default=str(_interference_path), 541 | help="Path to a JSON file indicating interference chunks to discard") 542 | parser.add_argument("--synthetic", action='store_true', help="Use to create radar signals from skeleton data") 543 | parser.add_argument("--range-length", type=int, default=None, 544 | help="Length of range FFT, default to half the number of samples") 545 | parser.add_argument("--doppler-length", type=int, default=None, 546 | help="Length of Doppler FFT, default to number of chirps") 547 | parser.add_argument("--range-scope", type=float, default=None, 548 | help="If given (in meters), range-length will be referred only until this point") 549 | parser.add_argument("--doppler-scope", type=float, default=None, 550 | help="If given (in m/s), doppler-length will be referred only until this point") 551 | parser.add_argument("--resample", type=float, default=None, help="New frame period to resample in milliseconds") 552 | parser.add_argument("--value", type=str, choices=_radar_val_keys, default=_radar_val_default, 553 | help="Save data as either complex amplitude, magnitude or decibels") 554 | parser.add_argument("--marginalize", type=str, choices=_margin_opts, default=_margin_default, 555 | help ="'none' to save range Doppler maps, " 556 | "otherwise (in)coherent range and Doppler spectrograms") 557 | 558 | p_args = parser.parse_args() 559 | preprocess_kwargs = vars(p_args) 560 | 561 | preprocess_kwargs["raw_path"] = preprocess_kwargs.pop("raw") 562 | preprocess_kwargs["preprocessed_path"] = preprocess_kwargs.pop("output") 563 | preprocess_kwargs["interference_path"] = preprocess_kwargs.pop("interference") 564 | preprocess_kwargs["resample_ms"] = preprocess_kwargs.pop("resample") 565 | 566 | preprocess(**preprocess_kwargs) 567 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | import torch 4 | import os 5 | import argparse 6 | import time 7 | import json 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | from torch.autograd import Variable 12 | from torch.optim import Adam 13 | from torch.utils.data import DataLoader 14 | import torch.nn as nn 15 | 16 | from networks.img_transf import ImageTransformNet, MultiTransformNet 17 | from networks.perceptual import Vgg16, RDNet, RDPerceptual, RACPIT 18 | 19 | from utils.slicer import train_test_slice 20 | from utils.provider import RadarDataset 21 | from utils.preprocess import open_recordings 22 | from utils.visualization import spec_plot 23 | 24 | # Global Variables 25 | BATCH_TRANSFER = 4 26 | BATCH_CLASSIFY = 32 27 | LEARNING_RATE = 1e-3 28 | EPOCHS = 100 29 | CONTENT_WEIGHT = 1e0 30 | TV_WEIGHT = 1e-7 31 | 32 | REAL_PATH = "/mnt/infineon-radar/preprocessed/real" 33 | SYNTH_PATH = "/mnt/infineon-radar/preprocessed/synthetic" 34 | 35 | # Radar Processing variables 36 | range_length = 128 37 | doppler_length = 128 38 | time_length = 64 39 | hop_length = 8 40 | ignore_dims = False 41 | 42 | 43 | def train_transfer(args): 44 | # GPU enabling 45 | if args.gpu is None: 46 | use_cuda = False 47 | dtype = torch.FloatTensor 48 | label_type = torch.LongTensor 49 | print("No GPU training") 50 | else: 51 | use_cuda = True 52 | dtype = torch.cuda.FloatTensor 53 | label_type = torch.cuda.LongTensor 54 | torch.cuda.set_device(args.gpu) 55 | print("Current device: %d" % torch.cuda.current_device()) 56 | 57 | epochs = args.epochs 58 | 59 | # visualization of training controlled by flag 60 | visualize = 0 if args.visualize is None else args.visualize 61 | 62 | # define network 63 | if args.range: 64 | image_transformer = MultiTransformNet(num_inputs=2, num_channels=1).type(dtype) 65 | else: 66 | image_transformer = ImageTransformNet(num_channels=1).type(dtype) 67 | optimizer = Adam(image_transformer.parameters(), LEARNING_RATE) 68 | criterion = nn.CrossEntropyLoss() 69 | 70 | loss_mse = torch.nn.MSELoss() 71 | 72 | # get training dataset 73 | config = args.config 74 | input_path = args.input 75 | output_path = args.output 76 | print(f"Training with configurations {', '.join(config)}.") 77 | print(f"Using data from {input_path} as the input and data from {output_path} as an output.") 78 | 79 | train_load = 0.8 80 | num_workers = 4 81 | 82 | if args.recordings is not None: 83 | split = 'no-cut' 84 | train_load = 0.5 85 | elif args.segments is None: 86 | split = 'single' 87 | else: 88 | with open(args.segments, "r") as f: 89 | split = json.load(f) 90 | 91 | recordings_input = open_recordings(config, input_path, 92 | load=True, range_length=range_length, doppler_length=doppler_length) 93 | recordings_output = open_recordings(config, output_path, 94 | load=True, range_length=range_length, doppler_length=doppler_length) 95 | 96 | # Merge recordings from all configs 97 | recordings_input = [r for recs in recordings_input.values() for r in recs] 98 | recordings_output = [r for recs in recordings_output.values() for r in recs] 99 | if args.range: 100 | recordings_input = [r.rename_vars(range_spect="range_real") for r in recordings_input] 101 | recordings_output = [r.rename_vars(range_spect="range_synth") for r in recordings_output] 102 | else: # drop range spectrograms 103 | recordings_input = [r.drop_vars("range_spect") for r in recordings_input] 104 | recordings_output = [r.drop_vars("range_spect") for r in recordings_output] 105 | 106 | recordings = [xr.merge([rec_real.rename_vars(doppler_spect="doppler_real"), 107 | rec_synth.rename_vars(doppler_spect="doppler_synth")], combine_attrs="drop_conflicts") 108 | for rec_real, rec_synth in zip(recordings_input, recordings_output)] 109 | 110 | if args.classes is not None: 111 | classes = args.classes 112 | print(f"Selecting classes {classes}") 113 | new_recs = [] 114 | for r in recordings: 115 | if r.label in classes: 116 | new_recs.append(r) 117 | print(f"{len(new_recs)} out of {len(recordings)} selected") 118 | recordings = new_recs 119 | 120 | slice_kwargs = dict(spec_length=time_length, stride=hop_length, train_load=train_load, copy_split=0) 121 | loader_kwargs = dict(batch_size=BATCH_TRANSFER, shuffle=True, num_workers=num_workers, pin_memory=True) 122 | 123 | print("Preloading datasets...") 124 | slice_output = slice_datasets(recordings, split=split, **slice_kwargs) 125 | if args.recordings is None: 126 | if args.segments is None: 127 | [train_dataset, test_dataset], tgt_segments = slice_output 128 | else: 129 | train_dataset, test_dataset = slice_output 130 | elif args.recordings == 'first': 131 | [train_dataset, test_dataset], tgt_segments = slice_output 132 | elif args.recordings == 'last': 133 | [test_dataset, train_dataset], tgt_segments = slice_output 134 | else: 135 | raise ValueError(f"Unrecognized recordings option {args.recordings}") 136 | 137 | train_loader = DataLoader(train_dataset, **loader_kwargs) 138 | 139 | test_indices = np.random.choice(len(test_dataset), size=visualize, replace=False) 140 | 141 | class_num = train_dataset.class_num 142 | input_shapes = train_dataset.feature_shapes 143 | input_shapes = input_shapes[:len(input_shapes)//2] 144 | print(f"Number of classes: {class_num}") 145 | print(f"Feature shapes: {input_shapes}\n") 146 | 147 | # load perceptual loss network 148 | if args.model is None: 149 | perceptual_net = Vgg16().type(dtype) # Only works with single input i.e. without range 150 | else: 151 | perceptual_net = RDPerceptual(args.model, input_shapes=input_shapes, class_num=class_num).type(dtype) 152 | 153 | log_id = args.log 154 | 155 | # calculate gram matrices for style feature layer maps we care about 156 | # style_features = vgg(style) 157 | # style_gram = [utils.gram(fmap) for fmap in style_features] 158 | 159 | loss_logs = [] 160 | 161 | for e in range(epochs): 162 | 163 | # track values for... 164 | img_count = 0 165 | aggregate_content_loss = 0.0 166 | aggregate_classify_loss = 0.0 167 | aggregate_tv_loss = 0.0 168 | batch_num = 0 169 | 170 | # train network 171 | image_transformer.train() 172 | for batch_num, (feature_batch, label) in enumerate(train_loader): 173 | img_batch_read = len(label) 174 | img_count += img_batch_read 175 | 176 | if args.range: 177 | [real_range, real_doppler, synth_range, synth_doppler] = feature_batch 178 | x = [Variable(real_feat).type(dtype) for real_feat in (real_range, real_doppler)] 179 | y_c = [Variable(synth_feat).type(dtype) for synth_feat in (synth_range, synth_doppler)] 180 | pass 181 | else: 182 | [real_batch, synth_batch] = feature_batch 183 | x = Variable(real_batch).type(dtype) 184 | y_c = Variable(synth_batch).type(dtype) 185 | label_true = Variable(label).type(label_type) 186 | 187 | # zero out gradients 188 | optimizer.zero_grad() 189 | 190 | # input batch to transformer network 191 | y_hat = image_transformer(x) 192 | 193 | # get vgg features 194 | y_c_features = perceptual_net(y_c) 195 | y_hat_features = perceptual_net(y_hat) 196 | 197 | # calculate classification loss w.r.t. input 198 | label_pred = y_hat_features[0] 199 | classify_loss = CONTENT_WEIGHT*criterion(label_pred, label_true) 200 | aggregate_classify_loss += classify_loss.item() 201 | 202 | # calculate content loss (h_relu_2_2) 203 | recon = y_c_features[1] 204 | recon_hat = y_hat_features[1] 205 | content_loss = CONTENT_WEIGHT*loss_mse(recon_hat, recon) 206 | aggregate_content_loss += content_loss.item() 207 | 208 | # calculate total variation regularization (anisotropic version) 209 | # https://www.wikiwand.com/en/Total_variation_denoising 210 | if args.range: 211 | diff_i = 0.0 212 | diff_j = 0.0 213 | for y_h in y_hat: 214 | diff_i += torch.sum(torch.abs(y_h[:, :, :, 1:] - y_h[:, :, :, :-1])) 215 | diff_j += torch.sum(torch.abs(y_h[:, :, 1:, :] - y_h[:, :, :-1, :])) 216 | else: 217 | diff_i = torch.sum(torch.abs(y_hat[:, :, :, 1:] - y_hat[:, :, :, :-1])) 218 | diff_j = torch.sum(torch.abs(y_hat[:, :, 1:, :] - y_hat[:, :, :-1, :])) 219 | tv_loss = args.tv_weight*(diff_i + diff_j) 220 | aggregate_tv_loss += tv_loss.item() 221 | 222 | # total loss 223 | total_loss = content_loss + tv_loss 224 | 225 | # backprop 226 | total_loss.backward() 227 | optimizer.step() 228 | 229 | # print out status message 230 | if (batch_num + 1) % 100 == 0: 231 | status = f"{time.ctime()} Epoch {e + 1}: " \ 232 | f"[{img_count}/{len(train_dataset)}] Batch:[{batch_num+1}] " \ 233 | f"agg_content: {aggregate_content_loss/(batch_num+1.0):.6f} " \ 234 | f"agg_class: {aggregate_classify_loss / (batch_num + 1.0):.6f} " \ 235 | f"agg_tv: {aggregate_tv_loss/(batch_num+1.0):.6f} " \ 236 | f"content: {content_loss:.6f} class: {classify_loss:.6f} tv: {tv_loss:.6f} " 237 | print(status) 238 | 239 | if ((batch_num + 1) % 1000 == 0) and visualize is not None: 240 | image_transformer.eval() 241 | 242 | if not os.path.exists("visualization"): 243 | os.makedirs("visualization") 244 | if not os.path.exists("visualization/%s" %log_id): 245 | os.makedirs("visualization/%s" %log_id) 246 | 247 | for img_index in test_indices: 248 | test_ds = test_dataset.dataset[int(img_index)] 249 | doppler_test = torch.from_numpy(test_ds.doppler_real.values[None, None, :, :]) 250 | 251 | plt_path = f"visualization/{log_id}/" \ 252 | f"{test_ds.activity}_{test_ds.date.replace(':','-')}_e{e+1}_b{batch_num+1}.png" 253 | 254 | x_test = Variable(doppler_test, requires_grad=False).type(dtype) 255 | titles = ("Real data", "Synthetic data", "Generated data") 256 | if args.range: 257 | range_test = torch.from_numpy(test_ds.range_real.values[None, None, :, :]) 258 | x_test = [Variable(range_test, requires_grad=False).type(dtype), x_test] 259 | range_hat, doppler_hat = image_transformer(x_test) 260 | range_hat = range_hat.cpu().detach().numpy() 261 | doppler_hat = doppler_hat.cpu().detach().numpy() 262 | test_ds["range_gen"] = (['time', 'range'], np.squeeze(range_hat), {"units": "dB"}) 263 | test_ds["doppler_gen"] = (['time', 'doppler'], np.squeeze(doppler_hat), {"units": "dB"}) 264 | range_ds = test_ds[["range_real", "range_synth", "range_gen"]] 265 | doppler_ds = test_ds[["doppler_real", "doppler_synth", "doppler_gen"]] 266 | fig, axes = plt.subplots(3, 2, figsize=(11, 6)) 267 | spec_plot(range_ds, axes=[ax[0] for ax in axes], vmin=-40, vmax=0, add_colorbar=False) 268 | spec_plot(doppler_ds, axes=[ax[1] for ax in axes], vmin=-40, vmax=0, add_colorbar=False) 269 | for ax_pair, title in zip(axes, titles): 270 | if title != titles[-1]: 271 | for ax in ax_pair: 272 | ax.axes.get_xaxis().set_visible(False) 273 | ax_pair[0].set_title(title) 274 | cbar = fig.colorbar(axes[0][0].get_images()[0], ax=axes, orientation='vertical') 275 | cbar.set_label('Amplitude [dB]') 276 | else: 277 | y_hat_test = image_transformer(x_test).cpu().detach().numpy() 278 | test_ds["doppler_gen"] = (['time', 'doppler'], np.squeeze(y_hat_test), {"units": "dB"}) 279 | spec_plot(test_ds, vmin=-40, vmax=0, cbar_global="Amplitude [dB]") 280 | 281 | axes = plt.gcf().axes 282 | for ax, title in zip(axes, titles): 283 | if title != titles[-1]: 284 | ax.axes.get_xaxis().set_visible(False) 285 | ax.set_title(title) 286 | 287 | plt.savefig(plt_path) 288 | plt.close() 289 | 290 | print("images saved") 291 | image_transformer.train() 292 | 293 | loss_logs.append({'content': aggregate_content_loss/(batch_num+1.0), 294 | 'class': aggregate_classify_loss/(batch_num+1.0), 295 | 'tv': aggregate_tv_loss/(batch_num+1.0)}) 296 | 297 | # save model 298 | image_transformer.eval() 299 | 300 | if use_cuda: 301 | image_transformer.cpu() 302 | 303 | with open(f"log/{args.log}_loss.json", "w") as wf: 304 | json.dump(loss_logs, wf, indent=4) 305 | 306 | if args.plot: 307 | content_loss = [log['content'] for log in loss_logs] 308 | class_loss = [log['class'] for log in loss_logs] 309 | fig, [ax_content, ax_class] = plt.subplots(2, 1) 310 | 311 | ax_content.plot(content_loss) 312 | ax_content.set_xlabel("Epochs") 313 | ax_content.set_ylabel("Content loss") 314 | 315 | ax_class.plot(class_loss) 316 | ax_class.set_xlabel("Epochs") 317 | ax_class.set_ylabel("Classification loss") 318 | 319 | plt.savefig(f"log/{args.log}.png") 320 | 321 | filename = "models/" + args.log + ".model" 322 | log_dir = os.path.dirname(filename) 323 | if not os.path.exists(log_dir): 324 | os.makedirs(log_dir) 325 | torch.save(image_transformer.state_dict(), filename) 326 | 327 | if use_cuda: 328 | image_transformer.cuda() 329 | 330 | 331 | def train_classify(args): 332 | # GPU enabling 333 | if args.gpu is None: 334 | use_cuda = False 335 | in_type = torch.FloatTensor 336 | out_type = torch.LongTensor 337 | print("No GPU training") 338 | else: 339 | use_cuda = True 340 | in_type = torch.cuda.FloatTensor 341 | out_type = torch.cuda.LongTensor 342 | torch.cuda.set_device(args.gpu) 343 | print("Current device: %d" % torch.cuda.current_device()) 344 | 345 | # get training dataset 346 | config = args.config 347 | data_path = args.dataset 348 | print(f"Training with configurations {', '.join(config)} from the dataset under {data_path}") 349 | 350 | epochs = args.epochs 351 | 352 | num_workers = 4 353 | 354 | if args.no_split: 355 | train_load = 'fake-load' 356 | split = None 357 | elif args.recordings is None: 358 | split = 'single' 359 | train_load = 0.8 360 | else: 361 | split = 'no-cut' 362 | train_load = 0.5 363 | 364 | recordings = open_recordings(config, data_path, load=True, range_length=range_length, doppler_length=doppler_length) 365 | # Merge recordings from all configs 366 | train_recordings = [r for recs in recordings.values() for r in recs] 367 | if not args.range: # drop range spectrograms 368 | train_recordings = [r.drop_vars("range_spect") for r in train_recordings] 369 | 370 | slice_kwargs = dict(spec_length=time_length, stride=hop_length, train_load=train_load, copy_split=0) 371 | loader_kwargs = dict(batch_size=BATCH_CLASSIFY, shuffle=True, num_workers=num_workers, pin_memory=True) 372 | 373 | print("Preloading datasets...") 374 | if args.no_split: 375 | train_dataset = slice_datasets(train_recordings, split=split, **slice_kwargs) 376 | test_dataset = train_dataset 377 | tgt_segments = {} 378 | elif args.recordings is None or args.recordings == 'first': 379 | [train_dataset, test_dataset], tgt_segments = slice_datasets(train_recordings, split=split, **slice_kwargs) 380 | elif args.recordings == 'last': 381 | [test_dataset, train_dataset], tgt_segments = slice_datasets(train_recordings, split=split, **slice_kwargs) 382 | else: 383 | raise ValueError(f"Unexpected recordings argument {args.recordings}") 384 | with open(f"log/{args.log}_segments.json", "w") as wf: 385 | json.dump(tgt_segments, wf, indent=4) 386 | 387 | train_loader = DataLoader(train_dataset, **loader_kwargs) 388 | test_loader = DataLoader(test_dataset, **loader_kwargs) 389 | 390 | class_num = train_dataset.class_num 391 | input_shapes = train_dataset.feature_shapes 392 | print(f"Number of classes: {class_num}") 393 | print(f"Feature shapes: {input_shapes}\n") 394 | 395 | # define network 396 | c_net = RDNet(input_shapes=input_shapes, class_num=class_num).type(in_type) 397 | 398 | optimizer = Adam(c_net.parameters(), LEARNING_RATE) 399 | criterion = nn.CrossEntropyLoss() 400 | 401 | loss_logs = [] 402 | 403 | early_stop_thresh = .96 404 | 405 | for e in range(epochs): 406 | 407 | # track values for... 408 | img_count = 0 409 | train_loss = 0.0 410 | train_acc = 0.0 411 | test_acc = 0.0 412 | batch_num = 0 413 | 414 | # train network 415 | c_net.train() 416 | for batch_num, batch in enumerate(train_loader): 417 | feature_batch, label_batch = batch 418 | if args.range: 419 | x = [Variable(feat).type(in_type) for feat in feature_batch] 420 | else: 421 | x = Variable(feature_batch[0]).type(in_type) 422 | 423 | img_batch_read = len(label_batch) 424 | img_count += img_batch_read 425 | 426 | # zero out gradients 427 | optimizer.zero_grad() 428 | 429 | # input batch to classifier network 430 | y_true = Variable(label_batch).type(out_type) 431 | y_pred = c_net(x) 432 | 433 | loss = criterion(y_pred, y_true) 434 | loss.backward() 435 | optimizer.step() 436 | 437 | # print statistics 438 | train_loss += loss.item() 439 | train_acc += evaluate(c_net, [batch]).item() * img_batch_read 440 | if (batch_num + 1) % 100 == 0: 441 | test_acc = evaluate(c_net, test_loader).item() 442 | status = f"{time.ctime()} Epoch {e + 1}: " \ 443 | f"[{img_count}/{len(train_dataset)}] Batch:[{batch_num + 1}] " \ 444 | f"train_loss: {train_loss / (batch_num + 1.0):.6f} " \ 445 | f"train_acc: {train_acc / img_count:.6f} test_acc: {test_acc:.6f}" 446 | print(status) 447 | 448 | loss_logs.append({'train_loss': train_loss / (batch_num + 1.0), 449 | 'train_acc': train_acc / img_count, 'test_acc': test_acc}) 450 | 451 | if train_acc / img_count > early_stop_thresh: 452 | print("***Early stopping training***") 453 | break 454 | 455 | # save model 456 | c_net.eval() 457 | 458 | if use_cuda: 459 | c_net.cpu() 460 | 461 | with open(f"log/{args.log}_loss.json", "w") as wf: 462 | json.dump(loss_logs, wf, indent=4) 463 | 464 | if args.plot: 465 | loss = [] 466 | train_acc = [] 467 | test_acc = [] 468 | for log in loss_logs: 469 | loss.append(log['train_loss']) 470 | train_acc.append(log['train_acc']) 471 | test_acc.append(log['test_acc']) 472 | fig, [ax_loss, ax_acc] = plt.subplots(2, 1) 473 | ax_loss.plot(loss) 474 | ax_loss.set_xlabel("Epochs") 475 | ax_loss.set_ylabel("Train loss") 476 | 477 | ax_acc.plot(train_acc) 478 | ax_acc.plot(test_acc) 479 | ax_acc.set_xlabel("Epochs") 480 | ax_acc.set_ylabel("Accuracy") 481 | ax_acc.legend(["train", "test"]) 482 | plt.savefig(f"log/{args.log}.png") 483 | 484 | filename = "models/" + args.log + ".model" 485 | log_dir = os.path.dirname(filename) 486 | if not os.path.exists(log_dir): 487 | os.makedirs(log_dir) 488 | torch.save(c_net.state_dict(), filename) 489 | 490 | if use_cuda: 491 | c_net.cuda() 492 | 493 | return loss_logs[-1]["train_acc"] 494 | 495 | 496 | def test(args): 497 | # GPU enabling 498 | if args.gpu is None: 499 | dtype = torch.FloatTensor 500 | print("No GPU in use") 501 | else: 502 | dtype = torch.cuda.FloatTensor 503 | torch.cuda.set_device(args.gpu) 504 | print("Current device: %d" % torch.cuda.current_device()) 505 | 506 | # get training dataset 507 | config = args.config 508 | data_path = args.dataset 509 | 510 | train_load = 0.5 511 | num_workers = 4 512 | 513 | recordings = open_recordings(config, data_path, load=True, range_length=range_length, doppler_length=doppler_length) 514 | # Merge recordings from all configs 515 | recordings = [r for recs in recordings.values() for r in recs] 516 | if not args.range: # drop range spectrograms 517 | recordings = [r.drop_vars("range_spect") for r in recordings] 518 | 519 | slice_kwargs = dict(spec_length=time_length, stride=hop_length, train_load=train_load) 520 | loader_kwargs = dict(batch_size=BATCH_TRANSFER, shuffle=False, num_workers=num_workers, pin_memory=True) 521 | 522 | print("Preloading datasets...") 523 | if args.recordings is None: 524 | if args.segments is None: 525 | test_dataset = slice_datasets(recordings, split=None, **slice_kwargs) 526 | else: 527 | with open(args.segments, "r") as f: 528 | segments = json.load(f) 529 | _, test_dataset = slice_datasets(recordings, split=segments, **slice_kwargs) 530 | elif args.recordings == 'first': 531 | [test_dataset, _], tgt_segments = slice_datasets(recordings, split='no-cut', **slice_kwargs) 532 | elif args.recordings == 'last': 533 | [_, test_dataset], tgt_segments = slice_datasets(recordings, split='no-cut', **slice_kwargs) 534 | else: 535 | raise ValueError(f"Unrecognized recordings option {args.recordings}") 536 | 537 | test_loader = DataLoader(test_dataset, **loader_kwargs) 538 | 539 | class_num = test_dataset.class_num 540 | input_shapes = test_dataset.feature_shapes 541 | print(f"Number of classes: {class_num}") 542 | print(f"Feature shapes: {input_shapes}\n") 543 | 544 | # load network including transformer and classifier 545 | if args.transformer is None: 546 | trans_c_net = RDNet(input_shapes=input_shapes, class_num=class_num).type(dtype) 547 | trans_c_net.load_state_dict(torch.load(args.classifier)) 548 | else: 549 | trans_c_net = RACPIT(trans_path=args.transformer, model_path=args.classifier, 550 | input_shapes=input_shapes, class_num=class_num).type(dtype) 551 | trans_c_net.eval() 552 | 553 | accuracy, predict_info = evaluate(trans_c_net, test_loader, predict_info=True) 554 | accuracy = accuracy.item() 555 | print(f"{accuracy:.6f} accuracy on configurations {', '.join(config)} from {data_path}") 556 | 557 | true_labels = predict_info["real"] 558 | predictions = predict_info["predict"] 559 | correct = predict_info["correct"].cpu().detach().numpy().astype(np.bool) 560 | misclassified = np.nonzero(np.logical_not(correct))[0] 561 | 562 | confusion_matrix(predictions, true_labels, nb_classes=class_num) 563 | 564 | visualize = args.visualize 565 | 566 | if visualize <= 0: 567 | return accuracy 568 | 569 | if not os.path.exists("visualization/%s" % args.log): 570 | os.makedirs("visualization/%s" % args.log) 571 | 572 | true_labels = true_labels.cpu().detach().numpy().astype(np.int) 573 | predictions = predictions.cpu().detach().numpy().astype(np.int) 574 | 575 | test_indices = np.random.choice(misclassified, size=visualize, replace=False) 576 | activities = test_dataset.attrs['activities'] 577 | 578 | for img_index in test_indices: 579 | test_ds = test_dataset.dataset[int(img_index)] 580 | doppler_test = torch.from_numpy(test_ds.doppler_spect.values[None, None, :, :]) 581 | 582 | true_activity = activities[true_labels[img_index]] 583 | pred_activity = activities[predictions[img_index]] 584 | 585 | assert true_activity != pred_activity, "Prediction is correct" 586 | assert true_activity == test_ds.activity, "The true activity does not coincide with the embedded activity" 587 | 588 | plt_path = f"visualization/{args.log}/" \ 589 | f"{true_activity}_{test_ds.date.replace(':', '-')}_{pred_activity}.png" 590 | 591 | x_test = Variable(doppler_test, requires_grad=False).type(dtype) 592 | titles = ("Real data", f"Generated data, classified as {pred_activity}") 593 | if args.range: 594 | range_test = torch.from_numpy(test_ds.range_spect.values[None, None, :, :]) 595 | x_test = [Variable(range_test, requires_grad=False).type(dtype), x_test] 596 | range_hat, doppler_hat = trans_c_net.transformer(x_test) 597 | range_hat = range_hat.cpu().detach().numpy() 598 | doppler_hat = doppler_hat.cpu().detach().numpy() 599 | test_ds["range_gen"] = (['time', 'range'], np.squeeze(range_hat), {"units": "dB"}) 600 | test_ds["doppler_gen"] = (['time', 'doppler'], np.squeeze(doppler_hat), {"units": "dB"}) 601 | range_ds = test_ds[["range_spect", "range_gen"]] 602 | doppler_ds = test_ds[["doppler_spect", "doppler_gen"]] 603 | fig, axes = plt.subplots(2, 2, figsize=(11, 6)) 604 | spec_plot(range_ds, axes=[ax[0] for ax in axes], vmin=-40, vmax=0, add_colorbar=False) 605 | spec_plot(doppler_ds, axes=[ax[1] for ax in axes], vmin=-40, vmax=0, add_colorbar=False) 606 | for ax_pair, title in zip(axes, titles): 607 | if title != titles[-1]: 608 | for ax in ax_pair: 609 | ax.axes.get_xaxis().set_visible(False) 610 | ax_pair[0].set_title(title) 611 | cbar = fig.colorbar(axes[0][0].get_images()[0], ax=axes, orientation='vertical') 612 | cbar.set_label('Amplitude [dB]') 613 | else: 614 | y_hat_test = trans_c_net.transformer(x_test).cpu().detach().numpy() 615 | test_ds["doppler_gen"] = (['time', 'doppler'], np.squeeze(y_hat_test), {"units": "dB"}) 616 | spec_plot(test_ds, vmin=-40, vmax=0, cbar_global="Amplitude [dB]") 617 | 618 | axes = plt.gcf().axes 619 | for ax, title in zip(axes, titles): 620 | if title != titles[-1]: 621 | ax.axes.get_xaxis().set_visible(False) 622 | ax.set_title(title) 623 | 624 | plt.savefig(plt_path) 625 | plt.close() 626 | 627 | return accuracy 628 | 629 | 630 | def slice_datasets(recordings, spec_length, stride, train_load=0.8, split=None, copy_split=0): 631 | if copy_split > 1: 632 | effective_len = len(recordings) // copy_split 633 | else: 634 | effective_len = len(recordings) 635 | if split is None: 636 | slices = train_test_slice(recordings, spec_length, stride, train_load, split=split) 637 | rd_dataset = RadarDataset(recordings, slices=slices, ignore_dims=ignore_dims) 638 | return rd_dataset 639 | elif isinstance(split, dict): 640 | slices = train_test_slice(recordings[:effective_len], spec_length, stride, train_load, verbose=False, 641 | split=split, return_segments=False) 642 | if copy_split > 1: 643 | slices = [copy_slices(sl, effective_len, copy_split) for sl in slices] 644 | rd_datasets = [RadarDataset(recordings, slices=s, ignore_dims=ignore_dims) for s in slices] 645 | # assert set(index for index, sl in slices[0]) == set(range(len(recordings))), \ 646 | # "Slices do not include all recordings" 647 | return rd_datasets 648 | else: 649 | slices = train_test_slice(recordings[:effective_len], spec_length, stride, train_load, verbose=False, 650 | split=split, return_segments=True) 651 | segments = slices.pop(-1) 652 | if copy_split > 1: 653 | slices = [copy_slices(sl, effective_len, copy_split) for sl in slices] 654 | rd_datasets = [RadarDataset(recordings, slices=s, ignore_dims=ignore_dims) for s in slices] 655 | # assert set(index for index, sl in slices[0]) == set(range(len(recordings))), \ 656 | return rd_datasets, segments 657 | 658 | 659 | def copy_slices(slices, num_recs, repeat): 660 | new_slices = [] 661 | for n in range(repeat): 662 | new_slices += [(index + num_recs * n, sl) for index, sl in slices] 663 | return new_slices 664 | 665 | 666 | # ============== eval 667 | def evaluate(model_instance, input_loader, gpu=True, predict_info=False): 668 | ori_train_state = model_instance.training 669 | model_instance.eval() 670 | first_test = True 671 | all_probs, all_labels = None, None 672 | 673 | for data in input_loader: 674 | inputs = data[0] 675 | labels = data[1] 676 | if gpu: 677 | inputs = [inp.cuda() for inp in inputs] 678 | labels = labels.cuda() 679 | 680 | probabilities = model_instance.predict(inputs) 681 | 682 | probabilities = probabilities.data.float() 683 | labels = labels.data.float() 684 | if first_test: 685 | all_probs = probabilities 686 | all_labels = labels 687 | first_test = False 688 | else: 689 | all_probs = torch.cat((all_probs, probabilities), 0) 690 | all_labels = torch.cat((all_labels, labels), 0) 691 | 692 | _, predict = torch.max(all_probs, 1) 693 | predict = torch.squeeze(predict).float() 694 | correct = predict == all_labels 695 | accuracy = torch.sum(correct) / float(all_labels.size()[0]) 696 | 697 | model_instance.train(ori_train_state) 698 | 699 | if predict_info: 700 | predictions = {"predict": predict, "real": all_labels, "correct": correct} 701 | return accuracy, predictions 702 | else: 703 | return accuracy 704 | 705 | 706 | def confusion_matrix(predicted, true_label, nb_classes): 707 | cm = torch.zeros(nb_classes, nb_classes) 708 | with torch.no_grad(): 709 | for t, p in zip(true_label.view(-1), predicted.view(-1)): 710 | cm[t.long(), p.long()] += 1 711 | print(cm) 712 | 713 | 714 | def save_params(parameters): 715 | log_file = f"log/{parameters['log']}_params.json" 716 | log_dir = os.path.dirname(log_file) 717 | if not os.path.exists(log_dir): 718 | os.makedirs(log_dir) 719 | with open(log_file, "w") as wf: 720 | json.dump(parameters, wf, indent=4) 721 | 722 | 723 | def main(): 724 | parser = argparse.ArgumentParser(description='style transfer in pytorch') 725 | parser.add_argument("--log", type=str, default=None, help="ID to mark output files and logs. Default to timestamp") 726 | subparsers = parser.add_subparsers(title="subcommands", dest="subcommand") 727 | 728 | train_parser = subparsers.add_parser("train-transfer", help="train a model to do style transfer") 729 | train_parser.add_argument("--plot", action='store_true', help="Plot and save a training report") 730 | train_parser.add_argument("--gpu", type=int, default=None, help="ID of GPU to be used") 731 | train_parser.add_argument("--epochs", type=int, default=EPOCHS, help="Number of epochs for training") 732 | train_parser.add_argument("--tv-weight", type=float, default=TV_WEIGHT, help="Weight of TV regularization") 733 | train_parser.add_argument("--model", type=str, default=None, help="Path to a saved model to use " 734 | "for perceptual loss. VGG16 as default.") 735 | train_parser.add_argument("--recordings", type=str, default=None, help="Select to use 'first' or 'last' recordings") 736 | train_parser.add_argument("--segments", type=str, default=None, help="path to a segment file") 737 | train_parser.add_argument("--visualize", type=int, default=None, help="Set to 1 if you want to visualize training") 738 | train_parser.add_argument("--input", type=str, default=REAL_PATH, help="Path to input training dataset") 739 | train_parser.add_argument("--output", type=str, default=SYNTH_PATH, help="Path to output training dataset") 740 | train_parser.add_argument("--config", type=str, nargs='*', default=["F"], help="Radar configurations to train with") 741 | train_parser.add_argument("--range", action='store_true', help="Use range information alongside doppler") 742 | train_parser.add_argument("--classes", type=int, nargs='*', default=None, 743 | help="Classes to train with, default to all") 744 | 745 | classify_parser = subparsers.add_parser("train-classify", help="train a model to classify human activity") 746 | classify_parser.add_argument("--plot", action='store_true', help="Plot and save a training report") 747 | classify_parser.add_argument("--gpu", type=int, default=None, help="ID of GPU to be used") 748 | classify_parser.add_argument("--epochs", type=int, default=EPOCHS, help="Number of epochs for training") 749 | classify_parser.add_argument("--min-acc", type=float, default=0.0, help="Retrain until min. accuracy is reached") 750 | classify_parser.add_argument("--recordings", type=str, default=None, 751 | help="Select to use 'first' or 'last' recordings") 752 | classify_parser.add_argument("--dataset", type=str, default=SYNTH_PATH, help="Path to training dataset") 753 | classify_parser.add_argument("--config", type=str, nargs='*', default=["E", "F"], 754 | help="Radar configurations to train with") 755 | classify_parser.add_argument("--range", action='store_true', help="Use range information alongside doppler") 756 | classify_parser.add_argument("--no-split", action='store_true', help="Do not split data into test/train") 757 | 758 | test_parser = subparsers.add_parser("test", help="test a model to apply human activity classification") 759 | test_parser.add_argument("--gpu", type=int, default=None, help="ID of GPU to be used") 760 | test_parser.add_argument("--visualize", type=int, default=0, 761 | help="Number of misclassified spectrograms to show") 762 | test_parser.add_argument("--transformer", type=str, default=None, help="Path to a saved model to use " 763 | "for image transform") 764 | test_parser.add_argument("--classifier", type=str, required=True, help="Path to a saved model to use " 765 | "for classification.") 766 | test_parser.add_argument("--recordings", type=str, default=None, help="Select to use 'first' or 'last' recordings") 767 | test_parser.add_argument("--segments", type=str, default=None, help="path to a segment file") 768 | test_parser.add_argument("--dataset", type=str, default=REAL_PATH, help="Path to the dataset " 769 | "to feed the transformer") 770 | test_parser.add_argument("--config", type=str, nargs='*', default=["E"], help="Radar configurations to test with") 771 | test_parser.add_argument("--range", action='store_true', help="Use range information alongside doppler") 772 | 773 | args = parser.parse_args() 774 | 775 | params = {"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S")} 776 | print(f"Process started at {params['timestamp']}") 777 | if args.log is None: 778 | args.log = params["timestamp"].replace(':', '') 779 | params.update(vars(args)) 780 | save_params(params) 781 | 782 | # command 783 | if args.subcommand == "train-transfer": 784 | print("Training image transfer!") 785 | train_transfer(args) 786 | elif args.subcommand == "train-classify": 787 | print("Training classifier!") 788 | train_classify(args) 789 | acc = train_classify(args) 790 | while acc < args.min_acc: 791 | acc = train_classify(args) 792 | elif args.subcommand == "test": 793 | print("Testing!") 794 | acc = test(args) 795 | params["accuracy"] = acc 796 | save_params(params) 797 | else: 798 | print("invalid command") 799 | 800 | 801 | if __name__ == '__main__': 802 | main() 803 | --------------------------------------------------------------------------------